from __future__ import annotations
from functools import wraps
import os
import itertools
from typing import Any, Callable, TYPE_CHECKING
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from warnings import warn
from collections import namedtuple
import tempfile
from .labeledarray import LabeledArray
from .imgarray import ImgArray
from .axesmixin import AxesMixin
from ._utils._skimage import skres
from ._utils import _misc, _transform, _structures, _filters, _deconv, _corr, _docs
from ..utils.axesop import switch_slice, complement_axes, find_first_appeared, del_axis
from ..utils.deco import record_lazy, dims_to_spatial_axes, same_dtype, make_history
from ..utils.misc import check_nd
from ..utils.slicer import axis_targeted_slicing, key_repr
from ..utils.utilcls import Progress
from ..utils.io import get_imsave_meta_from_img, memmap
from ..collections import DataList
from .._types import nDFloat, Coords, Iterable, Dims
from ..axes import ImageAxesError
from .._const import Const
from .._cupy import xp, xp_ndi, xp_fft, asnumpy
if TYPE_CHECKING:
from dask import array as da
[docs]class LazyImgArray(AxesMixin):
additional_props = ["dirpath", "metadata", "name"]
def __init__(self, obj: "da.core.Array", name: str = None, axes: str = None, dirpath: str = None,
history: list[str] = None, metadata: dict = None):
from dask import array as da
if not isinstance(obj, da.core.Array):
raise TypeError(f"The first input must be dask array, got {type(obj)}")
self.value = obj
self.dirpath = dirpath
self.name = name
# MicroManager
if isinstance(self.name, str) and self.name.endswith("_MMStack_Pos0.ome"):
self.name = self.name[:-17]
self.axes = axes
self.metadata = metadata
self.history = [] if history is None else history
@property
def ndim(self):
return self.value.ndim
@property
def shape(self):
try:
tup = namedtuple("AxesShape", list(self.axes))
return tup(*self.value.shape)
except ImageAxesError:
return self.value.shape
@property
def dtype(self):
return self.value.dtype
@property
def size(self):
return self.value.size
@property
def itemsize(self):
return self.value.itemsize
@property
def chunksize(self):
try:
tup = namedtuple("AxesShape", list(self.axes))
return tup(*self.value.chunksize)
except ImageAxesError:
return self.value.chunksize
@property
def gb(self):
return self.value.nbytes / 1e9
def __array__(self):
# Should not be `self.compute` because in napari Viewer this function is called every time
# sliders are moved.
return asnumpy(self.value.compute())
def __getitem__(self, key):
if isinstance(key, str):
key = axis_targeted_slicing(self.value.ndim, self.axes, key)
keystr = key_repr(key) # write down key like "0,*,*"
if hasattr(key, "__array__"):
# fancy indexing will lose axes information
new_axes = None
elif "new" in keystr:
# np.newaxis or None will add dimension
new_axes = None
elif self.axes:
del_list = [i for i, s in enumerate(keystr.split(",")) if s not in ("*", "")]
new_axes = del_axis(self.axes, del_list)
else:
new_axes = None
out = self.__class__(self.value[key], name=self.name, dirpath=self.dirpath, axes=new_axes,
metadata=self.metadata, history=self.history)
out._getitem_additional_set_info(self, keystr=keystr,
new_axes=new_axes, key=key)
return out
def __neg__(self) -> LazyImgArray:
out = self.__class__(-self.value)
out._set_info(self, next_history="neg")
return out
@same_dtype(asfloat=True)
def __add__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
out = self.value + other.value
else:
out = self.value + other
out = self.__class__(out)
out._set_info(self, next_history="add")
return out
@same_dtype(asfloat=True)
def __iadd__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
self.value += other.value
else:
self.value += other
self.history.append("add")
return self
@same_dtype(asfloat=True)
def __sub__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
out = self.value - other.value
else:
out = self.value - other
out = self.__class__(out)
out._set_info(self, next_history="subtract")
return out
@same_dtype(asfloat=True)
def __isub__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
self.value -= other.value
else:
self.value -= other
self.history.append("subtract")
return self
@same_dtype(asfloat=True)
def __mul__(self, other) -> LazyImgArray:
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other = other.value
elif np.isscalar(other) and other < 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
out = self.value * other
out = self.__class__(out)
out._set_info(self, next_history="multiply")
return out
@same_dtype(asfloat=True)
def __imul__(self, other) -> LazyImgArray:
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other = other.value
elif np.isscalar(other) and other < 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
self.value *= other
self.history.append("multiply")
return self
def __truediv__(self, other) -> LazyImgArray:
self = self.as_float()
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other[other==0] = np.inf
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other[other==0] = np.inf
other = other.value
elif np.isscalar(other) and other <= 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
out = self.value / other
out = self.__class__(out)
out._set_info(self, next_history="divide")
return out
def __itruediv__(self, other) -> LazyImgArray:
if self.dtype.kind in "ui":
raise ValueError("Cannot divide integer inplace.")
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other[other==0] = np.inf
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other[other==0] = np.inf
other = other.value
elif np.isscalar(other) and other < 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
self.value /= other
self.history.append("divide")
return self
@property
def chunk_info(self):
if self.axes.is_none():
chunk_info = self.chunksize
else:
chunk_info = ", ".join([f"{s}({o})" for s, o in zip(self.chunksize, self.axes)])
return chunk_info
def _repr_dict_(self):
return {" shape ": self.shape_info,
" chunk sizes ": self.chunk_info,
" dtype ": self.dtype,
" directory ": self.dirpath,
"original image": self.name,
" history ": "->".join(self.history)}
def __repr__(self):
return "\n" + "\n".join(f"{k}: {v}" for k, v in self._repr_dict_().items()) + "\n"
[docs] def compute(self, ignore_limit: bool = False) -> ImgArray:
"""
Compute all the task and convert the result into ImgArray. If image size overwhelms MAX_GB
then MemoryError is raised.
"""
if self.gb > Const["MAX_GB"] and not ignore_limit:
raise MemoryError(f"Too large: {self.gb:.2f} GB")
with Progress("Converting to ImgArray"):
arr = self.value.compute()
if arr.ndim > 0:
img = asnumpy(arr).view(ImgArray)
for attr in ["name", "dirpath", "axes", "metadata", "history"]:
setattr(img, attr, getattr(self, attr, None))
else:
img = arr
return img
@property
def data(self):
warn("'data' should no longer be used and will be removed soon. Use 'img.compute()' instead.",
DeprecationWarning)
return self.compute()
@property
def img(self):
warn("'img' is renamed to 'value' for compatibility with ImgArray and will be removed soon.",
DeprecationWarning)
return self.value
[docs] def release(self, update: bool = True) -> LazyImgArray:
"""
Compute all the tasks and store the data in memory map, and read it as a dask array
again.
"""
from dask import array as da
with Progress("Releasing jobs"):
with tempfile.NamedTemporaryFile() as ntf:
mmap = np.memmap(ntf, mode="w+", shape=self.shape, dtype=self.dtype)
mmap[:] = self.value[:]
img = da.from_array(mmap, chunks=self.chunksize).map_blocks(
np.array, meta=np.array([], dtype=self.dtype)
)
if update:
self.value = img
out = self
else:
out = self.__class__(img)
out._set_info(self)
return out
[docs] @_docs.copy_docs(LabeledArray.imsave)
def imsave(self, tifname: str, dtype = None):
if not tifname.endswith(".tif"):
tifname += ".tif"
if os.sep not in tifname:
tifname = os.path.join(self.dirpath, tifname)
if self.metadata is None:
self.metadata = {}
if dtype is None:
dtype = self.dtype
self = self.as_img_type(dtype).sort_axes()
imsave_kwargs = get_imsave_meta_from_img(self, update_lut=False)
memmap_image = memmap(tifname, shape=self.shape, dtype=self.dtype, **imsave_kwargs)
with Progress("Saving"):
memmap_image[:] = self.value[:]
memmap_image.flush()
return None
[docs] def rechunk(self, chunks="auto", *, threshold=None, block_size_limit=None, balance=False,
update=False) -> LazyImgArray:
"""
Rechunk the bound dask array.
Parameters
----------
chunks, threshold, block_size_limit, balance
Passed directly to dask.array's rechunk
Returns
-------
LazyImgArray
Rechunked dask array is bound. History will not be updated.
"""
rechunked = self.value.rechunk(chunks=chunks, threshold=threshold,
block_size_limit=block_size_limit, balance=balance)
if update:
self.value = rechunked
return self
else:
out = self.__class__(rechunked)
out._set_info(self)
return out
[docs] def apply_dask_func(self, funcname: str, *args, **kwargs) -> LazyImgArray:
"""
Apply dask array function to the connected dask array.
Parameters
----------
funcname : str
Name of function to apply.
args, kwargs :
Parameters that will be passed to `funcname`.
Returns
-------
LazyImgArray
Updated one
"""
out = getattr(self.value, funcname)(*args, **kwargs)
out = self.__class__(out)
new_axes = "inherit" if out.shape == self.shape else None
out._set_info(self, make_history(funcname, args, kwargs), new_axes=new_axes)
return out
def _apply_function(self,
func: Callable,
c_axes: str = None,
drop_axis: Iterable[int] = [],
new_axis: Iterable[int] = None,
dtype = np.float32,
rechunk_to: tuple[int, ...] | str = "none",
dask_wrap: bool = False,
args: tuple = None, kwargs: dict[str] = None) -> LazyImgArray:
"""
Rechunk array in a correct shape and apply function using `map_blocks`. This function is similar
to the `apply_dask` function in `MetaArray` while returns dask array bound LazyImgArray.
Parameters
----------
func : callable
Function to apply for each chunk.
c_axes : str, optional
Axes to iterate.
drop_axis : Iterable[int], optional
Passed to map_blocks.
new_axis : Iterable[int], optional
Passed to map_blocks.
dtype : any that can be converted to np.dtype object, default is np.float32
Output data type.
rechunk_to : tuple[int,...], optional
In what size input array should be rechunked before `map_blocks` iteration. If str is given,
array will be rechunked in following rules:
- "none": No rechunking
- "default": Rechunked with "auto" method for each spatial dimension.
- "max": Rechunked to the shape size for each spatial dimension.
dask_wrap : bool, optional
If True, for each chunk array will be converted to dask and rechunked with "auto" option
before function call.
args : tuple, optional
Arguments that will passed to `func`.
kwargs : dict
Keyword arguments that will passed to `func`.
Returns
-------
LazyImgArray
Dask array after function is applied is bound to this newly generated object.
"""
slice_in = []
slice_out = []
for a in self.axes:
if a in c_axes:
slice_in.append(0)
slice_out.append(np.newaxis)
else:
slice_in.append(slice(None))
slice_out.append(slice(None))
slice_in = tuple(slice_in)
slice_out = tuple(slice_out)
if args is None:
args = tuple()
if kwargs is None:
kwargs = dict()
if rechunk_to == "none":
input_ = self.value
else:
if rechunk_to == "default":
rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot="auto")
elif rechunk_to == "max":
rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot=self.shape)
input_ = self.value.rechunk(rechunk_to)
if dask_wrap:
from dask import array as da
@wraps(func)
def _func(arr, *args, **kwargs):
out = func(da.from_array(arr[slice_in]), *args, **kwargs)
return out[slice_out].compute()
else:
@wraps(func)
def _func(arr, *args, **kwargs):
out = func(arr[slice_in], *args, **kwargs)
return out[slice_out]
out = input_.map_blocks(_func, *args, drop_axis=drop_axis, new_axis=new_axis,
meta=xp.array([], dtype=dtype), **kwargs)
return out
def _apply_map_blocks(self, func: Callable, c_axes: str = None,
args: tuple = None, kwargs: dict[str, Any] = None):
from dask import array as da
if args is None:
args = ()
if kwargs is None:
kwargs = {}
all_axes = str(self.axes)
def _func(input: ArrayLike, *args, **kwargs):
out = xp.empty(input.shape, input.dtype)
for sl in iter_slice(input.shape, c_axes, all_axes):
out[sl] = func(input[sl], *args, **kwargs)
return out
return da.map_blocks(_func, self.value, dtype=self.dtype, *args, **kwargs)
def _apply_map_overlap(self, func: Callable, c_axes: str = None, depth = 16, boundary="reflect",
dtype: DTypeLike = None, args: tuple = None, kwargs: dict[str, Any] = None):
from dask import array as da
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if dtype is None:
dtype = self.dtype
all_axes = str(self.axes)
def _func(input: ArrayLike, *args, **kwargs):
out = xp.empty(input.shape, input.dtype)
for sl in iter_slice(input.shape, c_axes, all_axes):
out[sl] = func(input[sl], *args, **kwargs)
return out
depth = switch_slice(c_axes, self.axes, 0, depth)
return da.map_overlap(_func, self.value, depth=depth, boundary=boundary, dtype=dtype,
*args, **kwargs)
def _apply_dask_filter(self,
func: Callable,
c_axes: str = None,
args: tuple = None,
kwargs: dict[str, Any] = None) -> LazyImgArray:
# TODO: This is not efficient. Maybe using da.stack is better?
from dask import array as da
out = da.empty_like(self.value)
if args is None:
args = ()
if kwargs is None:
kwargs = {}
for sl, img in self.iter(c_axes, israw=False):
out[sl] = func(img, *args, **kwargs)
return out
# TODO: This should wait for dask-image implement map_coordinates
# @_docs.copy_docs(LabeledArray.rotated_crop)
# @dims_to_spatial_axes
# @record_lazy
# def rotated_crop(self, origin, dst1, dst2, dims=2) -> LazyImgArray:
# origin = np.asarray(origin)
# dst1 = np.asarray(dst1)
# dst2 = np.asarray(dst2)
# ax0 = _misc.make_rotated_axis(origin, dst2)
# ax1 = _misc.make_rotated_axis(dst1, origin)
# all_coords = ax0[:, np.newaxis] + ax1[np.newaxis] - origin
# all_coords = np.moveaxis(all_coords, -1, 0)
# Because output shape changes, we have to tell dask what chunk size it would be, otherwise output
# shape is estimated in a wrong way.
# output_chunks = [1] * self.ndim
# for i, a in enumerate(dims):
# it = self.axisof(a)
# output_chunks[it] = all_coords.shape[i+1]
#
# cropped_img = self._apply_function(xp_ndi.map_coordinates,
# c_axes=complement_axes(dims, self.axes),
# dtype=self.dtype,
# rechunk_to="max",
# args=(xp.asarray(all_coords),),
# kwargs=dict(prefilter=False, order=1, chunks=output_chunks)
# )
[docs] @_docs.copy_docs(ImgArray.erosion)
@dims_to_spatial_axes
@record_lazy
def erosion(self, radius: float = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp_ndi.grey_erosion if self.dtype != bool else xp_ndi.binary_erosion
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius),
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.dilation)
@dims_to_spatial_axes
@record_lazy
def dilation(self, radius: float = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp_ndi.grey_dilation if self.dtype != bool else xp_ndi.binary_dilation
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius),
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.opening)
@dims_to_spatial_axes
@record_lazy
def opening(self, radius: float = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp_ndi.grey_opening if self.dtype != bool else xp_ndi.binary_opening
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius)*2,
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.closing)
@dims_to_spatial_axes
@record_lazy
def closing(self, radius: float = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp_ndi.grey_closing if self.dtype != bool else xp_ndi.binary_closing
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius)*2,
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.gaussian_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@record_lazy
def gaussian_filter(self, sigma: nDFloat = 1.0, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
c_axes = complement_axes(dims, self.axes)
depth = _ceilint(sigma*4)
return self._apply_map_overlap(
xp_ndi.gaussian_filter,
c_axes=c_axes,
depth=depth,
kwargs=dict(sigma=sigma),
)
[docs] @_docs.copy_docs(ImgArray.mean_filter)
@same_dtype(asfloat=True)
@dims_to_spatial_axes
@record_lazy
def mean_filter(self, radius: float = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
kernel = (disk/np.sum(disk)).astype(np.float32)
return self._apply_map_overlap(
xp_ndi.convolve,
depth=_ceilint(radius),
c_axes=complement_axes(dims, self.axes),
kwargs=dict(weights=kernel),
)
[docs] @_docs.copy_docs(ImgArray.convolve)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@record_lazy
def convolve(self, kernel, *, mode: str = "reflect", cval: float = 0, dims: Dims = None,
update: bool = False) -> LazyImgArray:
from dask_image.ndfilters import convolve
kernel = np.asarray(kernel)
shape = np.array(kernel.shape)
half_size = shape // 2
depth = tuple(half_size)
c_axes = complement_axes(dims, self.axes)
return self._apply_map_overlap(
xp_ndi.convolve,
c_axes=c_axes,
depth=depth,
kwargs=dict(weights=kernel, mode=mode, cval=cval),
)
[docs] @_docs.copy_docs(ImgArray.edge_filter)
@dims_to_spatial_axes
@same_dtype
@record_lazy
def edge_filter(self, method: str = "sobel", *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
# BUG: returns zero array
from ._utils._skimage import skfil
method_dict = {"sobel": (skfil.sobel, 1),
"farid": (skfil.farid, 2),
"scharr": (skfil.scharr, 1),
"prewitt": (skfil.prewitt, 1)}
try:
filter_func, depth = method_dict[method]
except KeyError:
raise ValueError("`method` must be 'sobel', 'farid' 'scharr', or 'prewitt'.")
return self._apply_map_overlap(
filter_func,
depth=depth,
c_axes=complement_axes(dims, self.axes)
)
[docs] @_docs.copy_docs(ImgArray.laplacian_filter)
@dims_to_spatial_axes
@same_dtype
@record_lazy
def laplacian_filter(self, radius: int = 1, *, dims: Dims = None, update: bool = False
) -> LazyImgArray:
ndim = len(dims)
_, laplace_op = skres.uft.laplacian(ndim, (2*radius+1,) * ndim)
return self._apply_map_overlap(
xp_ndi.convolve,
depth=_ceilint(radius),
c_axes=complement_axes(dims, self.axes),
args=(laplace_op,),
)
[docs] @_docs.copy_docs(ImgArray.affine)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@record_lazy
def affine(self, matrix=None, scale=None, rotation=None, shear=None, translation=None, *,
mode="constant", cval=0, output_shape=None, order=1, dims=None) -> LazyImgArray:
if matrix is None:
matrix = _transform.compose_affine_matrix(scale=scale, rotation=rotation,
shear=shear, translation=translation,
ndim=len(dims))
from dask_image.ndinterp import affine_transform
return self._apply_dask_filter(
affine_transform,
c_axes=complement_axes(dims, self.axes),
kwargs=dict(matrix=matrix, mode=mode, cval=cval,
output_shape=output_shape, order=order)
)
[docs] @_docs.copy_docs(ImgArray.kalman_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@record_lazy
def kalman_filter(self, gain: float = 0.8, noise_var: float = 0.05, *, along: str = "t",
dims: Dims = None, update: bool = False) -> LazyImgArray:
if self.axisof(along) != 0:
raise ValueError("Currently kalman_filter does not support t-axis != 0.")
return self._apply_map_blocks(
_filters.kalman_filter,
c_axes=complement_axes(along + dims, self.axes),
args=(gain, noise_var)
)
[docs] @_docs.copy_docs(ImgArray.fft)
@dims_to_spatial_axes
@record_lazy
def fft(self, *, shape: int | Iterable[int] | str = "same", shift: bool = True,
dims: Dims = None) -> LazyImgArray:
from dask import array as da
axes = [self.axisof(a) for a in dims]
if shape == "square":
s = 2**int(np.ceil(np.max(self.sizesof(dims))))
shape = (s,) * len(dims)
elif shape == "same":
shape = None
else:
shape = check_nd(shape, len(dims))
freq = da.fft.fftn(self.value.astype(np.float32), s=shape, axes=axes).astype(np.complex64)
if shift:
freq[:] = da.fft.fftshift(freq, axes=axes)
return freq
[docs] @_docs.copy_docs(ImgArray.ifft)
@dims_to_spatial_axes
@record_lazy
def ifft(self, real:bool=True, *, shift:bool=True, dims=None) -> LazyImgArray:
from dask import array as da
axes = [self.axisof(a) for a in dims]
if shift:
freq = da.fft.ifftshift(self.value, axes=axes)
else:
freq = self.value
out = da.fft.ifftn(freq, axes=axes).astype(np.complex64)
if real:
out = da.real(out)
return out
[docs] @_docs.copy_docs(ImgArray.power_spectra)
@dims_to_spatial_axes
@record_lazy
def power_spectra(self, shape = "same", norm: bool = False, zero_norm: bool = False, *,
dims: Dims = None) -> LazyImgArray:
freq = self.fft(dims=dims, shape=shape)
pw = freq.value.real**2 + freq.value.imag**2
if norm:
pw /= pw.max()
if zero_norm:
sl = switch_slice(dims, pw.axes, ifin=np.array(pw.shape)//2, ifnot=slice(None))
pw[sl] = 0
return pw
[docs] def chunksizeof(self, axis:str):
return self.value.chunksize[self.axes.find(axis)]
[docs] def chunksizesof(self, axes:str):
return tuple(self.chunksizeof(a) for a in axes)
[docs] def transpose(self, axes):
if self.axes.is_none():
new_axes = None
else:
new_axes = "".join([self.axes[i] for i in list(axes)])
out = self.__class__(self.value.transpose(axes))
out._set_info(self, new_axes=new_axes)
return out
[docs] def sort_axes(self):
order = self.axes.argsort()
return self.transpose(tuple(order))
[docs] @_docs.copy_docs(LabeledArray.crop_center)
@dims_to_spatial_axes
@record_lazy
def crop_center(self, scale=0.5, *, dims=2) -> LazyImgArray:
# check scale
if hasattr(scale, "__iter__") and len(scale) == 3 and len(dims) == 2:
dims = "zyx"
scale = np.asarray(check_nd(scale, len(dims)))
if np.any((scale <= 0) | (1 < scale)):
raise ValueError(f"scale must be (0, 1], but got {scale}")
# Make axis-targeted slicing string
sizes = self.sizesof(dims)
slices = []
for a, size, sc in zip(dims, sizes, scale):
x0 = int(size / 2 * (1 - sc))
x1 = int(np.ceil(size / 2 * (1 + sc)))
slices.append(f"{a}={x0}:{x1}")
out = self[";".join(slices)]
return out
[docs] @_docs.copy_docs(ImgArray.tiled_lowpass_filter)
@dims_to_spatial_axes
@record_lazy
def tiled_lowpass_filter(self, cutoff: float = 0.2, order: int = 2, overlap: int = 16, *,
dims: Dims = None, update: bool = False) -> LazyImgArray:
from ._utils._skimage import _get_ND_butterworth_filter
self = self.as_float()
cutoff = check_nd(cutoff, len(dims))
c_axes = complement_axes(dims, self.axes)
if all((c >= 0.5 or c <= 0) for c in cutoff):
return self
depth = switch_slice(dims, self.axes, overlap, 0)
def func(arr):
arr = xp.asarray(arr)
shape = arr.shape
weight = _get_ND_butterworth_filter(shape, cutoff, order, False, True)
ft = weight * xp_fft.rfftn(arr)
ift = xp_fft.irfftn(ft, s=shape)
return ift
out = self._apply_map_overlap(func, c_axes=c_axes, depth=depth, boundary="reflect")
return out
[docs] @_docs.copy_docs(ImgArray.proj)
@same_dtype
def proj(self, axis: str = None, method: str = "mean") -> LazyImgArray:
from dask import array as da
if axis is None:
axis = find_first_appeared("ztpi<c", include=self.axes, exclude="yx")
elif not isinstance(axis, str):
raise TypeError("`axis` must be str.")
axisint = [self.axisof(a) for a in axis]
if method == "mean":
projection = getattr(da, method)(self.value, axis=tuple(axisint), dtype=np.float32)
else:
projection = getattr(da, method)(self.value, axis=tuple(axisint))
out = self.__class__(projection)
out._set_info(self, f"proj(axis={axis}, method={method})", del_axis(self.axes, axisint))
return out
[docs] @_docs.copy_docs(ImgArray.binning)
@dims_to_spatial_axes
@same_dtype
def binning(self, binsize: int = 2, method = "mean", *, check_edges: bool = True, dims: Dims = None) -> LazyImgArray:
if binsize == 1:
return self
if isinstance(method, str):
binfunc = getattr(xp, method)
elif callable(method):
binfunc = method
else:
raise TypeError("`method` must be a numpy function or callable object.")
img_to_reshape, shape, scale_ = _misc.adjust_bin(self.value, binsize, check_edges, dims, self.axes)
reshaped_img = img_to_reshape.reshape(shape)
axes_to_reduce = tuple(i*2+1 for i in range(self.ndim))
out = binfunc(reshaped_img, axis=axes_to_reduce)
out = self.__class__(out)
out._set_info(self, f"binning(binsize={binsize})")
out.axes = str(self.axes) # _set_info does not pass copy so new axes must be defined here.
out.set_scale({a: self.scale[a]/scale for a, scale in zip(self.axes, scale_)})
return out
[docs] @_docs.copy_docs(ImgArray.track_drift)
def track_drift(self, along: str = None, upsample_factor: int = 10) -> "da.core.Array":
if along is None:
along = find_first_appeared("tpzc<i", include=self.axes)
elif len(along) != 1:
raise ValueError("`along` must be single character.")
dims = complement_axes(along, self.axes)
chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1)
img_fft = self.fft(shift=False, dims=dims).value.rechunk(chunks)
ndim = len(dims)
slice_out = (np.newaxis, slice(None)) + (np.newaxis,)*(ndim-1)
each_shape = (1, ndim) + (1,)*(ndim-1)
len_t = self.sizeof(along)
def pcc(x):
if x.shape[0] < 2:
return np.array([0]*ndim, dtype=np.float32).reshape(*each_shape)
x = xp.asarray(x)
result = _corr.subpixel_pcc(x[0], x[1], upsample_factor=upsample_factor)
return asnumpy(result[slice_out])
from dask import array as da
# I don't know the reason why but output dask array's chunk size along t-axis should be
# specified to be 1, and rechunk it map_overlap.
result = da.map_overlap(pcc, img_fft,
depth={0: (1, 0)},
trim=False,
boundary="none",
chunks=(1, ndim) + (1,)*(ndim-1),
meta=np.array([], dtype=np.float32)
)
# For cupy, we must call map_blocks (or from_delayed and delayed) here.
result = da.map_blocks(np.cumsum, result[..., 0].rechunk((len_t, ndim)),
axis=0, meta=np.array([], dtype=np.float32)
)
return result
[docs] @_docs.copy_docs(ImgArray.drift_correction)
@same_dtype(asfloat=True)
@record_lazy
@dims_to_spatial_axes
def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *,
zero_ave: bool = True, along: str = None, dims: Dims = 2,
update: bool = False, **affine_kwargs) -> LazyImgArray:
if along is None:
along = find_first_appeared("tpzcia", include=self.axes, exclude=dims)
elif len(along) != 1:
raise ValueError("`along` must be single character.")
from ..frame import MarkerFrame
if shift is None:
# determine 'ref'
if ref is None:
ref = self
_dims = complement_axes(along, self.axes)
if dims != _dims:
warn(f"dims={dims} with along={along} and {self.axes}-image are not "
f"valid input. Changed to dims={_dims}",
UserWarning)
dims = _dims
elif not isinstance(ref, self.__class__):
raise TypeError(f"'ref' must be LazyImgArray object, but got {type(ref)}")
elif ref.axes != along + dims:
raise ValueError(f"Arguments `along`({along}) + `dims`({dims}) do not match "
f"axes of `ref`({ref.axes})")
shift = ref.track_drift(along=along)
elif isinstance(shift, MarkerFrame):
if len(shift) != self.sizeof(along):
raise ValueError("Wrong shape of 'shift'.")
shift = shift.values
from dask import array as da
if zero_ave:
shift = shift - da.mean(shift, axis=0)
t_index = self.axisof(along)
slice_in = switch_slice(dims, self.axes, ifin=slice(None), ifnot=0)
slice_out = switch_slice(dims, self.axes, ifin=slice(None), ifnot=np.newaxis)
ndim = len(dims)
# Here shift must be a local variable for the function. Otherwise, it takes dask very long time
# for graph construction.
def warp(arr, shift, block_info=None):
arr = xp.asarray(arr)
mx = xp.eye(ndim+1, dtype=np.float32)
loc = block_info[None]["array-location"][0]
mx[:-1, -1] = -xp.asarray(shift[loc[t_index]])
return asnumpy(
_transform.warp(arr[slice_in], mx, **affine_kwargs)[slice_out]
)
chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1)
out = da.map_blocks(warp, self.value.rechunk(chunks), shift, meta=np.array([], dtype=self.dtype))
return out
[docs] @_docs.copy_docs(ImgArray.pad)
@dims_to_spatial_axes
@record_lazy
def pad(self, pad_width, mode: str = "constant", *, dims: Dims = None, **kwargs) -> LazyImgArray:
pad_width = _misc.make_pad(pad_width, dims, self.axes, **kwargs)
padimg = np.pad(self.value, pad_width, mode, **kwargs)
return padimg
# @_docs.copy_docs(ImgArray.wiener)
# @dims_to_spatial_axes
# @same_dtype(asfloat=True)
# @record_lazy
# def wiener(self, psf: np.ndarray, lmd: float = 0.1, *, depth="auto", dims: Dims = None, update: bool = False) -> LazyImgArray:
# if lmd <= 0:
# raise ValueError(f"lmd must be positive, but got: {lmd}")
# if depth == "auto":
# depth = 32 # TODO: any better way?
# psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims)
# return self._apply_map_overlap
# return self._apply_function(_deconv.wiener,
# c_axes=complement_axes(dims, self.axes),
# rechunk_to="max",
# args=(psf_ft, psf_ft_conj, lmd)
# )
# @_docs.copy_docs(ImgArray.lucy)
# @dims_to_spatial_axes
# @same_dtype(asfloat=True)
# @record_lazy
# def lucy(self, psf: np.ndarray, niter: int = 50, eps: float = 1e-5, depth: int = 32, *, dims: Dims = None,
# update: bool = False) -> LazyImgArray:
# psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims)
# return self._apply_map_overlap(_deconv.richardson_lucy,
# c_axes=complement_axes(dims, self.axes),
# depth=depth,
# boundary="nearest",
# args=(psf_ft, psf_ft_conj, niter, eps)
# )
def __array_function__(self, func, types, args, kwargs):
"""
Every time a numpy function (np.mean...) is called, this function will be called. Essentially numpy
function can be overloaded with this method.
"""
from dask import array as da
args, kwargs = _replace_inputs(self, args, kwargs)
_types = []
for t in types:
if t is self.__class__:
_types.append(da.core.Array)
else:
_types.append(t)
result = self.value.__array_function__(func, _types, args, kwargs)
if result is NotImplemented:
return NotImplemented
if isinstance(result, (tuple, list)):
out = []
for r in result:
if isinstance(r, da.core.Array):
out.append(
self.__class__(r)._process_output(self, args, kwargs)
)
else:
out.append(r)
out = DataList(out)
elif isinstance(result, da.core.Array):
out = self.__class__(result)
out._process_output(self, args, kwargs)
return out
def _process_output(self, input: LazyImgArray, args: tuple, kwargs: dict):
if "axis" in kwargs.keys() and not input.axes.is_none():
new_axes = del_axis(input.axes, kwargs["axis"])
else:
new_axes = "inherit"
self._set_info(input, new_axes=new_axes)
return None
[docs] def as_uint8(self) -> LazyImgArray:
img = self.value
if img.dtype == np.uint8:
return img
if img.dtype == np.uint16:
out = img / 256
elif img.dtype.kind == "f":
out = img + 0.5
out = np.clip(out, 0, 255)
else:
raise TypeError(f"invalid data type: {img.dtype}")
out = out.astype(np.uint8)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_uint16(self) -> LazyImgArray:
img = self.value
if img.dtype == np.uint16:
return img
if img.dtype == np.uint8:
out = img * 256
elif img.dtype == bool:
out = img
elif img.dtype.kind == "f":
out = img + 0.5
out = np.clip(out, 0, 65535)
else:
raise TypeError(f"invalid data type: {img.dtype}")
out = out.astype(np.uint16)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_float(self) -> LazyImgArray:
if self.dtype == np.float32:
return self
out = self.value.astype(np.float32)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_img_type(self, dtype=np.uint16) -> LazyImgArray:
dtype = np.dtype(dtype)
if self.dtype == dtype:
return self
elif dtype == "uint16":
return self.as_uint16()
elif dtype == "uint8":
return self.as_uint8()
elif dtype == "float32":
return self.as_float()
elif dtype == "float64":
warn("Data type float64 is not valid for images. It was converted to float32 instead",
UserWarning)
return self.as_float()
elif dtype == "complex64":
out = self.value.astype(np.complex64)
out = self.__class__(out)
out._set_info(self)
return out
elif dtype == "complex128":
warn("Data type complex128 is not valid for images. It was converted to complex64 instead",
UserWarning)
out = self.value.astype(np.complex64)
out = self.__class__(out)
out._set_info(self)
return out
elif dtype == "int8":
out = self.value.astype(np.int8)
out = self.__class__(out)
out._set_info(self)
return out
else:
raise ValueError(f"dtype: {dtype}")
def _set_additional_props(self, other):
# set additional properties
# If `other` does not have it and `self` has, then the property will be inherited.
for p in self.__class__.additional_props:
setattr(self, p, getattr(other, p,
getattr(self, p,
None)))
def _getitem_additional_set_info(self, other, **kwargs):
keystr = kwargs["keystr"]
self._set_info(other, f"getitem[{keystr}]", kwargs["new_axes"])
return None
def _set_info(self, other: LazyImgArray, next_history = None, new_axes: str = "inherit"):
self._set_additional_props(other)
# set axes
try:
if new_axes != "inherit":
self.axes = new_axes
self.set_scale(other)
else:
self.axes = other.axes.copy()
except ImageAxesError:
self.axes = None
# set history
if next_history is not None:
self.history = other.history + [next_history]
else:
self.history = other.history.copy()
return None
def _replace_inputs(img: LazyImgArray, args, kwargs):
_as_dask_array = lambda a: a.value if isinstance(a, LazyImgArray) else a
# convert arguments
args = tuple(_as_dask_array(a) for a in args)
if "axis" in kwargs:
axis = kwargs["axis"]
if isinstance(axis, str):
_axis = tuple(map(img.axisof, axis))
if len(_axis) == 1:
_axis = _axis[0]
kwargs["axis"] = _axis
if "out" in kwargs:
kwargs["out"] = tuple(_as_dask_array(a) for a in kwargs["out"])
return args, kwargs
def _ceilint(a: float):
return int(np.ceil(a))
[docs]def iter_slice(shape, iteraxes: str, all_axes: str, exclude: str = ""):
ndim = len(all_axes)
iterlist = switch_slice(axes=iteraxes,
all_axes=all_axes,
ifin=[range(s) for s in shape],
ifnot=[(slice(None),)]*ndim)
it = itertools.product(*iterlist)
c = 0 # counter
for sl in it:
if len(exclude) == 0:
outsl = sl
else:
outsl = tuple(s for i, s in enumerate(sl)
if all_axes[i] not in exclude)
yield outsl
c += 1
# if iterlist = []
if c == 0:
outsl = (slice(None),) * (len(all_axes) - len(exclude))
yield outsl