Source code for magicclass.qthreading

from __future__ import annotations
import inspect
from functools import wraps
from typing import Any, Callable, TYPE_CHECKING, Iterable, Union, overload, TypeVar

try:
    from superqt.utils import create_worker, GeneratorWorker, FunctionWorker
except ImportError as e:  # pragma: no cover
    msg = f"{e}. To use magicclass with threading please `pip install superqt`"
    raise type(e)(msg)

from magicgui.widgets import ProgressBar

from .fields import MagicField
from .utils.functions import get_signature

if TYPE_CHECKING:
    from .gui import BaseGui
    from .gui.mgui_ext import PushButtonPlus

_F = TypeVar("_F", bound=Callable)


[docs]class Callbacks: """List of callback functions.""" def __init__(self): self._callbacks: list[Callable] = [] @property def callbacks(self) -> tuple[Callable, ...]: return tuple(self._callbacks)
[docs] def connect(self, callback: _F) -> _F: """ Append a callback function to the callback list. Parameters ---------- callback : Callable Callback function. """ if not callable(callback): raise TypeError("Can only connect callable object.") self._callbacks.append(callback) return callback
[docs] def disconnect(self, callback: _F) -> _F: """ Remove callback function from the callback list. Parameters ---------- callback : Callable Callback function to be removed. """ self._callbacks.remove(callback) return callback
def _iter_as_method(self, obj: BaseGui) -> Iterable[Callable]: for callback in self._callbacks: def f(*args, **kwargs): with obj.macro.blocked(): out = callback.__get__(obj)(*args, **kwargs) return out yield f
[docs]class NapariProgressBar: """A progressbar class that provides napari progress bar with same API.""" def __init__(self, value: int = 0, max: int = 1000): from napari.utils import progress with progress._all_instances.events.changed.blocker(): self._pbar = progress(total=max) self._pbar.n = value @property def value(self) -> int: return self._pbar.n @value.setter def value(self, v) -> None: self._pbar.n = v self._pbar.events.value(value=self._pbar.n) @property def max(self) -> int: return self._pbar.total @max.setter def max(self, v) -> None: self._pbar.total = v @property def label(self) -> str: return self._pbar.desc @label.setter def label(self, v) -> None: self._pbar.set_description(v) @property def visible(self) -> bool: return False
[docs] def show(self): type(self._pbar)._all_instances.events.changed(added={self._pbar}, removed={})
[docs] def close(self): self._pbar.close()
_SupportProgress = Union[ProgressBar, NapariProgressBar]
[docs]class thread_worker: """Create a worker in a superqt/napari style.""" def __init__( self, f: Callable | None = None, *, ignore_errors: bool = False, progress: dict[str, Any] | None = None, ) -> None: self._func: Callable | None = None self._started = Callbacks() self._returned = Callbacks() self._errored = Callbacks() self._yielded = Callbacks() self._finished = Callbacks() self._aborted = Callbacks() self._ignore_errors = ignore_errors self._objects: dict[int, BaseGui] = {} self._progressbars: dict[int, _SupportProgress | None] = {} if f is not None: self(f) if progress: if isinstance(progress, bool): progress = {} desc = progress.get("desc", None) total = progress.get("total", 0) pbar = progress.get("pbar", None) progress = {"desc": desc, "total": total, "pbar": pbar} self._progress = progress @property def not_ready(self) -> bool: return self._func is None @overload def __call__(self, f: Callable) -> thread_worker: ... @overload def __call__(self, bgui: BaseGui, *args, **kwargs) -> Any: ... def __call__(self, *args, **kwargs): if self.not_ready: f = args[0] self._func = f wraps(f)(self) return self else: return self._func(*args, **kwargs) def __get__(self, gui: BaseGui, objtype=None): if gui is None: return self gui_id = id(gui) if gui_id in self._objects: return self._objects[gui_id] @wraps(self) def _create_worker(*args, **kwargs): worker: FunctionWorker | GeneratorWorker = create_worker( self._func.__get__(gui), _ignore_errors=self._ignore_errors, *args, **kwargs, ) for c in self._started._iter_as_method(gui): worker.started.connect(c) for c in self._returned._iter_as_method(gui): worker.returned.connect(c) for c in self._errored._iter_as_method(gui): worker.errored.connect(c) for c in self._yielded._iter_as_method(gui): worker.yielded.connect(c) for c in self._finished._iter_as_method(gui): worker.finished.connect(c) for c in self._aborted._iter_as_method(gui): worker.aborted.connect(c) if self._progress: _desc = self._progress["desc"] _total = self._progress["total"] if callable(_desc): desc = _desc(gui) else: desc = str(_desc) if isinstance(_total, str): arguments = self.__signature__.bind(gui, *args, **kwargs) arguments.apply_defaults() all_args = arguments.arguments total = eval(_total, {}, all_args) elif callable(_total): total = _total(gui) elif isinstance(_total, int): total = _total else: raise TypeError( "'total' must be int, callable or evaluatable string." ) _pbar = self._progress["pbar"] if _pbar is None: pbar = self._find_progressbar( gui, desc=desc, total=total, ) elif isinstance(_pbar, MagicField): pbar = _pbar.get_widget(gui) if not isinstance(pbar, ProgressBar): raise TypeError(f"{_pbar.name} does not create a ProgressBar.") pbar.label = desc or _pbar.name pbar.max = total else: if not isinstance(_pbar, ProgressBar): raise TypeError(f"{_pbar} is not a ProgressBar.") pbar = _pbar worker.started.connect(init_pbar.__get__(pbar)) if not pbar.visible: worker.finished.connect(close_pbar.__get__(pbar)) if pbar.max != 0 and isinstance(worker, GeneratorWorker): worker.pbar = pbar # avoid garbage collection worker.yielded.connect(increment.__get__(pbar)) _push_button: PushButtonPlus = gui[self._func.__name__] if _push_button.running: _push_button.enabled = False @worker.finished.connect def _enable(): _push_button.enabled = True worker.start() else: worker.run() return None _create_worker.__signature__ = self._get_method_signature() self._objects[gui_id] = _create_worker # cache return _create_worker def _get_method_signature(self) -> inspect.Signature: sig = self.__signature__ params = list(sig.parameters.values())[1:] return sig.replace(parameters=params) @property def __signature__(self) -> inspect.Signature: """Get the signature of the bound function.""" return get_signature(self._func) @__signature__.setter def __signature__(self, sig: inspect.Signature) -> None: """Update signature of the bound function.""" if not isinstance(sig, inspect.Signature): raise TypeError(f"Cannot set type {type(sig)}.") self._func.__signature__ = sig return None def _find_progressbar(self, gui: BaseGui, desc: str | None = None, total: int = 0): """Find available progressbar. Create a new one if not found.""" gui_id = id(gui) if gui_id in self._progressbars: _pbar = self._progressbars[gui_id] for name, attr in gui.__class__.__dict__.items(): if isinstance(attr, MagicField): attr = attr.get_widget(gui) if isinstance(attr, ProgressBar): _pbar = self._progressbars[gui_id] = attr if desc is None: desc = name break else: _pbar = self._progressbars[gui_id] = None if desc is None: desc = "Progress" if _pbar is None: if gui.parent_viewer is not None: _pbar = NapariProgressBar(value=0, max=total) else: _pbar = ProgressBar(value=0, max=total) _pbar.native.setParent(gui.native, _pbar.native.windowFlags()) else: _pbar.max = total _pbar.label = desc return _pbar @property def started(self) -> Callbacks: """Event that will be emitted on started.""" return self._started @property def returned(self) -> Callbacks: """Event that will be emitted on returned.""" return self._returned @property def errored(self) -> Callbacks: """Event that will be emitted on errored.""" return self._errored @property def yielded(self) -> Callbacks: """Event that will be emitted on yielded.""" return self._yielded @property def finished(self) -> Callbacks: """Event that will be emitted on finished.""" return self._finished @property def aborted(self) -> Callbacks: """Event that will be emitted on aborted.""" return self._aborted
[docs]def init_pbar(pbar: _SupportProgress): """Initialize progressbar.""" pbar.value = 0 pbar.show() return None
[docs]def close_pbar(pbar: _SupportProgress): """Close progressbar.""" if isinstance(pbar, ProgressBar): _labeled_widget = pbar._labeled_widget() if _labeled_widget is not None: pbar = _labeled_widget pbar.close() return None
[docs]def increment(pbar: _SupportProgress): """Increment progressbar.""" if pbar.value == pbar.max: pbar.max = 0 else: pbar.value += 1 return None