Source code for magicclass.qthreading

from __future__ import annotations
import inspect
from functools import wraps
from typing import (
    Any,
    Callable,
    TYPE_CHECKING,
    Iterable,
    Union,
    overload,
    TypeVar,
    Protocol,
    runtime_checkable,
)

try:
    from superqt.utils import create_worker, GeneratorWorker, FunctionWorker
except ImportError as e:  # pragma: no cover
    msg = f"{e}. To use magicclass with threading please `pip install superqt`"
    raise type(e)(msg)

from qtpy.QtCore import Qt
from magicgui.widgets import ProgressBar, Container, Widget, PushButton

from .fields import MagicField
from .utils import get_signature, move_to_screen_center

if TYPE_CHECKING:
    from .gui import BaseGui
    from .gui.mgui_ext import PushButtonPlus

__all__ = ["thread_worker"]

_F = TypeVar("_F", bound=Callable)


@runtime_checkable
class _SupportProgress(Protocol):
    """A progress protocol."""

    def __init__(self, max: int = 1, **kwargs):
        raise NotImplementedError()

    @property
    def value(self) -> int:
        raise NotImplementedError()

    @value.setter
    def value(self, v) -> None:
        raise NotImplementedError()

    @property
    def max(self) -> int:
        raise NotImplementedError()

    @max.setter
    def max(self, v) -> None:
        raise NotImplementedError()

    def set_description(self, desc: str):
        raise NotImplementedError()

    def show(self):
        raise NotImplementedError()

    def close(self):
        raise NotImplementedError()


class Callbacks:
    """List of callback functions."""

    def __init__(self):
        self._callbacks: list[Callable] = []

    @property
    def callbacks(self) -> tuple[Callable, ...]:
        return tuple(self._callbacks)

    def connect(self, callback: _F) -> _F:
        """
        Append a callback function to the callback list.

        Parameters
        ----------
        callback : Callable
            Callback function.
        """
        if not callable(callback):
            raise TypeError("Can only connect callable object.")
        self._callbacks.append(callback)
        return callback

    def disconnect(self, callback: _F) -> _F:
        """
        Remove callback function from the callback list.

        Parameters
        ----------
        callback : Callable
            Callback function to be removed.
        """
        self._callbacks.remove(callback)
        return callback

    def _iter_as_method(self, obj: BaseGui) -> Iterable[Callable]:
        for callback in self._callbacks:

            def f(*args, **kwargs):
                with obj.macro.blocked():
                    out = callback.__get__(obj)(*args, **kwargs)
                return out

            yield f


class NapariProgressBar(_SupportProgress):
    """A progressbar class that provides napari progress bar with same API."""

    def __init__(self, value: int = 0, max: int = 1000):
        from napari.utils import progress

        with progress._all_instances.events.changed.blocker():
            self._pbar = progress(total=max)
            self._pbar.n = value

    @property
    def value(self) -> int:
        return self._pbar.n

    @value.setter
    def value(self, v) -> None:
        self._pbar.n = v
        self._pbar.events.value(value=self._pbar.n)

    @property
    def max(self) -> int:
        return self._pbar.total

    @max.setter
    def max(self, v) -> None:
        self._pbar.total = v

    def set_description(self, v: str) -> None:
        self._pbar.set_description(v)

    @property
    def visible(self) -> bool:
        return False

    def show(self):
        type(self._pbar)._all_instances.events.changed(added={self._pbar}, removed={})

    def close(self):
        self._pbar.close()


