Source code for impy.arrays.lazy

from __future__ import annotations
from functools import wraps
import os
import itertools
from typing import Any, Callable, TYPE_CHECKING
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from warnings import warn
from collections import namedtuple
import tempfile

from .labeledarray import LabeledArray
from .imgarray import ImgArray
from .axesmixin import AxesMixin
from ._utils._skimage import skres
from ._utils import _misc, _transform, _structures, _filters, _deconv, _corr, _docs

from ..utils.axesop import switch_slice, complement_axes, find_first_appeared, del_axis
from ..utils.deco import record_lazy, dims_to_spatial_axes, same_dtype, make_history
from ..utils.misc import check_nd
from ..utils.slicer import axis_targeted_slicing, key_repr
from ..utils.utilcls import Progress
from ..utils.io import get_imsave_meta_from_img, memmap
from ..collections import DataList

from .._types import nDFloat, Coords, Iterable, Dims
from ..axes import ImageAxesError
from .._const import Const
from .._cupy import xp, xp_ndi, xp_fft, asnumpy

if TYPE_CHECKING:
    from dask import array as da


[docs]class LazyImgArray(AxesMixin): additional_props = ["dirpath", "metadata", "name"] def __init__(self, obj: "da.core.Array", name: str = None, axes: str = None, dirpath: str = None, history: list[str] = None, metadata: dict = None): from dask import array as da if not isinstance(obj, da.core.Array): raise TypeError(f"The first input must be dask array, got {type(obj)}") self.value = obj self.dirpath = dirpath self.name = name # MicroManager if isinstance(self.name, str) and self.name.endswith("_MMStack_Pos0.ome"): self.name = self.name[:-17] self.axes = axes self.metadata = metadata self.history = [] if history is None else history @property def ndim(self): return self.value.ndim @property def shape(self): try: tup = namedtuple("AxesShape", list(self.axes)) return tup(*self.value.shape) except ImageAxesError: return self.value.shape @property def dtype(self): return self.value.dtype @property def size(self): return self.value.size @property def itemsize(self): return self.value.itemsize @property def chunksize(self): try: tup = namedtuple("AxesShape", list(self.axes)) return tup(*self.value.chunksize) except ImageAxesError: return self.value.chunksize @property def gb(self): return self.value.nbytes / 1e9 def __array__(self): # Should not be `self.compute` because in napari Viewer this function is called every time # sliders are moved. return asnumpy(self.value.compute()) def __getitem__(self, key): if isinstance(key, str): key = axis_targeted_slicing(self.value.ndim, self.axes, key) keystr = key_repr(key) # write down key like "0,*,*" if hasattr(key, "__array__"): # fancy indexing will lose axes information new_axes = None elif "new" in keystr: # np.newaxis or None will add dimension new_axes = None elif self.axes: del_list = [i for i, s in enumerate(keystr.split(",")) if s not in ("*", "")] new_axes = del_axis(self.axes, del_list) else: new_axes = None out = self.__class__(self.value[key], name=self.name, dirpath=self.dirpath, axes=new_axes, metadata=self.metadata, history=self.history) out._getitem_additional_set_info(self, keystr=keystr, new_axes=new_axes, key=key) return out def __neg__(self) -> LazyImgArray: out = self.__class__(-self.value) out._set_info(self, next_history="neg") return out @same_dtype(asfloat=True) def __add__(self, other) -> LazyImgArray: if isinstance(other, self.__class__): out = self.value + other.value else: out = self.value + other out = self.__class__(out) out._set_info(self, next_history="add") return out @same_dtype(asfloat=True) def __iadd__(self, other) -> LazyImgArray: if isinstance(other, self.__class__): self.value += other.value else: self.value += other self.history.append("add") return self @same_dtype(asfloat=True) def __sub__(self, other) -> LazyImgArray: if isinstance(other, self.__class__): out = self.value - other.value else: out = self.value - other out = self.__class__(out) out._set_info(self, next_history="subtract") return out @same_dtype(asfloat=True) def __isub__(self, other) -> LazyImgArray: if isinstance(other, self.__class__): self.value -= other.value else: self.value -= other self.history.append("subtract") return self @same_dtype(asfloat=True) def __mul__(self, other) -> LazyImgArray: if isinstance(other, np.ndarray) and other.dtype.kind != "c": other = other.astype(np.float32) other = other elif isinstance(other, self.__class__) and other.dtype.kind != "c": other = other.as_float() other = other.value elif np.isscalar(other) and other < 0: raise ValueError("Cannot multiply negative value.") else: other = other out = self.value * other out = self.__class__(out) out._set_info(self, next_history="multiply") return out @same_dtype(asfloat=True) def __imul__(self, other) -> LazyImgArray: if isinstance(other, np.ndarray) and other.dtype.kind != "c": other = other.astype(np.float32) other = other elif isinstance(other, self.__class__) and other.dtype.kind != "c": other = other.as_float() other = other.value elif np.isscalar(other) and other < 0: raise ValueError("Cannot multiply negative value.") else: other = other self.value *= other self.history.append("multiply") return self def __truediv__(self, other) -> LazyImgArray: self = self.as_float() if isinstance(other, np.ndarray) and other.dtype.kind != "c": other = other.astype(np.float32) other[other==0] = np.inf other = other elif isinstance(other, self.__class__) and other.dtype.kind != "c": other = other.as_float() other[other==0] = np.inf other = other.value elif np.isscalar(other) and other <= 0: raise ValueError("Cannot multiply negative value.") else: other = other out = self.value / other out = self.__class__(out) out._set_info(self, next_history="divide") return out def __itruediv__(self, other) -> LazyImgArray: if self.dtype.kind in "ui": raise ValueError("Cannot divide integer inplace.") if isinstance(other, np.ndarray) and other.dtype.kind != "c": other = other.astype(np.float32) other[other==0] = np.inf other = other elif isinstance(other, self.__class__) and other.dtype.kind != "c": other = other.as_float() other[other==0] = np.inf other = other.value elif np.isscalar(other) and other < 0: raise ValueError("Cannot multiply negative value.") else: other = other self.value /= other self.history.append("divide") return self @property def chunk_info(self): if self.axes.is_none(): chunk_info = self.chunksize else: chunk_info = ", ".join([f"{s}({o})" for s, o in zip(self.chunksize, self.axes)]) return chunk_info def _repr_dict_(self): return {" shape ": self.shape_info, " chunk sizes ": self.chunk_info, " dtype ": self.dtype, " directory ": self.dirpath, "original image": self.name, " history ": "->".join(self.history)} def __repr__(self): return "\n" + "\n".join(f"{k}: {v}" for k, v in self._repr_dict_().items()) + "\n"
[docs] def compute(self, ignore_limit: bool = False) -> ImgArray: """ Compute all the task and convert the result into ImgArray. If image size overwhelms MAX_GB then MemoryError is raised. """ if self.gb > Const["MAX_GB"] and not ignore_limit: raise MemoryError(f"Too large: {self.gb:.2f} GB") with Progress("Converting to ImgArray"): arr = self.value.compute() if arr.ndim > 0: img = asnumpy(arr).view(ImgArray) for attr in ["name", "dirpath", "axes", "metadata", "history"]: setattr(img, attr, getattr(self, attr, None)) else: img = arr return img
@property def data(self): warn("'data' should no longer be used and will be removed soon. Use 'img.compute()' instead.", DeprecationWarning) return self.compute() @property def img(self): warn("'img' is renamed to 'value' for compatibility with ImgArray and will be removed soon.", DeprecationWarning) return self.value
[docs] def release(self, update: bool = True) -> LazyImgArray: """ Compute all the tasks and store the data in memory map, and read it as a dask array again. """ from dask import array as da with Progress("Releasing jobs"): with tempfile.NamedTemporaryFile() as ntf: mmap = np.memmap(ntf, mode="w+", shape=self.shape, dtype=self.dtype) mmap[:] = self.value[:] img = da.from_array(mmap, chunks=self.chunksize).map_blocks( np.array, meta=np.array([], dtype=self.dtype) ) if update: self.value = img out = self else: out = self.__class__(img) out._set_info(self) return out
[docs] @_docs.copy_docs(LabeledArray.imsave) def imsave(self, tifname: str, dtype = None): if not tifname.endswith(".tif"): tifname += ".tif" if os.sep not in tifname: tifname = os.path.join(self.dirpath, tifname) if self.metadata is None: self.metadata = {} if dtype is None: dtype = self.dtype self = self.as_img_type(dtype).sort_axes() imsave_kwargs = get_imsave_meta_from_img(self, update_lut=False) memmap_image = memmap(tifname, shape=self.shape, dtype=self.dtype, **imsave_kwargs) with Progress("Saving"): memmap_image[:] = self.value[:] memmap_image.flush() return None
[docs] def rechunk(self, chunks="auto", *, threshold=None, block_size_limit=None, balance=False, update=False) -> LazyImgArray: """ Rechunk the bound dask array. Parameters ---------- chunks, threshold, block_size_limit, balance Passed directly to dask.array's rechunk Returns ------- LazyImgArray Rechunked dask array is bound. History will not be updated. """ rechunked = self.value.rechunk(chunks=chunks, threshold=threshold, block_size_limit=block_size_limit, balance=balance) if update: self.value = rechunked return self else: out = self.__class__(rechunked) out._set_info(self) return out
[docs] def apply_dask_func(self, funcname: str, *args, **kwargs) -> LazyImgArray: """ Apply dask array function to the connected dask array. Parameters ---------- funcname : str Name of function to apply. args, kwargs : Parameters that will be passed to `funcname`. Returns ------- LazyImgArray Updated one """ out = getattr(self.value, funcname)(*args, **kwargs) out = self.__class__(out) new_axes = "inherit" if out.shape == self.shape else None out._set_info(self, make_history(funcname, args, kwargs), new_axes=new_axes) return out
def _apply_function(self, func: Callable, c_axes: str = None, drop_axis: Iterable[int] = [], new_axis: Iterable[int] = None, dtype = np.float32, rechunk_to: tuple[int, ...] | str = "none", dask_wrap: bool = False, args: tuple = None, kwargs: dict[str] = None) -> LazyImgArray: """ Rechunk array in a correct shape and apply function using `map_blocks`. This function is similar to the `apply_dask` function in `MetaArray` while returns dask array bound LazyImgArray. Parameters ---------- func : callable Function to apply for each chunk. c_axes : str, optional Axes to iterate. drop_axis : Iterable[int], optional Passed to map_blocks. new_axis : Iterable[int], optional Passed to map_blocks. dtype : any that can be converted to np.dtype object, default is np.float32 Output data type. rechunk_to : tuple[int,...], optional In what size input array should be rechunked before `map_blocks` iteration. If str is given, array will be rechunked in following rules: - "none": No rechunking - "default": Rechunked with "auto" method for each spatial dimension. - "max": Rechunked to the shape size for each spatial dimension. dask_wrap : bool, optional If True, for each chunk array will be converted to dask and rechunked with "auto" option before function call. args : tuple, optional Arguments that will passed to `func`. kwargs : dict Keyword arguments that will passed to `func`. Returns ------- LazyImgArray Dask array after function is applied is bound to this newly generated object. """ slice_in = [] slice_out = [] for a in self.axes: if a in c_axes: slice_in.append(0) slice_out.append(np.newaxis) else: slice_in.append(slice(None)) slice_out.append(slice(None)) slice_in = tuple(slice_in) slice_out = tuple(slice_out) if args is None: args = tuple() if kwargs is None: kwargs = dict() if rechunk_to == "none": input_ = self.value else: if rechunk_to == "default": rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot="auto") elif rechunk_to == "max": rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot=self.shape) input_ = self.value.rechunk(rechunk_to) if dask_wrap: from dask import array as da @wraps(func) def _func(arr, *args, **kwargs): out = func(da.from_array(arr[slice_in]), *args, **kwargs) return out[slice_out].compute() else: @wraps(func) def _func(arr, *args, **kwargs): out = func(arr[slice_in], *args, **kwargs) return out[slice_out] out = input_.map_blocks(_func, *args, drop_axis=drop_axis, new_axis=new_axis, meta=xp.array([], dtype=dtype), **kwargs) return out def _apply_map_blocks(self, func: Callable, c_axes: str = None, args: tuple = None, kwargs: dict[str, Any] = None): from dask import array as da if args is None: args = () if kwargs is None: kwargs = {} all_axes = str(self.axes) def _func(input: ArrayLike, *args, **kwargs): out = xp.empty(input.shape, input.dtype) for sl in iter_slice(input.shape, c_axes, all_axes): out[sl] = func(input[sl], *args, **kwargs) return out return da.map_blocks(_func, self.value, dtype=self.dtype, *args, **kwargs) def _apply_map_overlap(self, func: Callable, c_axes: str = None, depth = 16, boundary="reflect", dtype: DTypeLike = None, args: tuple = None, kwargs: dict[str, Any] = None): from dask import array as da if args is None: args = () if kwargs is None: kwargs = {} if dtype is None: dtype = self.dtype all_axes = str(self.axes) def _func(input: ArrayLike, *args, **kwargs): out = xp.empty(input.shape, input.dtype) for sl in iter_slice(input.shape, c_axes, all_axes): out[sl] = func(input[sl], *args, **kwargs) return out depth = switch_slice(c_axes, self.axes, 0, depth) return da.map_overlap(_func, self.value, depth=depth, boundary=boundary, dtype=dtype, *args, **kwargs) def _apply_dask_filter(self, func: Callable, c_axes: str = None, args: tuple = None, kwargs: dict[str, Any] = None) -> LazyImgArray: # TODO: This is not efficient. Maybe using da.stack is better? from dask import array as da out = da.empty_like(self.value) if args is None: args = () if kwargs is None: kwargs = {} for sl, img in self.iter(c_axes, israw=False): out[sl] = func(img, *args, **kwargs) return out # TODO: This should wait for dask-image implement map_coordinates # @_docs.copy_docs(LabeledArray.rotated_crop) # @dims_to_spatial_axes # @record_lazy # def rotated_crop(self, origin, dst1, dst2, dims=2) -> LazyImgArray: # origin = np.asarray(origin) # dst1 = np.asarray(dst1) # dst2 = np.asarray(dst2) # ax0 = _misc.make_rotated_axis(origin, dst2) # ax1 = _misc.make_rotated_axis(dst1, origin) # all_coords = ax0[:, np.newaxis] + ax1[np.newaxis] - origin # all_coords = np.moveaxis(all_coords, -1, 0) # Because output shape changes, we have to tell dask what chunk size it would be, otherwise output # shape is estimated in a wrong way. # output_chunks = [1] * self.ndim # for i, a in enumerate(dims): # it = self.axisof(a) # output_chunks[it] = all_coords.shape[i+1] # # cropped_img = self._apply_function(xp_ndi.map_coordinates, # c_axes=complement_axes(dims, self.axes), # dtype=self.dtype, # rechunk_to="max", # args=(xp.asarray(all_coords),), # kwargs=dict(prefilter=False, order=1, chunks=output_chunks) # )
[docs] @_docs.copy_docs(ImgArray.erosion) @dims_to_spatial_axes @record_lazy def erosion(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) c_axes = complement_axes(dims, self.axes) filter_func = xp_ndi.grey_erosion if self.dtype != bool else xp_ndi.binary_erosion return self._apply_map_overlap( filter_func, c_axes=c_axes, depth=_ceilint(radius), kwargs=dict(footprint=disk) )
[docs] @_docs.copy_docs(ImgArray.dilation) @dims_to_spatial_axes @record_lazy def dilation(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) c_axes = complement_axes(dims, self.axes) filter_func = xp_ndi.grey_dilation if self.dtype != bool else xp_ndi.binary_dilation return self._apply_map_overlap( filter_func, c_axes=c_axes, depth=_ceilint(radius), kwargs=dict(footprint=disk) )
[docs] @_docs.copy_docs(ImgArray.opening) @dims_to_spatial_axes @record_lazy def opening(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) c_axes = complement_axes(dims, self.axes) filter_func = xp_ndi.grey_opening if self.dtype != bool else xp_ndi.binary_opening return self._apply_map_overlap( filter_func, c_axes=c_axes, depth=_ceilint(radius)*2, kwargs=dict(footprint=disk) )
[docs] @_docs.copy_docs(ImgArray.closing) @dims_to_spatial_axes @record_lazy def closing(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) c_axes = complement_axes(dims, self.axes) filter_func = xp_ndi.grey_closing if self.dtype != bool else xp_ndi.binary_closing return self._apply_map_overlap( filter_func, c_axes=c_axes, depth=_ceilint(radius)*2, kwargs=dict(footprint=disk) )
[docs] @_docs.copy_docs(ImgArray.gaussian_filter) @dims_to_spatial_axes @same_dtype(asfloat=True) @record_lazy def gaussian_filter(self, sigma: nDFloat = 1.0, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: c_axes = complement_axes(dims, self.axes) depth = _ceilint(sigma*4) return self._apply_map_overlap( xp_ndi.gaussian_filter, c_axes=c_axes, depth=depth, kwargs=dict(sigma=sigma), )
[docs] @_docs.copy_docs(ImgArray.median_filter) @dims_to_spatial_axes @same_dtype @record_lazy def median_filter(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) return self._apply_map_overlap( xp_ndi.median_filter, depth=_ceilint(radius), c_axes=complement_axes(dims, self.axes), kwargs=dict(footprint=disk) )
[docs] @_docs.copy_docs(ImgArray.mean_filter) @same_dtype(asfloat=True) @dims_to_spatial_axes @record_lazy def mean_filter(self, radius: float = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: disk = _structures.ball_like(radius, len(dims)) kernel = (disk/np.sum(disk)).astype(np.float32) return self._apply_map_overlap( xp_ndi.convolve, depth=_ceilint(radius), c_axes=complement_axes(dims, self.axes), kwargs=dict(weights=kernel), )
[docs] @_docs.copy_docs(ImgArray.convolve) @dims_to_spatial_axes @same_dtype(asfloat=True) @record_lazy def convolve(self, kernel, *, mode: str = "reflect", cval: float = 0, dims: Dims = None, update: bool = False) -> LazyImgArray: from dask_image.ndfilters import convolve kernel = np.asarray(kernel) shape = np.array(kernel.shape) half_size = shape // 2 depth = tuple(half_size) c_axes = complement_axes(dims, self.axes) return self._apply_map_overlap( xp_ndi.convolve, c_axes=c_axes, depth=depth, kwargs=dict(weights=kernel, mode=mode, cval=cval), )
[docs] @_docs.copy_docs(ImgArray.edge_filter) @dims_to_spatial_axes @same_dtype @record_lazy def edge_filter(self, method: str = "sobel", *, dims: Dims = None, update: bool = False ) -> LazyImgArray: # BUG: returns zero array from ._utils._skimage import skfil method_dict = {"sobel": (skfil.sobel, 1), "farid": (skfil.farid, 2), "scharr": (skfil.scharr, 1), "prewitt": (skfil.prewitt, 1)} try: filter_func, depth = method_dict[method] except KeyError: raise ValueError("`method` must be 'sobel', 'farid' 'scharr', or 'prewitt'.") return self._apply_map_overlap( filter_func, depth=depth, c_axes=complement_axes(dims, self.axes) )
[docs] @_docs.copy_docs(ImgArray.laplacian_filter) @dims_to_spatial_axes @same_dtype @record_lazy def laplacian_filter(self, radius: int = 1, *, dims: Dims = None, update: bool = False ) -> LazyImgArray: ndim = len(dims) _, laplace_op = skres.uft.laplacian(ndim, (2*radius+1,) * ndim) return self._apply_map_overlap( xp_ndi.convolve, depth=_ceilint(radius), c_axes=complement_axes(dims, self.axes), args=(laplace_op,), )
[docs] @_docs.copy_docs(ImgArray.affine) @dims_to_spatial_axes @same_dtype(asfloat=True) @record_lazy def affine(self, matrix=None, scale=None, rotation=None, shear=None, translation=None, *, mode="constant", cval=0, output_shape=None, order=1, dims=None) -> LazyImgArray: if matrix is None: matrix = _transform.compose_affine_matrix(scale=scale, rotation=rotation, shear=shear, translation=translation, ndim=len(dims)) from dask_image.ndinterp import affine_transform return self._apply_dask_filter( affine_transform, c_axes=complement_axes(dims, self.axes), kwargs=dict(matrix=matrix, mode=mode, cval=cval, output_shape=output_shape, order=order) )
[docs] @_docs.copy_docs(ImgArray.kalman_filter) @dims_to_spatial_axes @same_dtype(asfloat=True) @record_lazy def kalman_filter(self, gain: float = 0.8, noise_var: float = 0.05, *, along: str = "t", dims: Dims = None, update: bool = False) -> LazyImgArray: if self.axisof(along) != 0: raise ValueError("Currently kalman_filter does not support t-axis != 0.") return self._apply_map_blocks( _filters.kalman_filter, c_axes=complement_axes(along + dims, self.axes), args=(gain, noise_var) )
[docs] @_docs.copy_docs(ImgArray.fft) @dims_to_spatial_axes @record_lazy def fft(self, *, shape: int | Iterable[int] | str = "same", shift: bool = True, dims: Dims = None) -> LazyImgArray: from dask import array as da axes = [self.axisof(a) for a in dims] if shape == "square": s = 2**int(np.ceil(np.max(self.sizesof(dims)))) shape = (s,) * len(dims) elif shape == "same": shape = None else: shape = check_nd(shape, len(dims)) freq = da.fft.fftn(self.value.astype(np.float32), s=shape, axes=axes).astype(np.complex64) if shift: freq[:] = da.fft.fftshift(freq, axes=axes) return freq
[docs] @_docs.copy_docs(ImgArray.ifft) @dims_to_spatial_axes @record_lazy def ifft(self, real:bool=True, *, shift:bool=True, dims=None) -> LazyImgArray: from dask import array as da axes = [self.axisof(a) for a in dims] if shift: freq = da.fft.ifftshift(self.value, axes=axes) else: freq = self.value out = da.fft.ifftn(freq, axes=axes).astype(np.complex64) if real: out = da.real(out) return out
[docs] @_docs.copy_docs(ImgArray.power_spectra) @dims_to_spatial_axes @record_lazy def power_spectra(self, shape = "same", norm: bool = False, zero_norm: bool = False, *, dims: Dims = None) -> LazyImgArray: freq = self.fft(dims=dims, shape=shape) pw = freq.value.real**2 + freq.value.imag**2 if norm: pw /= pw.max() if zero_norm: sl = switch_slice(dims, pw.axes, ifin=np.array(pw.shape)//2, ifnot=slice(None)) pw[sl] = 0 return pw
[docs] def chunksizeof(self, axis:str): return self.value.chunksize[self.axes.find(axis)]
[docs] def chunksizesof(self, axes:str): return tuple(self.chunksizeof(a) for a in axes)
[docs] def transpose(self, axes): if self.axes.is_none(): new_axes = None else: new_axes = "".join([self.axes[i] for i in list(axes)]) out = self.__class__(self.value.transpose(axes)) out._set_info(self, new_axes=new_axes) return out
[docs] def sort_axes(self): order = self.axes.argsort() return self.transpose(tuple(order))
[docs] @_docs.copy_docs(LabeledArray.crop_center) @dims_to_spatial_axes @record_lazy def crop_center(self, scale=0.5, *, dims=2) -> LazyImgArray: # check scale if hasattr(scale, "__iter__") and len(scale) == 3 and len(dims) == 2: dims = "zyx" scale = np.asarray(check_nd(scale, len(dims))) if np.any((scale <= 0) | (1 < scale)): raise ValueError(f"scale must be (0, 1], but got {scale}") # Make axis-targeted slicing string sizes = self.sizesof(dims) slices = [] for a, size, sc in zip(dims, sizes, scale): x0 = int(size / 2 * (1 - sc)) x1 = int(np.ceil(size / 2 * (1 + sc))) slices.append(f"{a}={x0}:{x1}") out = self[";".join(slices)] return out
[docs] @_docs.copy_docs(ImgArray.tiled_lowpass_filter) @dims_to_spatial_axes @record_lazy def tiled_lowpass_filter(self, cutoff: float = 0.2, order: int = 2, overlap: int = 16, *, dims: Dims = None, update: bool = False) -> LazyImgArray: from ._utils._skimage import _get_ND_butterworth_filter self = self.as_float() cutoff = check_nd(cutoff, len(dims)) c_axes = complement_axes(dims, self.axes) if all((c >= 0.5 or c <= 0) for c in cutoff): return self depth = switch_slice(dims, self.axes, overlap, 0) def func(arr): arr = xp.asarray(arr) shape = arr.shape weight = _get_ND_butterworth_filter(shape, cutoff, order, False, True) ft = weight * xp_fft.rfftn(arr) ift = xp_fft.irfftn(ft, s=shape) return ift out = self._apply_map_overlap(func, c_axes=c_axes, depth=depth, boundary="reflect") return out
[docs] @_docs.copy_docs(ImgArray.proj) @same_dtype def proj(self, axis: str = None, method: str = "mean") -> LazyImgArray: from dask import array as da if axis is None: axis = find_first_appeared("ztpi<c", include=self.axes, exclude="yx") elif not isinstance(axis, str): raise TypeError("`axis` must be str.") axisint = [self.axisof(a) for a in axis] if method == "mean": projection = getattr(da, method)(self.value, axis=tuple(axisint), dtype=np.float32) else: projection = getattr(da, method)(self.value, axis=tuple(axisint)) out = self.__class__(projection) out._set_info(self, f"proj(axis={axis}, method={method})", del_axis(self.axes, axisint)) return out
[docs] @_docs.copy_docs(ImgArray.binning) @dims_to_spatial_axes @same_dtype def binning(self, binsize: int = 2, method = "mean", *, check_edges: bool = True, dims: Dims = None) -> LazyImgArray: if binsize == 1: return self if isinstance(method, str): binfunc = getattr(xp, method) elif callable(method): binfunc = method else: raise TypeError("`method` must be a numpy function or callable object.") img_to_reshape, shape, scale_ = _misc.adjust_bin(self.value, binsize, check_edges, dims, self.axes) reshaped_img = img_to_reshape.reshape(shape) axes_to_reduce = tuple(i*2+1 for i in range(self.ndim)) out = binfunc(reshaped_img, axis=axes_to_reduce) out = self.__class__(out) out._set_info(self, f"binning(binsize={binsize})") out.axes = str(self.axes) # _set_info does not pass copy so new axes must be defined here. out.set_scale({a: self.scale[a]/scale for a, scale in zip(self.axes, scale_)}) return out
[docs] @_docs.copy_docs(ImgArray.track_drift) def track_drift(self, along: str = None, upsample_factor: int = 10) -> "da.core.Array": if along is None: along = find_first_appeared("tpzc<i", include=self.axes) elif len(along) != 1: raise ValueError("`along` must be single character.") dims = complement_axes(along, self.axes) chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1) img_fft = self.fft(shift=False, dims=dims).value.rechunk(chunks) ndim = len(dims) slice_out = (np.newaxis, slice(None)) + (np.newaxis,)*(ndim-1) each_shape = (1, ndim) + (1,)*(ndim-1) len_t = self.sizeof(along) def pcc(x): if x.shape[0] < 2: return np.array([0]*ndim, dtype=np.float32).reshape(*each_shape) x = xp.asarray(x) result = _corr.subpixel_pcc(x[0], x[1], upsample_factor=upsample_factor) return asnumpy(result[slice_out]) from dask import array as da # I don't know the reason why but output dask array's chunk size along t-axis should be # specified to be 1, and rechunk it map_overlap. result = da.map_overlap(pcc, img_fft, depth={0: (1, 0)}, trim=False, boundary="none", chunks=(1, ndim) + (1,)*(ndim-1), meta=np.array([], dtype=np.float32) ) # For cupy, we must call map_blocks (or from_delayed and delayed) here. result = da.map_blocks(np.cumsum, result[..., 0].rechunk((len_t, ndim)), axis=0, meta=np.array([], dtype=np.float32) ) return result
[docs] @_docs.copy_docs(ImgArray.drift_correction) @same_dtype(asfloat=True) @record_lazy @dims_to_spatial_axes def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *, zero_ave: bool = True, along: str = None, dims: Dims = 2, update: bool = False, **affine_kwargs) -> LazyImgArray: if along is None: along = find_first_appeared("tpzcia", include=self.axes, exclude=dims) elif len(along) != 1: raise ValueError("`along` must be single character.") from ..frame import MarkerFrame if shift is None: # determine 'ref' if ref is None: ref = self _dims = complement_axes(along, self.axes) if dims != _dims: warn(f"dims={dims} with along={along} and {self.axes}-image are not " f"valid input. Changed to dims={_dims}", UserWarning) dims = _dims elif not isinstance(ref, self.__class__): raise TypeError(f"'ref' must be LazyImgArray object, but got {type(ref)}") elif ref.axes != along + dims: raise ValueError(f"Arguments `along`({along}) + `dims`({dims}) do not match " f"axes of `ref`({ref.axes})") shift = ref.track_drift(along=along) elif isinstance(shift, MarkerFrame): if len(shift) != self.sizeof(along): raise ValueError("Wrong shape of 'shift'.") shift = shift.values from dask import array as da if zero_ave: shift = shift - da.mean(shift, axis=0) t_index = self.axisof(along) slice_in = switch_slice(dims, self.axes, ifin=slice(None), ifnot=0) slice_out = switch_slice(dims, self.axes, ifin=slice(None), ifnot=np.newaxis) ndim = len(dims) # Here shift must be a local variable for the function. Otherwise, it takes dask very long time # for graph construction. def warp(arr, shift, block_info=None): arr = xp.asarray(arr) mx = xp.eye(ndim+1, dtype=np.float32) loc = block_info[None]["array-location"][0] mx[:-1, -1] = -xp.asarray(shift[loc[t_index]]) return asnumpy( _transform.warp(arr[slice_in], mx, **affine_kwargs)[slice_out] ) chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1) out = da.map_blocks(warp, self.value.rechunk(chunks), shift, meta=np.array([], dtype=self.dtype)) return out
[docs] @_docs.copy_docs(ImgArray.pad) @dims_to_spatial_axes @record_lazy def pad(self, pad_width, mode: str = "constant", *, dims: Dims = None, **kwargs) -> LazyImgArray: pad_width = _misc.make_pad(pad_width, dims, self.axes, **kwargs) padimg = np.pad(self.value, pad_width, mode, **kwargs) return padimg
# @_docs.copy_docs(ImgArray.wiener) # @dims_to_spatial_axes # @same_dtype(asfloat=True) # @record_lazy # def wiener(self, psf: np.ndarray, lmd: float = 0.1, *, depth="auto", dims: Dims = None, update: bool = False) -> LazyImgArray: # if lmd <= 0: # raise ValueError(f"lmd must be positive, but got: {lmd}") # if depth == "auto": # depth = 32 # TODO: any better way? # psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims) # return self._apply_map_overlap # return self._apply_function(_deconv.wiener, # c_axes=complement_axes(dims, self.axes), # rechunk_to="max", # args=(psf_ft, psf_ft_conj, lmd) # ) # @_docs.copy_docs(ImgArray.lucy) # @dims_to_spatial_axes # @same_dtype(asfloat=True) # @record_lazy # def lucy(self, psf: np.ndarray, niter: int = 50, eps: float = 1e-5, depth: int = 32, *, dims: Dims = None, # update: bool = False) -> LazyImgArray: # psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims) # return self._apply_map_overlap(_deconv.richardson_lucy, # c_axes=complement_axes(dims, self.axes), # depth=depth, # boundary="nearest", # args=(psf_ft, psf_ft_conj, niter, eps) # ) def __array_function__(self, func, types, args, kwargs): """ Every time a numpy function (np.mean...) is called, this function will be called. Essentially numpy function can be overloaded with this method. """ from dask import array as da args, kwargs = _replace_inputs(self, args, kwargs) _types = [] for t in types: if t is self.__class__: _types.append(da.core.Array) else: _types.append(t) result = self.value.__array_function__(func, _types, args, kwargs) if result is NotImplemented: return NotImplemented if isinstance(result, (tuple, list)): out = [] for r in result: if isinstance(r, da.core.Array): out.append( self.__class__(r)._process_output(self, args, kwargs) ) else: out.append(r) out = DataList(out) elif isinstance(result, da.core.Array): out = self.__class__(result) out._process_output(self, args, kwargs) return out def _process_output(self, input: LazyImgArray, args: tuple, kwargs: dict): if "axis" in kwargs.keys() and not input.axes.is_none(): new_axes = del_axis(input.axes, kwargs["axis"]) else: new_axes = "inherit" self._set_info(input, new_axes=new_axes) return None
[docs] def as_uint8(self) -> LazyImgArray: img = self.value if img.dtype == np.uint8: return img if img.dtype == np.uint16: out = img / 256 elif img.dtype.kind == "f": out = img + 0.5 out = np.clip(out, 0, 255) else: raise TypeError(f"invalid data type: {img.dtype}") out = out.astype(np.uint8) out = self.__class__(out) out._set_info(self) return out
[docs] def as_uint16(self) -> LazyImgArray: img = self.value if img.dtype == np.uint16: return img if img.dtype == np.uint8: out = img * 256 elif img.dtype == bool: out = img elif img.dtype.kind == "f": out = img + 0.5 out = np.clip(out, 0, 65535) else: raise TypeError(f"invalid data type: {img.dtype}") out = out.astype(np.uint16) out = self.__class__(out) out._set_info(self) return out
[docs] def as_float(self) -> LazyImgArray: if self.dtype == np.float32: return self out = self.value.astype(np.float32) out = self.__class__(out) out._set_info(self) return out
[docs] def as_img_type(self, dtype=np.uint16) -> LazyImgArray: dtype = np.dtype(dtype) if self.dtype == dtype: return self elif dtype == "uint16": return self.as_uint16() elif dtype == "uint8": return self.as_uint8() elif dtype == "float32": return self.as_float() elif dtype == "float64": warn("Data type float64 is not valid for images. It was converted to float32 instead", UserWarning) return self.as_float() elif dtype == "complex64": out = self.value.astype(np.complex64) out = self.__class__(out) out._set_info(self) return out elif dtype == "complex128": warn("Data type complex128 is not valid for images. It was converted to complex64 instead", UserWarning) out = self.value.astype(np.complex64) out = self.__class__(out) out._set_info(self) return out elif dtype == "int8": out = self.value.astype(np.int8) out = self.__class__(out) out._set_info(self) return out else: raise ValueError(f"dtype: {dtype}")
def _set_additional_props(self, other): # set additional properties # If `other` does not have it and `self` has, then the property will be inherited. for p in self.__class__.additional_props: setattr(self, p, getattr(other, p, getattr(self, p, None))) def _getitem_additional_set_info(self, other, **kwargs): keystr = kwargs["keystr"] self._set_info(other, f"getitem[{keystr}]", kwargs["new_axes"]) return None def _set_info(self, other: LazyImgArray, next_history = None, new_axes: str = "inherit"): self._set_additional_props(other) # set axes try: if new_axes != "inherit": self.axes = new_axes self.set_scale(other) else: self.axes = other.axes.copy() except ImageAxesError: self.axes = None # set history if next_history is not None: self.history = other.history + [next_history] else: self.history = other.history.copy() return None
def _replace_inputs(img: LazyImgArray, args, kwargs): _as_dask_array = lambda a: a.value if isinstance(a, LazyImgArray) else a # convert arguments args = tuple(_as_dask_array(a) for a in args) if "axis" in kwargs: axis = kwargs["axis"] if isinstance(axis, str): _axis = tuple(map(img.axisof, axis)) if len(_axis) == 1: _axis = _axis[0] kwargs["axis"] = _axis if "out" in kwargs: kwargs["out"] = tuple(_as_dask_array(a) for a in kwargs["out"]) return args, kwargs def _ceilint(a: float): return int(np.ceil(a))
[docs]def iter_slice(shape, iteraxes: str, all_axes: str, exclude: str = ""): ndim = len(all_axes) iterlist = switch_slice(axes=iteraxes, all_axes=all_axes, ifin=[range(s) for s in shape], ifnot=[(slice(None),)]*ndim) it = itertools.product(*iterlist) c = 0 # counter for sl in it: if len(exclude) == 0: outsl = sl else: outsl = tuple(s for i, s in enumerate(sl) if all_axes[i] not in exclude) yield outsl c += 1 # if iterlist = [] if c == 0: outsl = (slice(None),) * (len(all_axes) - len(exclude)) yield outsl