Source code for impy.arrays.bases.metaarray

from __future__ import annotations
from typing import TYPE_CHECKING
from pathlib import Path
import numpy as np
from numpy.typing import DTypeLike
from ..axesmixin import AxesMixin, get_axes_tuple
from ..._types import *
from ...axes import ImageAxesError
from ...array_api import xp
from ...utils.axesop import *
from ...utils.slicer import *
from ...collections import DataList

if TYPE_CHECKING:
    from typing_extensions import Self


[docs]class MetaArray(AxesMixin, np.ndarray): additional_props = ["_source", "_metadata", "_name"] NP_DISPATCH = {} _name: str _source: Path | None _metadata: dict[str, Any] def __new__( cls: type[MetaArray], obj, name: str | None = None, axes: str | None = None, source: str | Path | None = None, metadata: dict[str, Any] | None = None, dtype: DTypeLike = None, ) -> Self: if isinstance(obj, cls): return obj self = np.asarray(obj, dtype=dtype).view(cls) self.source = source self._name = name self.axes = axes self._metadata = metadata or {} return self @property def source(self): """The source file path.""" return self._source @source.setter def source(self, val): if val is None: self._source = None else: self._source = Path(val) @property def name(self) -> str: """Name of the array.""" if self._name is None: source = self.source if source is None: return "No name" else: return source.name else: return self._name @name.setter def name(self, val): self._name = str(val) @property def metadata(self) -> dict[str, Any]: """Metadata dictionary of the array.""" return self._metadata @property def value(self) -> np.ndarray: """Numpy view of the array.""" return np.asarray(self) def _repr_dict_(self) -> dict[str, Any]: return { "name": self.name, "shape": self.shape_info, "dtype": self.dtype, "source": self.source, "scale": self.scale, } def __str__(self): return f"{self.__class__.__name__}<{self.name!r}>" @property def shape(self): try: tup = get_axes_tuple(self) return tup(*super().shape) except ImageAxesError: return super().shape def _set_additional_props(self, other): # set additional properties # If `other` does not have it and `self` has, then the property will be inherited. for p in self.__class__.additional_props: setattr(self, p, getattr(other, p, getattr(self, p, None))) def _set_info(self, other: Self, new_axes: str = "inherit"): self._set_additional_props(other) # set axes try: if new_axes != "inherit": self.axes = new_axes self.set_scale(other) else: self.axes = other.axes.copy() except ImageAxesError: self.axes = None return None def __getitem__(self, key: int | str | slice | tuple) -> Self: if isinstance(key, str): # img["t=2;z=4"] ... axis-targeted slicing sl = axis_targeted_slicing(self.ndim, str(self.axes), key) return self.__getitem__(sl) if isinstance(key, np.ndarray): key = self._broadcast(key) out = super().__getitem__(key) # get item as np.ndarray keystr = key_repr(key) # write down key e.g. "0,*,*" if isinstance(out, self.__class__): # cannot set attribution to such as numpy.int32 if hasattr(key, "__array__") and key.size > 1: # fancy indexing will lose axes information, except for 1D array key = np.asarray(key) if key.ndim == 1: new_axes = self.axes else: new_axes = None elif "new" in keystr: # np.newaxis or None will add dimension new_axes = None elif not self.axes.is_none() and self.axes: del_list = [i for i, s in enumerate(keystr.split(",")) if s not in ("*", "")] new_axes = del_axis(self.axes, del_list) else: new_axes = None out._getitem_additional_set_info( self, keystr=keystr, new_axes=new_axes, key=key ) return out def _getitem_additional_set_info(self, other: Self, **kwargs): self._set_info(other, kwargs["new_axes"]) return None def __setitem__(self, key: int | str | slice | tuple, value): if isinstance(key, str): # img["t=2;z=4"] ... ImageJ-like method sl = axis_targeted_slicing(self.ndim, str(self.axes), key) return self.__setitem__(sl, value) if isinstance(key, MetaArray) and key.dtype == bool and not key.axes.is_none(): key = add_axes(self.axes, self.shape, key, key.axes) elif isinstance(key, np.ndarray) and key.dtype == bool and key.ndim == 2: # img[arr] ... where arr is 2-D boolean array key = add_axes(self.axes, self.shape, key) super().__setitem__(key, value) def __array_finalize__(self, obj): """ Every time an np.ndarray object is made by numpy functions inherited to ImgArray, this function will be called to set essential attributes. Therefore, you can use such as img.copy() and img.astype("int") without problems (maybe...). """ if obj is None: return None self._set_additional_props(obj) try: self.axes = getattr(obj, "axes", None) except Exception: self.axes = None else: if not self.axes.is_none() and len(self.axes) != self.ndim: self.axes = None def __array_ufunc__(self, ufunc, method, *args, **kwargs): """ Every time a numpy universal function (add, subtract, ...) is called, this function will be called to set/update essential attributes. """ args_, _ = _replace_inputs(self, args, kwargs) result = getattr(ufunc, method)(*args_, **kwargs) if result is NotImplemented: return NotImplemented result = result.view(self.__class__) # in the case result is such as np.float64 if not isinstance(result, self.__class__): return result result._process_output(ufunc, args, kwargs) return result def _inherit_meta(self, obj: AxesMixin, ufunc, **kwargs): """ Copy axis etc. from obj. This is called in __array_ufunc__(). Unlike _set_info(), keyword `axis` must be considered because it changes `ndim`. """ if "axis" in kwargs.keys() and not obj.axes.is_none(): new_axes = del_axis(obj.axes, kwargs["axis"]) else: new_axes = "inherit" self._set_info(obj, new_axes=new_axes) return self def __array_function__(self, func, types, args, kwargs): """ Every time a numpy function (np.mean...) is called, this function will be called. Essentially numpy function can be overloaded with this method. """ if (func in self.__class__.NP_DISPATCH and all(issubclass(t, MetaArray) for t in types)): return self.__class__.NP_DISPATCH[func](*args, **kwargs) args_, _ = _replace_inputs(self, args, kwargs) result = func(*args_, **kwargs) if result is NotImplemented: return NotImplemented if isinstance(result, (tuple, list)): _as_meta_array = lambda a: a.view(self.__class__)._process_output(func, args, kwargs) \ if type(a) is np.ndarray else a result = DataList(_as_meta_array(r) for r in result) else: if isinstance(result, np.ndarray): result = result.view(self.__class__) # in the case result is such as np.float64 if isinstance(result, self.__class__): result._process_output(func, args, kwargs) return result def _process_output(self, func, args, kwargs): # find the largest MetaArray. Largest because of broadcasting. arr = None for arg in args: if isinstance(arg, self.__class__): if arr is None or arr.ndim < arg.ndim: arr = arg if isinstance(arr, self.__class__): self._inherit_meta(arr, func, **kwargs) return self
[docs] @classmethod def implements(cls, numpy_function): """ Add functions to NP_DISPATCH so that numpy functions can be overloaded. """ def decorator(func): cls.NP_DISPATCH[numpy_function] = func func.__name__ = numpy_function.__name__ return func return decorator
[docs] def sort_axes(self) -> Self: """ Sort image dimensions to ptzcyx-order Returns ------- MetaArray Sorted image """ order = self.axes.argsort() return self.transpose(order)
[docs] def argmax_nd(self) -> tuple[int, ...]: """ N-dimensional version of argmax. For instance, if yx-array takes its maximum at (5, 8), this function returns ``AxesShape(y=5, x=8)``. Returns ------- AxesShape Argmax of the array. """ argmax = np.unravel_index(np.argmax(self), self.shape) try: tup = get_axes_tuple(self) return tup(*argmax) except ImageAxesError: return argmax
[docs] def split(self, axis=None) -> DataList[Self]: """ Split n-dimensional image into (n-1)-dimensional images. Parameters ---------- axis : str or int, optional Along which axis the original image will be split, by default "c" Returns ------- list of arrays Separate images """ # determine axis in int. if axis is None: axis = find_first_appeared(self.axes, include="cztp") axisint = self.axisof(axis) imgs: DataList[MetaArray] = DataList(np.moveaxis(self, axisint, 0)) for img in imgs: img.axes = del_axis(self.axes, axisint) img.set_scale(self) return imgs
def _apply_dask( self, func: Callable, c_axes: str | None = None, drop_axis: Iterable[int] = [], new_axis: Iterable[int] = None, dtype = np.float32, out_chunks: tuple[int, ...] = None, args: tuple[Any] = None, kwargs: dict[str, Any] = None ) -> Self: """ Convert array into dask array and run a batch process in parallel. In many cases batch process in this way is faster than `multiprocess` module. Parameters ---------- func : callable Function to apply. c_axes : str, optional Axes to iterate. drop_axis : Iterable[int], optional Passed to map_blocks. new_axis : Iterable[int], optional Passed to map_blocks. dtype : any that can be converted to np.dtype object, default is np.float32 Output data type. out_chunks : tuple of int, optional Output chunks. This argument is important when the output shape will change. args : tuple, optional Arguments that will passed to `func`. kwargs : dict Keyword arguments that will passed to `func`. Returns ------- MetaArray Processed array. """ if args is None: args = tuple() if kwargs is None: kwargs = dict() if len(c_axes) == 0: # Do not construct dask tasks if it is not needed. out = xp.asnumpy(func(self.value, *args, **kwargs), dtype=dtype) else: from dask import array as da new_axis = _list_of_axes(self, new_axis) drop_axis = _list_of_axes(self, drop_axis) # determine chunk size and slices chunks = switch_slice(c_axes, self.axes, ifin=1, ifnot=self.shape) slice_in = [] slice_out = [] for i, a in enumerate(self.axes): if a in c_axes: slice_in.append(0) slice_out.append(np.newaxis) else: slice_in.append(slice(None)) slice_out.append(slice(None)) if i in drop_axis: slice_out.pop(-1) if i in new_axis: slice_in.append(np.newaxis) slice_in = tuple(slice_in) slice_out = tuple(slice_out) all_args = (self.value,) + args img_idx = [] _args = [] for i, arg in enumerate(all_args): if isinstance(arg, (np.ndarray, xp.ndarray)) and arg.shape == self.shape: _args.append(da.from_array(arg, chunks=chunks)) img_idx.append(i) else: _args.append(arg) def _func(*args, **kwargs): args = list(args) for i in img_idx: if args[i].ndim < len(slice_in): continue args[i] = args[i][slice_in] out = func(*args, **kwargs) return xp.asnumpy(out[slice_out]) out = da.map_blocks( _func, *_args, drop_axis=drop_axis, new_axis=new_axis, meta=xp.array([], dtype=dtype), chunks=out_chunks, **kwargs ) out = out.compute() out = out.view(self.__class__) return out
[docs] def transpose(self, axes) -> Self: """ change the order of image dimensions. 'axes' will also be arranged. """ out = super().transpose(axes) if self.axes.is_none(): new_axes = None else: new_axes = "".join([self.axes[i] for i in list(axes)]) out._set_info(self, new_axes=new_axes) return out
def _broadcast(self, value): """ More flexible broadcasting. If `self` has "zcyx"-axes and `value` has "zyx"-axes, then they should be broadcasted by stacking `value` along "c"-axes """ if isinstance(value, MetaArray) and not value.axes.is_none(): value = add_axes(self.axes, self.shape, value, value.axes) elif isinstance(value, np.ndarray): try: if self.sizesof("yx") == value.shape: value = add_axes(self.axes, self.shape, value) except AttributeError: pass return value def __add__(self, value) -> Self: value = self._broadcast(value) return super().__add__(value) def __sub__(self, value) -> Self: value = self._broadcast(value) return super().__sub__(value) def __mul__(self, value) -> Self: value = self._broadcast(value) return super().__mul__(value) def __truediv__(self, value) -> Self: value = self._broadcast(value) return super().__truediv__(value) def __mod__(self, value) -> Self: value = self._broadcast(value) return super().__mod__(value) def __floordiv__(self, value) -> Self: value = self._broadcast(value) return super().__floordiv__(value) def __gt__(self, value) -> Self: value = self._broadcast(value) return super().__gt__(value) def __ge__(self, value) -> Self: value = self._broadcast(value) return super().__ge__(value) def __lt__(self, value) -> Self: value = self._broadcast(value) return super().__lt__(value) def __le__(self, value) -> Self: value = self._broadcast(value) return super().__le__(value) def __eq__(self, value) -> Self: value = self._broadcast(value) return super().__eq__(value) def __ne__(self, value) -> Self: value = self._broadcast(value) return super().__ne__(value) def __and__(self, value) -> Self: value = self._broadcast(value) return super().__and__(value) def __or__(self, value) -> Self: value = self._broadcast(value) return super().__or__(value) def __ne__(self, value) -> Self: value = self._broadcast(value) return super().__ne__(value) def __iadd__(self, value) -> Self: value = self._broadcast(value) return super().__iadd__(value) def __isub__(self, value) -> Self: value = self._broadcast(value) return super().__isub__(value) def __imul__(self, value) -> Self: value = self._broadcast(value) return super().__imul__(value) def __itruediv__(self, value) -> Self: value = self._broadcast(value) return super().__itruediv__(value) def __imod__(self, value) -> Self: value = self._broadcast(value) return super().__imod__(value) def __ifloordiv__(self, value) -> Self: value = self._broadcast(value) return super().__ifloordiv__(value)
def _list_of_axes(img: MetaArray, axis): if axis is None: axis = [] elif isinstance(axis, str): axis = [img.axisof(a) for a in axis] elif np.isscalar(axis): axis = [axis] return axis def _replace_inputs(img: MetaArray, args, kwargs): _as_np_ndarray = lambda a: a.value if isinstance(a, MetaArray) else a # convert arguments args = tuple(_as_np_ndarray(a) for a in args) if "axis" in kwargs: axis = kwargs["axis"] if isinstance(axis, str): _axis = tuple(map(img.axisof, axis)) if len(_axis) == 1: _axis = _axis[0] kwargs["axis"] = _axis if "axes" in kwargs: # used in such as np.rot90 axes = kwargs["axes"] if isinstance(axes, str): _axes = tuple(map(img.axisof, axes)) kwargs["axes"] = _axes if "out" in kwargs: kwargs["out"] = tuple(_as_np_ndarray(a) for a in kwargs["out"]) return args, kwargs