class DefaultProgressBar(Container, _SupportProgress):
    """The default progressbar widget."""

    def __init__(self, max: int = 1):
        self.pbar = ProgressBar(value=0, max=max)
        self.pause_button = PushButton(text="Pause")
        self.abort_button = PushButton(text="Abort")
        cnt = Container(
            layout="horizontal", widgets=[self.pause_button, self.abort_button]
        )
        cnt.margins = (0, 0, 0, 0)
        self.pbar.min_width = 200
        self._paused = False
        super().__init__(widgets=[self.pbar, cnt])

    @property
    def value(self) -> int:
        return self.pbar.value

    @value.setter
    def value(self, v):
        self.pbar.value = v

    @property
    def max(self) -> int:
        return self.pbar.max

    @max.setter
    def max(self, v):
        self.pbar.max = v

    def set_description(self, desc: str):
        """Set description as the label of the progressbar."""
        self.pbar.label = desc
        return None

    def set_worker(self, worker: GeneratorWorker | FunctionWorker):
        """Set currently running worker."""
        self._worker = worker
        # initialize abort_button
        self.abort_button.text = "Abort"
        self.abort_button.changed.connect(self._abort_worker)
        self.abort_button.enabled = True

        # initialize pause_button
        self.pause_button.text = "Pause"
        if not isinstance(self._worker, GeneratorWorker):
            self.pause_button.enabled = False
            return None
        self.pause_button.enabled = True
        self.pause_button.changed.connect(self._toggle_pause)

        @self._worker.paused.connect
        def _on_pause():
            self.pause_button.text = "Resume"
            self.pause_button.enabled = True

        return None

    def _toggle_pause(self):
        if self._paused:
            self._worker.resume()
            self.pause_button.text = "Pause"
        else:
            self._worker.pause()
            self.pause_button.text = "Pausing"
            self.pause_button.enabled = False

        self._paused = not self._paused
        return None

    def _abort_worker(self):
        self._paused = False
        self.pause_button.text = "Pause"
        self.abort_button.text = "Aborting"
        self.pause_button.enabled = False
        self.abort_button.enabled = False
        self._worker.quit()
        return None


ProgressBarLike = Union[ProgressBar, _SupportProgress]


