Source code for magicclass.ext.dask.progress

from __future__ import annotations
from functools import wraps
from typing import Any, Callable, TYPE_CHECKING
from dask.diagnostics import Callback as DaskCallback
from psygnal import Signal
from superqt.utils import FunctionWorker, GeneratorWorker, create_worker

from magicclass.utils import move_to_screen_center, QtSignal
from magicclass.utils.qthreading import (
    CallbackList,
    DefaultProgressBar,
    thread_worker,
    ProgressDict,
)

if TYPE_CHECKING:
    from ..._gui import BaseGui


[docs]class DaskProgressBar(DefaultProgressBar, DaskCallback): """A progress bar widget for dask computation.""" computed = Signal(object) def __init__( self, max: int = 100, minimum: float = 0.5, dt: float = 0.1, ): self._minimum = minimum self._dt = dt self._frac = 0.0 self._n_computation = 0 super().__init__(max=max) self.footer[1].visible = self.footer[2].visible = False self._computed_signal = QtSignal() self._computed_signal.connect(self._on_computed) def __enter__(self): self._n_computation = 0 self._on_timer_updated() return super().__enter__() def _start(self, dsk): self._state = None self._frac = 0.0 self._n_computation += 1 self._start_thread() return None def _on_computed(self, result): s = self._state if not s: self._frac = 0.0 else: ndone = len(s["finished"]) ntasks = sum(len(s[k]) for k in ["ready", "waiting", "running"]) + ndone if ndone <= ntasks: self._frac = ndone / ntasks if ntasks else 0.0 self.pbar.value = self.max * self._frac self.computed.emit(result) return None def _pretask(self, key, dsk, state): self._state = state return None def _posttask(self, key, result, dsk, state, worker_id): self._computed_signal.emit(result) self._time_signal.emit() return None def _finish(self, dsk, state, errored): self._frac = 1.0 self._running = False self._thread_timer.join() self._timer.reset() return None def _on_timer_updated(self, _=None): if self._n_computation > 1: _prefix = f"({self._n_computation}) " else: _prefix = "" if self._timer.sec < 3600: self.time_label.value = _prefix + self._timer.format_time( "{min:0>2}:{sec:0>2}" ) else: self.time_label.value = _prefix + self._timer.format_time() return None
[docs] def set_worker(self, worker: GeneratorWorker | FunctionWorker): """Set currently running worker.""" self._worker = worker if isinstance(self._worker, GeneratorWorker): raise TypeError("Cannot set generator.") self.footer[1].visible = False self.footer[2].visible = False self._time_signal.emit() return None
[docs]class dask_thread_worker(thread_worker): """ Create a dask's worker in a superqt/napari style. This thread worker class can monitor the progress of dask computation. Callback function connected to ``computed`` signal will get called when any one of the tasks are finished. The returned value of the task will be sent to the callback argument. The returned value is useful if delayed functions are computed but it is not always meaningful when dask mapping functions such as ``map_blocks`` is used. Unlike standard ``thread_worker``, you should not specify ``total`` parameter since dask progress bar knows it. Examples -------- .. code-block:: python @magicclass class A: @dask_thread_worker def func(self): arr = da.random.random((30000, 30000)) da.mean(arr).compute() @func.computed.connect def _on_computed(self, _=None): print("computed") """ _DEFAULT_PROGRESS_BAR = DaskProgressBar _DEFAULT_TOTAL = 100 def __init__( self, f: Callable | None = None, *, ignore_errors: bool = False, progress: ProgressDict | bool | None = True, ) -> None: super().__init__(f, ignore_errors=ignore_errors, progress=progress) self._callback_dict_["computed"] = CallbackList() @property def computed(self) -> CallbackList[Any]: return self._callback_dict_["computed"] def _create_method(self, gui: BaseGui): if self._progress is None: self._progress = { "pbar": None, "desc": "Progress", "total": 100, } else: self._progress["pbar"] = None return super()._create_method(gui) def _create_qt_worker( self, gui, *args, **kwargs ) -> FunctionWorker | GeneratorWorker: gui_id = id(gui) pbar = self._DEFAULT_PROGRESS_BAR(max=self._DEFAULT_TOTAL) self._progressbars[gui_id] = pbar for c in self.computed._iter_as_method(gui): pbar.computed.connect(c) pbar.native.setParent(gui.native, self.__class__._WINDOW_FLAG) move_to_screen_center(pbar.native) worker = create_worker( self._define_function(pbar).__get__(gui), _ignore_errors=self._ignore_errors, _start_thread=False, *args, **kwargs, ) return worker def _define_function(self, pbar): @wraps(self._func) def _wrapped(*args, **kwargs): with pbar: out = self._func(*args, **kwargs) return out return _wrapped