[docs]class thread_worker: """Create a worker in a superqt/napari style.""" _DEFAULT_PROGRESS_BAR = DefaultProgressBar def __init__( self, f: Callable | None = None, *, ignore_errors: bool = False, progress: dict[str, Any] | None = None, ) -> None: self._func: Callable | None = None self._started = Callbacks() self._returned = Callbacks() self._errored = Callbacks() self._yielded = Callbacks() self._finished = Callbacks() self._aborted = Callbacks() self._ignore_errors = ignore_errors self._objects: dict[int, BaseGui] = {} self._progressbars: dict[int, ProgressBarLike | None] = {} if f is not None: self(f) if progress: if isinstance(progress, bool): progress = {} desc = progress.get("desc", None) total = progress.get("total", 0) pbar = progress.get("pbar", None) progress = {"desc": desc, "total": total, "pbar": pbar} self._progress = progress
[docs] @classmethod def set_default(cls, pbar_cls: Callable | str): """ Set the default progressbar class. This class method is useful when there is an user-defined class that follows ``_SupportProgress`` protocol. Parameters ---------- pbar_cls : callable or str The default class. In principle this parameter does not have to be a class. As long as ``pbar_cls(max=...)`` returns a ``_SupportProgress`` object it works. Either "default" or "napari" is also accepted. """ if isinstance(pbar_cls, str): if pbar_cls == "napari": pbar_cls = NapariProgressBar elif pbar_cls == "default": pbar_cls = DefaultProgressBar else: raise ValueError( f"Unknown progress bar {pbar_cls!r}. Must be either 'default' or " "'napari', or a proper type object." ) cls._DEFAULT_PROGRESS_BAR = pbar_cls return pbar_cls
@overload def __call__(self, f: Callable) -> thread_worker: ... @overload def __call__(self, bgui: BaseGui, *args, **kwargs) -> Any: ... def __call__(self, *args, **kwargs): if self._func is None: f = args[0] self._func = f wraps(f)(self) return self else: return self._func(*args, **kwargs) def __get__(self, gui: BaseGui, objtype=None): if gui is None: return self gui_id = id(gui) if gui_id in self._objects: return self._objects[gui_id] @wraps(self) def _create_worker(*args, **kwargs): worker: FunctionWorker | GeneratorWorker = create_worker( self._func.__get__(gui), _ignore_errors=self._ignore_errors, *args, **kwargs, ) for c in self._started._iter_as_method(gui): worker.started.connect(c) for c in self._returned._iter_as_method(gui): worker.returned.connect(c) for c in self._errored._iter_as_method(gui): worker.errored.connect(c) for c in self._yielded._iter_as_method(gui): worker.yielded.connect(c) for c in self._finished._iter_as_method(gui): worker.finished.connect(c) for c in self._aborted._iter_as_method(gui): worker.aborted.connect(c) if self._progress: _desc = self._progress["desc"] _total = self._progress["total"] if callable(_desc): desc = _desc(gui) else: desc = str(_desc) if isinstance(_total, str): arguments = self.__signature__.bind(gui, *args, **kwargs) arguments.apply_defaults() all_args = arguments.arguments total = eval(_total, {}, all_args) elif callable(_total): total = _total(gui) elif isinstance(_total, int): total = _total else: raise TypeError( "'total' must be int, callable or evaluatable string." ) _pbar = self._progress["pbar"] if _pbar is None: pbar = self._find_progressbar( gui, desc=desc, total=total, ) elif isinstance(_pbar, MagicField): pbar = _pbar.get_widget(gui) if not isinstance(pbar, ProgressBarLike): raise TypeError(f"{_pbar.name} does not create a ProgressBar.") pbar.label = desc or _pbar.name pbar.max = total else: if not isinstance(_pbar, ProgressBarLike): raise TypeError(f"{_pbar} is not a ProgressBar.") pbar = _pbar worker.started.connect(init_pbar.__get__(pbar)) if not getattr(pbar, "visible", False): worker.finished.connect(close_pbar.__get__(pbar)) if pbar.max != 0 and isinstance(worker, GeneratorWorker): worker.pbar = pbar # avoid garbage collection worker.yielded.connect(increment.__get__(pbar)) if hasattr(pbar, "set_worker"): pbar.set_worker(worker) _push_button: PushButtonPlus = gui[self._func.__name__] if _push_button.running: _push_button.enabled = False @worker.finished.connect def _enable(): _push_button.enabled = True worker.start() else: worker.run() return None _create_worker.__signature__ = self._get_method_signature() self._objects[gui_id] = _create_worker # cache return _create_worker def _get_method_signature(self) -> inspect.Signature: sig = self.__signature__ params = list(sig.parameters.values())[1:] return sig.replace(parameters=params) @property def __signature__(self) -> inspect.Signature: """Get the signature of the bound function.""" return get_signature(self._func) @__signature__.setter def __signature__(self, sig: inspect.Signature) -> None: """Update signature of the bound function.""" if not isinstance(sig, inspect.Signature): raise TypeError(f"Cannot set type {type(sig)}.") self._func.__signature__ = sig return None def _find_progressbar(self, gui: BaseGui, desc: str | None = None, total: int = 0): """Find available progressbar. Create a new one if not found.""" gui_id = id(gui) if gui_id in self._progressbars: _pbar = self._progressbars[gui_id] for name, attr in gui.__class__.__dict__.items(): if isinstance(attr, MagicField): attr = attr.get_widget(gui) if isinstance(attr, ProgressBar): _pbar = self._progressbars[gui_id] = attr if desc is None: desc = name break else: _pbar = self._progressbars[gui_id] = None if desc is None: desc = "Progress" if _pbar is None: _pbar = self.__class__._DEFAULT_PROGRESS_BAR(max=total) if isinstance(_pbar, Widget) and _pbar.parent is None: # Popup progressbar as a splashscreen if it is not a child widget. _pbar.native.setParent( gui.native, Qt.WindowTitleHint | Qt.WindowMinimizeButtonHint | Qt.Window, ) move_to_screen_center(_pbar.native) else: _pbar.max = total # try to set description if hasattr(_pbar, "set_description"): _pbar.set_description(desc) else: _pbar.label = desc return _pbar @property def started(self) -> Callbacks: """Event that will be emitted on started.""" return self._started @property def returned(self) -> Callbacks: """Event that will be emitted on returned.""" return self._returned @property def errored(self) -> Callbacks: """Event that will be emitted on errored.""" return self._errored @property def yielded(self) -> Callbacks: """Event that will be emitted on yielded.""" return self._yielded @property def finished(self) -> Callbacks: """Event that will be emitted on finished.""" return self._finished @property def aborted(self) -> Callbacks: """Event that will be emitted on aborted.""" return self._aborted
def init_pbar(pbar: ProgressBarLike): """Initialize progressbar.""" pbar.value = 0 pbar.show() return None def close_pbar(pbar: ProgressBarLike): """Close progressbar.""" if isinstance(pbar, ProgressBar): _labeled_widget = pbar._labeled_widget() if _labeled_widget is not None: pbar = _labeled_widget pbar.close() return None def increment(pbar: ProgressBarLike): """Increment progressbar.""" if pbar.value == pbar.max: pbar.max = 0 else: pbar.value += 1 return None