from __future__ import annotations
from functools import lru_cache, wraps
import itertools
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING, Sequence
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from warnings import warn
import tempfile
import operator
from .labeledarray import LabeledArray
from .imgarray import ImgArray
from .axesmixin import AxesMixin
from ._utils import _misc, _transform, _structures, _filters, _deconv, _corr, _docs
from impy.utils.axesop import slice_axes, switch_slice, complement_axes, find_first_appeared
from impy.utils.deco import check_input_and_output_lazy, dims_to_spatial_axes, same_dtype
from impy.utils.misc import check_nd
from impy.utils import slicer
from impy.io import imsave
from impy.collections import DataList
from impy.axes import ImageAxesError, AxisLike
from impy.array_api import xp
from impy._types import nDFloat, Coords, Iterable, Dims, PaddingMode
from impy._const import Const
if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
[docs]class LazyImgArray(AxesMixin):
additional_props = ["_source", "_metadata", "_name"]
def __init__(
self,
obj: "DaskArray",
name: str = None,
axes: str = None,
source: str = None,
metadata: dict = None
):
from dask.array.core import Array as DaskArray
if not isinstance(obj, DaskArray):
raise TypeError(f"The first input must be dask array, got {type(obj)}")
self.value: "DaskArray" = obj
self.source = source
self._name = name
self.axes = axes
self._metadata = metadata or {}
@property
def source(self):
"""The source file path."""
return self._source
@source.setter
def source(self, val):
if val is None:
self._source = None
else:
self._source = Path(val)
@property
def name(self) -> str:
"""Name of the array."""
if self._name is None:
source = self.source
if source is None:
return "No name"
else:
return source.name
else:
return self._name
@name.setter
def name(self, val):
self._name = str(val)
@property
def metadata(self) -> dict[str, Any]:
"""Metadata dictionary of the array."""
return self._metadata
@property
def ndim(self) -> int:
"""Number of dimensions of the array."""
return self.value.ndim
@property
def shape(self):
"""Shape of the array."""
return self.axes.tuple(self.value.shape)
@property
def dtype(self):
"""Data type of the array"""
return self.value.dtype
@property
def size(self) -> int:
"""Total number of elements in the array."""
return self.value.size
@property
def itemsize(self):
"""Size of each element in bytes."""
return self.value.itemsize
@property
def chunksize(self):
"""Chunk size of the array."""
return self.axes.tuple(self.value.chunksize)
@property
def GB(self) -> float:
"""Return the array size in GB."""
return self.value.nbytes / 1e9
gb = GB # alias
def __array__(self):
# Should not be `self.compute` because in napari Viewer this function is called every time
# sliders are moved.
return xp.asnumpy(self.value.compute())
def __getitem__(self, key):
key = slicer.solve_slicer(key, self.axes)
new_axes = slice_axes(self.axes, key)
out = self.__class__(
self.value[key],
name=self.name,
source=self.source,
axes=new_axes,
metadata=self.metadata
)
out._getitem_additional_set_info(
self, new_axes=new_axes, key=key
)
return out
def __neg__(self) -> LazyImgArray:
"""Invert array."""
out = self.__class__(-self.value)
out._set_info(self)
return out
def _apply_operator(self, op: Callable[[Any, Any], Any], other: Any) -> LazyImgArray:
if isinstance(other, self.__class__):
out = op(self.value, other.value)
else:
out = op(self.value, other)
out = self.__class__(out)
out._set_info(self)
return out
@same_dtype(asfloat=True)
def __add__(self, other) -> LazyImgArray:
return self._apply_operator(operator.add, other)
@same_dtype(asfloat=True)
def __iadd__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
self.value += other.value
else:
self.value += other
return self
@same_dtype(asfloat=True)
def __sub__(self, other) -> LazyImgArray:
return self._apply_operator(operator.sub, other)
@same_dtype(asfloat=True)
def __isub__(self, other) -> LazyImgArray:
if isinstance(other, self.__class__):
self.value -= other.value
else:
self.value -= other
return self
@same_dtype(asfloat=True)
def __mul__(self, other) -> LazyImgArray:
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other = other.value
else:
other = other
out = self.value * other
out = self.__class__(out)
return out
@same_dtype(asfloat=True)
def __imul__(self, other) -> LazyImgArray:
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other = other.value
else:
other = other
self.value *= other
return self
def __truediv__(self, other) -> LazyImgArray:
self = self.as_float()
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other[other==0] = np.inf
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other[other==0] = np.inf
other = other.value
elif np.isscalar(other) and other <= 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
out = self.value / other
out = self.__class__(out)
out._set_info(self)
return out
def __itruediv__(self, other) -> LazyImgArray:
if self.dtype.kind in "ui":
raise ValueError("Cannot divide integer inplace.")
if isinstance(other, np.ndarray) and other.dtype.kind != "c":
other = other.astype(np.float32)
other[other==0] = np.inf
other = other
elif isinstance(other, self.__class__) and other.dtype.kind != "c":
other = other.as_float()
other[other==0] = np.inf
other = other.value
elif np.isscalar(other) and other < 0:
raise ValueError("Cannot multiply negative value.")
else:
other = other
self.value /= other
return self
def __gt__(self, other) -> LazyImgArray:
return self._apply_operator(operator.gt, other)
def __ge__(self, other) -> LazyImgArray:
return self._apply_operator(operator.ge, other)
def __lt__(self, other) -> LazyImgArray:
return self._apply_operator(operator.lt, other)
def __le__(self, other) -> LazyImgArray:
return self._apply_operator(operator.le, other)
def __eq__(self, other) -> LazyImgArray:
return self._apply_operator(operator.eq, other)
def __ne__(self, other) -> LazyImgArray:
return self._apply_operator(operator.ne, other)
def __mod__(self, other) -> LazyImgArray:
return self._apply_operator(operator.mod, other)
def __floordiv__(self, other) -> LazyImgArray:
return self._apply_operator(operator.floordiv, other)
def __pow__(self, power: float) -> LazyImgArray:
return self._apply_operator(operator.pow, power)
@property
def chunk_info(self):
chunk_info = ", ".join([f"{s}({o})" for s, o in zip(self.chunksize, self.axes)])
return chunk_info
def _repr_dict_(self):
return {
"name": self.name,
"shape": self.shape_info,
"chunk sizes": self.chunk_info,
"dtype": self.dtype,
"source": self.source,
"scale": self.scale,
}
[docs] def compute(self, ignore_limit: bool = False) -> ImgArray:
"""
Compute all the task and convert the result into ImgArray. If image size overwhelms MAX_GB
then MemoryError is raised.
"""
if self.gb > Const["MAX_GB"] and not ignore_limit:
raise MemoryError(f"Too large: {self.gb:.2f} GB")
arr: xp.ndarray = self.value.compute()
if arr.ndim > 0:
img = xp.asnumpy(arr).view(ImgArray)
for attr in ["_name", "_source", "axes", "_metadata"]:
setattr(img, attr, getattr(self, attr, None))
else:
img = arr.item()
return img
[docs] def release(self, update: bool = True) -> LazyImgArray:
"""
Compute all the tasks and store the data in memory map, and read it as a dask array
again.
"""
from dask import array as da
with tempfile.NamedTemporaryFile() as ntf:
mmap = np.memmap(ntf, mode="w+", shape=self.shape, dtype=self.dtype)
da.store(self.value, mmap, compute=True)
mmap.flush()
img = da.from_array(mmap, chunks=self.chunksize).map_blocks(
np.array, meta=np.array([], dtype=self.dtype)
)
if update:
self.value = img
out = self
else:
out = self.__class__(img)
out._set_info(self)
return out
[docs] @_docs.copy_docs(LabeledArray.imsave)
def imsave(self, save_path: str | Path, *, dtype = None, overwrite: bool = True):
save_path = Path(save_path)
if self.ndim < 2:
raise ValueError("Cannot save <2D array as an image.")
if save_path.suffix == "":
if self.source is not None:
ext = self.source.suffix
if ext == "":
ext = ".tif"
else:
ext = ".tif"
save_path = save_path.parent / (save_path.name + ext)
if not Path(save_path).is_absolute():
if self.source is None:
raise ValueError(
"Image directory path is unknown. Set by \n"
" >>> img.source = \"...\"\n"
"or specify absolute path like\n"
" >>> img.imsave(\"/path/to/XXX.tif\")"
)
save_path = self.source.parent / save_path
if not overwrite and save_path.exists():
raise FileExistsError(f"File {save_path!r} already exists.")
if dtype is None:
dtype = self.dtype
self = self.as_img_type(dtype).sort_axes()
imsave(save_path, self, lazy=True)
return None
[docs] def rechunk(
self,
chunks="auto",
*,
threshold=None,
block_size_limit=None,
balance=False,
update=False
) -> LazyImgArray:
"""
Rechunk the bound dask array.
Parameters
----------
chunks, threshold, block_size_limit, balance
Passed directly to dask.array's rechunk
Returns
-------
LazyImgArray
Rechunked dask array is bound.
"""
rechunked = self.value.rechunk(
chunks=chunks, threshold=threshold,
block_size_limit=block_size_limit, balance=balance
)
if update:
self.value = rechunked
return self
else:
out = self.__class__(rechunked)
out._set_info(self)
return out
[docs] def apply_dask_func(self, funcname: str, *args, **kwargs) -> LazyImgArray:
"""
Apply dask array function to the connected dask array.
Parameters
----------
funcname : str
Name of function to apply.
args, kwargs :
Parameters that will be passed to `funcname`.
Returns
-------
LazyImgArray
Updated one
"""
out = getattr(self.value, funcname)(*args, **kwargs)
out = self.__class__(out)
new_axes = self._INHERIT if out.shape == self.shape else None
out._set_info(self, new_axes=new_axes)
return out
def _apply_function(
self,
func: Callable,
c_axes: str = None,
drop_axis: Iterable[int] = [],
new_axis: Iterable[int] = None,
dtype = np.float32,
rechunk_to: tuple[int, ...] | str = "none",
dask_wrap: bool = False,
args: tuple = None, kwargs: dict[str] = None
) -> LazyImgArray:
"""
Rechunk array in a correct shape and apply function using `map_blocks`. This function is similar
to the `apply_dask` function in `MetaArray` while returns dask array bound LazyImgArray.
Parameters
----------
func : callable
Function to apply for each chunk.
c_axes : str, optional
Axes to iterate.
drop_axis : Iterable[int], optional
Passed to map_blocks.
new_axis : Iterable[int], optional
Passed to map_blocks.
dtype : any that can be converted to np.dtype object, default is np.float32
Output data type.
rechunk_to : tuple[int,...], optional
In what size input array should be rechunked before `map_blocks` iteration. If str is given,
array will be rechunked in following rules:
- "none": No rechunking
- "default": Rechunked with "auto" method for each spatial dimension.
- "max": Rechunked to the shape size for each spatial dimension.
dask_wrap : bool, optional
If True, for each chunk array will be converted to dask and rechunked with "auto" option
before function call.
args : tuple, optional
Arguments that will passed to `func`.
kwargs : dict
Keyword arguments that will passed to `func`.
Returns
-------
LazyImgArray
Dask array after function is applied is bound to this newly generated object.
"""
slice_in = []
slice_out = []
for a in self.axes:
if a in c_axes:
slice_in.append(0)
slice_out.append(np.newaxis)
else:
slice_in.append(slice(None))
slice_out.append(slice(None))
slice_in = tuple(slice_in)
slice_out = tuple(slice_out)
if args is None:
args = tuple()
if kwargs is None:
kwargs = dict()
if rechunk_to == "none":
input_ = self.value
else:
if rechunk_to == "default":
rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot="auto")
elif rechunk_to == "max":
rechunk_to = switch_slice(c_axes, self.axes, ifin=1, ifnot=self.shape)
input_ = self.value.rechunk(rechunk_to)
if dask_wrap:
from dask import array as da
@wraps(func)
def _func(arr, *args, **kwargs):
out = func(da.from_array(arr[slice_in]), *args, **kwargs)
return out[slice_out].compute()
else:
@wraps(func)
def _func(arr, *args, **kwargs):
out = func(arr[slice_in], *args, **kwargs)
return out[slice_out]
out = input_.map_blocks(_func, *args, drop_axis=drop_axis, new_axis=new_axis,
meta=xp.array([], dtype=dtype), **kwargs)
return out
def _apply_map_blocks(
self,
func: Callable,
c_axes: str = None,
args: tuple = None,
kwargs: dict[str, Any] = None
) -> DaskArray:
from dask import array as da
if args is None:
args = ()
if kwargs is None:
kwargs = {}
all_axes = str(self.axes)
def _func(input: ArrayLike, *args, **kwargs):
out = xp.empty(input.shape, input.dtype)
for sl in iter_slice(input.shape, c_axes, all_axes):
out[sl] = func(input[sl], *args, **kwargs)
return out
return da.map_blocks(_func, self.value, dtype=self.dtype, *args, **kwargs)
def _apply_map_overlap(
self,
func: Callable,
c_axes: str = None,
depth: int = 16,
boundary="reflect",
dtype: DTypeLike = None,
args: tuple = None,
kwargs: dict[str, Any] = None
) -> DaskArray:
from dask import array as da
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if dtype is None:
dtype = self.dtype
all_axes = str(self.axes)
def _func(input: xp.ndarray, *args, **kwargs):
out = xp.empty(input.shape, input.dtype)
for sl in iter_slice(input.shape, c_axes, all_axes):
out[sl] = func(input[sl], *args, **kwargs)
return out
depth = switch_slice(c_axes, self.axes, 0, depth)
return da.map_overlap(
_func, self.value, *args, depth=depth, boundary=boundary, dtype=dtype,
**kwargs
)
def _apply_dask_filter(
self,
func: Callable,
c_axes: str = None,
args: tuple = None,
kwargs: dict[str, Any] = None
) -> LazyImgArray:
# TODO: This is not efficient. Maybe using da.stack is better?
from dask import array as da
out = da.empty_like(self.value)
if args is None:
args = ()
if kwargs is None:
kwargs = {}
for sl, img in self.iter(c_axes, israw=False):
out[sl] = func(img, *args, **kwargs)
return out
# TODO: This should wait for dask-image implement map_coordinates
# @_docs.copy_docs(LabeledArray.rotated_crop)
# @dims_to_spatial_axes
# @record_lazy
# def rotated_crop(self, origin, dst1, dst2, dims=2) -> LazyImgArray:
# origin = np.asarray(origin)
# dst1 = np.asarray(dst1)
# dst2 = np.asarray(dst2)
# ax0 = _misc.make_rotated_axis(origin, dst2)
# ax1 = _misc.make_rotated_axis(dst1, origin)
# all_coords = ax0[:, np.newaxis] + ax1[np.newaxis] - origin
# all_coords = np.moveaxis(all_coords, -1, 0)
# Because output shape changes, we have to tell dask what chunk size it would be, otherwise output
# shape is estimated in a wrong way.
# output_chunks = [1] * self.ndim
# for i, a in enumerate(dims):
# it = self.axisof(a)
# output_chunks[it] = all_coords.shape[i+1]
#
# cropped_img = self._apply_function(xp.ndi.map_coordinates,
# c_axes=complement_axes(dims, self.axes),
# dtype=self.dtype,
# rechunk_to="max",
# args=(xp.asarray(all_coords),),
# kwargs=dict(prefilter=False, order=1, chunks=output_chunks)
# )
[docs] @_docs.copy_docs(ImgArray.erosion)
@dims_to_spatial_axes
@check_input_and_output_lazy
def erosion(
self,
radius: float = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp.ndi.grey_erosion if self.dtype != bool else xp.ndi.binary_erosion
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius),
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.dilation)
@dims_to_spatial_axes
@check_input_and_output_lazy
def dilation(
self,
radius: float = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp.ndi.grey_dilation if self.dtype != bool else xp.ndi.binary_dilation
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius),
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.opening)
@dims_to_spatial_axes
@check_input_and_output_lazy
def opening(
self,
radius: float = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp.ndi.grey_opening if self.dtype != bool else xp.ndi.binary_opening
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius)*2,
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.closing)
@dims_to_spatial_axes
@check_input_and_output_lazy
def closing(
self,
radius: float = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
c_axes = complement_axes(dims, self.axes)
filter_func = xp.ndi.grey_closing if self.dtype != bool else xp.ndi.binary_closing
return self._apply_map_overlap(
filter_func,
c_axes=c_axes,
depth=_ceilint(radius)*2,
kwargs=dict(footprint=disk)
)
[docs] @_docs.copy_docs(ImgArray.gaussian_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def gaussian_filter(
self,
sigma: nDFloat = 1.0,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
c_axes = complement_axes(dims, self.axes)
depth = _ceilint(sigma*4)
return self._apply_map_overlap(
xp.ndi.gaussian_filter,
c_axes=c_axes,
depth=depth,
kwargs=dict(sigma=sigma),
)
[docs] @_docs.copy_docs(ImgArray.spline_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def spline_filter(
self,
order: int = 3,
mode: PaddingMode = "mirror",
*,
dims: Dims = None,
update: bool = False,
) -> LazyImgArray:
depth = order
from functools import partial
func = partial(
_filters.spline_filter, order=order, output=np.float32, mode=mode
)
return self._apply_map_overlap(
func,
c_axes=complement_axes(dims, self.axes),
depth=depth,
)
[docs] @_docs.copy_docs(ImgArray.mean_filter)
@same_dtype(asfloat=True)
@dims_to_spatial_axes
@check_input_and_output_lazy
def mean_filter(
self,
radius: float = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
disk = _structures.ball_like(radius, len(dims))
kernel = (disk/np.sum(disk)).astype(np.float32)
return self._apply_map_overlap(
xp.ndi.convolve,
depth=_ceilint(radius),
c_axes=complement_axes(dims, self.axes),
kwargs=dict(weights=kernel),
)
[docs] @_docs.copy_docs(ImgArray.convolve)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def convolve(
self,
kernel,
*,
mode: str = "reflect",
cval: float = 0,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
kernel = np.asarray(kernel)
shape = np.array(kernel.shape)
half_size = shape // 2
depth = tuple(half_size)
c_axes = complement_axes(dims, self.axes)
return self._apply_map_overlap(
xp.ndi.convolve,
c_axes=c_axes,
depth=depth,
kwargs=dict(weights=kernel, mode=mode, cval=cval),
)
[docs] @_docs.copy_docs(ImgArray.edge_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def edge_filter(
self,
method: str = "sobel",
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
from skimage.filters.edges import sobel, farid, scharr, prewitt
method_dict = {
"sobel": (sobel, 1), "farid": (farid, 2), "scharr": (scharr, 1),
"prewitt": (prewitt, 1)
}
try:
filter_func, depth = method_dict[method]
except KeyError:
raise ValueError("`method` must be 'sobel', 'farid' 'scharr', or 'prewitt'.")
return self._apply_map_overlap(
filter_func,
depth=depth,
c_axes=complement_axes(dims, self.axes)
)
[docs] @_docs.copy_docs(ImgArray.laplacian_filter)
@dims_to_spatial_axes
@same_dtype
@check_input_and_output_lazy
def laplacian_filter(
self,
radius: int = 1,
*,
dims: Dims = None,
update: bool = False
) -> LazyImgArray:
from skimage.restoration import uft
ndim = len(dims)
_, laplace_op = uft.laplacian(ndim, (2*radius+1,) * ndim)
return self._apply_map_overlap(
xp.ndi.convolve,
depth=_ceilint(radius),
c_axes=complement_axes(dims, self.axes),
kwargs=dict(weights=laplace_op),
)
[docs] @_docs.copy_docs(ImgArray.affine)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def affine(
self,
matrix=None,
scale=None,
rotation=None,
shear=None,
translation=None,
*,
mode="constant",
cval=0,
output_shape=None,
order=1,
dims=None
) -> LazyImgArray:
if matrix is None:
matrix = _transform.compose_affine_matrix(
scale=scale, rotation=rotation,
shear=shear, translation=translation,
ndim=len(dims)
)
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
from dask_image.ndinterp import affine_transform
return self._apply_dask_filter(
affine_transform,
c_axes=complement_axes(dims, self.axes),
kwargs=dict(matrix=matrix, mode=mode, cval=cval,
output_shape=output_shape, order=order)
)
[docs] @_docs.copy_docs(ImgArray.kalman_filter)
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output_lazy
def kalman_filter(self, gain: float = 0.8, noise_var: float = 0.05, *, along: str = "t",
dims: Dims = None, update: bool = False) -> LazyImgArray:
if self.axisof(along) != 0:
raise ValueError("Currently kalman_filter does not support t-axis != 0.")
return self._apply_map_blocks(
_filters.kalman_filter,
c_axes=complement_axes([along] + dims, self.axes),
args=(gain, noise_var)
)
[docs] @_docs.copy_docs(ImgArray.threshold)
@check_input_and_output_lazy
def threshold(
self,
thr: float | str,
*,
along: AxisLike | None = None,
) -> ImgArray:
if self.dtype == bool:
return self
from dask import array as da
if along is None:
along = "c" if "c" in self.axes else ""
if isinstance(thr, str) and thr.endswith("%"):
p = float(thr[:-1].strip())
out = da.zeros(self.shape, dtype=bool)
for sl, img in self.iter(along):
thr = da.percentile(img, p)
out[sl] = img >= thr
elif np.isscalar(thr) or (hasattr(thr, "compute") and thr.ndim == 0):
out = self >= thr
else:
raise TypeError(
"'thr' must be a scalar or string in 'X%' format."
)
return out
[docs] @_docs.copy_docs(ImgArray.fft)
@dims_to_spatial_axes
@check_input_and_output_lazy
def fft(self, *, shape: int | Iterable[int] | str = "same", shift: bool = True,
dims: Dims = None) -> LazyImgArray:
from dask import array as da
axes = [self.axisof(a) for a in dims]
if shape == "square":
s = 2**int(np.ceil(np.max(self.sizesof(dims))))
shape = (s,) * len(dims)
elif shape == "same":
shape = None
else:
shape = check_nd(shape, len(dims))
fft = _get_fft_func(Const["RESOURCE"])
freq = fft(self.value.astype(np.float32), s=shape, axes=axes)
if shift:
freq[:] = da.fft.fftshift(freq, axes=axes)
return freq
[docs] @_docs.copy_docs(ImgArray.ifft)
@dims_to_spatial_axes
@check_input_and_output_lazy
def ifft(self, real: bool = True, *, shift: bool = True, dims: Dims = None) -> LazyImgArray:
from dask import array as da
axes = [self.axisof(a) for a in dims]
if shift:
freq = da.fft.ifftshift(self.value, axes=axes)
else:
freq = self.value
ifft = _get_ifft_func(Const["RESOURCE"])
out = ifft(freq, axes=axes)
if real:
out = da.real(out)
return out
[docs] @_docs.copy_docs(ImgArray.power_spectra)
@dims_to_spatial_axes
@check_input_and_output_lazy
def power_spectra(self, shape = "same", norm: bool = False, zero_norm: bool = False, *,
dims: Dims = None) -> LazyImgArray:
freq = self.fft(dims=dims, shape=shape)
pw = freq.value.real ** 2 + freq.value.imag ** 2
if norm:
pw /= pw.max()
if zero_norm:
sl = switch_slice(dims, self.axes, ifin=np.array(self.shape)//2, ifnot=slice(None))
pw[sl] = 0
return pw
[docs] def chunksizeof(self, axis:str):
"""Get the chunk size of the given axis"""
return self.value.chunksize[self.axes.find(axis)]
[docs] def chunksizesof(self, axes:str):
"""Get the chunk sizes of the given axes"""
return tuple(self.chunksizeof(a) for a in axes)
[docs] @_docs.copy_docs(ImgArray.transpose)
def transpose(self, axes=None):
if axes is None:
_axes = None
new_axes = self.axes[::-1]
else:
_axes = [self.axisof(a) for a in axes]
new_axes = [self.axes[i] for i in list(axes)]
out = self.__class__(self.value.transpose(_axes))
out._set_info(self, new_axes=new_axes)
return out
[docs] @_docs.copy_docs(ImgArray.sort_axes)
def sort_axes(self):
order = self.axes.argsort()
return self.transpose(tuple(order))
[docs] @_docs.copy_docs(LabeledArray.crop_center)
@dims_to_spatial_axes
@check_input_and_output_lazy
def crop_center(self, scale=0.5, *, dims=2) -> LazyImgArray:
# check scale
if hasattr(scale, "__iter__") and len(scale) == 3 and len(dims) == 2:
dims = "zyx"
scale = np.asarray(check_nd(scale, len(dims)))
if np.any((scale <= 0) | (1 < scale)):
raise ValueError(f"scale must be (0, 1], but got {scale}")
# Make axis-targeted slicing string
sizes = self.sizesof(dims)
slices = []
for a, size, sc in zip(dims, sizes, scale):
x0 = int(size / 2 * (1 - sc))
x1 = int(np.ceil(size / 2 * (1 + sc)))
slices.append(f"{a}={x0}:{x1}")
out = self[";".join(slices)]
return out
[docs] @_docs.copy_docs(ImgArray.tiled_lowpass_filter)
@dims_to_spatial_axes
@check_input_and_output_lazy
def tiled_lowpass_filter(self, cutoff: float = 0.2, order: int = 2, overlap: int = 16, *,
dims: Dims = None, update: bool = False) -> LazyImgArray:
from ._utils._skimage import _get_ND_butterworth_filter
self = self.as_float()
c_axes = complement_axes(dims, self.axes)
ndims = len(dims)
cutoff = check_nd(cutoff, ndims)
if all((c >= 0.5*np.sqrt(ndims) or c <= 0) for c in cutoff):
return self
depth = switch_slice(dims, self.axes, overlap, 0)
def func(arr):
arr = xp.asarray(arr)
shape = arr.shape
weight = _get_ND_butterworth_filter(shape, cutoff, order, False, True)
ft = xp.asarray(weight) * xp.fft.rfftn(arr)
ift = xp.fft.irfftn(ft, s=shape)
return ift
out = self._apply_map_overlap(func, c_axes=c_axes, depth=depth, boundary="reflect")
return out
[docs] @_docs.copy_docs(ImgArray.proj)
@same_dtype
def proj(self, axis: str = None, method: str = "mean") -> LazyImgArray:
from dask import array as da
if axis is None:
axis = find_first_appeared("ztpi<c", include=self.axes, exclude="yx")
elif not isinstance(axis, str):
raise TypeError("`axis` must be str.")
axisint = [self.axisof(a) for a in axis]
if method == "mean":
projection = getattr(da, method)(self.value, axis=tuple(axisint), dtype=np.float32)
else:
projection = getattr(da, method)(self.value, axis=tuple(axisint))
out = self.__class__(projection)
out._set_info(self, self.axes.drop(axisint))
return out
[docs] @_docs.copy_docs(ImgArray.binning)
@dims_to_spatial_axes
@same_dtype
def binning(self, binsize: int = 2, method = "mean", *, check_edges: bool = True,
dims: Dims = None) -> LazyImgArray:
if binsize == 1:
return self
if isinstance(method, str):
binfunc = getattr(xp, method)
elif callable(method):
binfunc = method
else:
raise TypeError("`method` must be a numpy function or callable object.")
img_to_reshape, shape, scale_ = _misc.adjust_bin(self.value, binsize, check_edges, dims, self.axes)
reshaped_img = img_to_reshape.reshape(shape)
axes_to_reduce = tuple(i*2+1 for i in range(self.ndim))
out = binfunc(reshaped_img, axis=axes_to_reduce)
out = self.__class__(out)
out._set_info(self)
out.axes = str(self.axes) # _set_info does not pass copy so new axes must be defined here.
out.set_scale({a: self.scale[a]/scale for a, scale in zip(self.axes, scale_)})
return out
[docs] @_docs.copy_docs(ImgArray.track_drift)
def track_drift(self, along: str = None, upsample_factor: int = 10) -> "DaskArray":
if along is None:
along = find_first_appeared("tpzc<i", include=self.axes)
elif len(along) != 1:
raise ValueError("`along` must be single character.")
dims = complement_axes(along, self.axes)
chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1)
img_fft = self.fft(shift=False, dims=dims).value.rechunk(chunks)
ndim = len(dims)
slice_out = (np.newaxis, slice(None)) + (np.newaxis,)*(ndim-1)
each_shape = (1, ndim) + (1,)*(ndim-1)
len_t = self.sizeof(along)
def pcc(x):
if x.shape[0] < 2:
return np.array([0]*ndim, dtype=np.float32).reshape(*each_shape)
x = xp.asarray(x)
result = _corr.subpixel_pcc(x[0], x[1], upsample_factor=upsample_factor)[0]
return xp.asnumpy(result[slice_out])
from dask import array as da
# I don't know the reason why but output dask array's chunk size along t-axis should be
# specified to be 1, and rechunk it map_overlap.
result = da.map_overlap(
pcc,
img_fft,
depth={0: (1, 0)},
trim=False,
boundary="none",
chunks=(1, ndim) + (1,)*(ndim-1),
meta=np.array([], dtype=np.float32)
)
# For cupy, we must call map_blocks (or from_delayed and delayed) here.
result = da.map_blocks(
np.cumsum,
result[..., 0].rechunk((len_t, ndim)),
axis=0,
meta=np.array([], dtype=np.float32)
)
return result
[docs] @_docs.copy_docs(ImgArray.drift_correction)
@same_dtype(asfloat=True)
@check_input_and_output_lazy
@dims_to_spatial_axes
def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *,
zero_ave: bool = True, along: str = None, dims: Dims = 2,
update: bool = False, **affine_kwargs) -> LazyImgArray:
if along is None:
along = find_first_appeared("tpzcia", include=self.axes, exclude=dims)
elif len(along) != 1:
raise ValueError("`along` must be single character.")
from impy.frame import MarkerFrame
if shift is None:
# determine 'ref'
if ref is None:
ref = self
_dims = complement_axes(along, self.axes)
if dims != _dims:
warn(f"dims={dims} with along={along} and {self.axes}-image are not "
f"valid input. Changed to dims={_dims}",
UserWarning)
dims = _dims
elif not isinstance(ref, self.__class__):
raise TypeError(f"'ref' must be LazyImgArray object, but got {type(ref)}")
elif ref.axes != along + dims:
raise ValueError(f"Arguments `along`({along}) + `dims`({dims}) do not match "
f"axes of `ref`({ref.axes})")
shift = ref.track_drift(along=along)
elif isinstance(shift, MarkerFrame):
if len(shift) != self.sizeof(along):
raise ValueError("Wrong shape of 'shift'.")
shift = shift.values
from dask import array as da
if zero_ave:
shift = shift - da.mean(shift, axis=0)
t_index = self.axisof(along)
slice_in = switch_slice(dims, self.axes, ifin=slice(None), ifnot=0)
slice_out = switch_slice(dims, self.axes, ifin=slice(None), ifnot=np.newaxis)
ndim = len(dims)
# Here shift must be a local variable for the function. Otherwise, it takes dask very long time
# for graph construction.
def warp(arr, shift, block_info=None):
arr = xp.asarray(arr)
mx = xp.eye(ndim+1, dtype=np.float32)
loc = block_info[None]["array-location"][0]
mx[:-1, -1] = -xp.asarray(shift[loc[t_index]])
return xp.asnumpy(
_transform.warp(arr[slice_in], mx, **affine_kwargs)[slice_out]
)
chunks = switch_slice(dims, self.axes, ifin=self.shape, ifnot=1)
out = da.map_blocks(warp, self.value.rechunk(chunks), shift, meta=np.array([], dtype=self.dtype))
return out
[docs] @_docs.copy_docs(ImgArray.radon)
@dims_to_spatial_axes
def radon(
self,
degrees: float | Iterable[float],
*,
central_axis: AxisLike | Sequence[float] | None = None,
order: int = 3,
dims: Dims = None,
) -> LazyImgArray:
from dask import array as da
from dask_image.ndinterp import affine_transform
params, output_shape, squeeze = _transform.normalize_radon_input(
self, dims, central_axis, degrees
)
# apply spline filter in advance.
input = self.as_float().spline_filter(order=order)
tasks: list[DaskArray] = [
affine_transform(
input.value, p, order=order, output_shape=output_shape, prefilter=False
).sum(axis=0)
for p in params
]
out = LazyImgArray(da.stack(tasks, axis=0))
out._set_info(self, self.axes.drop(0).insert(0, "degree"))
out.axes[0].labels = list(degrees)
if squeeze:
out = out[0]
return out
[docs] @_docs.copy_docs(ImgArray.iradon)
def iradon(
self,
degrees: Sequence[float],
*,
central_axis: AxisLike = "y",
degree_axis: AxisLike | None = None,
height_axis: AxisLike | None = None,
height: int | None = None,
window: str = "hamming",
order: int = 3,
) -> LazyImgArray:
from dask import delayed, array as da
interp = {0: "nearest", 1: "linear", 3: "cubic"}[order]
central_axis, degree_axis, output_shape, new_axes = _transform.normalize_iradon_input(
self, central_axis, height_axis, degree_axis, height
)
input = np.moveaxis(self.value, self.axisof(degree_axis), -1)
filter_func = _transform.get_fourier_filter(input.shape[-2], window)
func = delayed(_transform.iradon)
if self.ndim == 3:
arrays = [
da.from_delayed(
func(
image_slice,
degrees=degrees,
interpolation=interp,
filter_func=filter_func,
output_shape=output_shape,
),
shape=self.shape[1:],
dtype=self.dtype,
)
for image_slice in input
]
out = da.stack(arrays, axis=0)
out = np.moveaxis(out, 0, 1)
else:
out = da.from_delayed(
func(
input,
degrees=degrees,
interpolation=interp,
filter_func=filter_func,
output_shape=output_shape,
),
shape=output_shape,
dtype=self.dtype,
)
out = out[::-1]
out = self.__class__(out, axes=new_axes, name=self.name, source=self.source)
out._set_info(self, new_axes=new_axes)
return out
[docs] @_docs.copy_docs(ImgArray.pad)
@dims_to_spatial_axes
@check_input_and_output_lazy
def pad(
self,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
mode: str = "constant",
*,
dims: Dims = None,
**kwargs,
) -> LazyImgArray:
pad_width = _misc.make_pad(pad_width, dims, self.axes, **kwargs)
padimg = np.pad(self.value, pad_width, mode, **kwargs)
return padimg
[docs] @_docs.copy_docs(ImgArray.wiener)
@same_dtype(asfloat=True)
@check_input_and_output_lazy
@dims_to_spatial_axes
def wiener(
self,
psf: np.ndarray,
lmd: float = 0.1,
*,
dims: Dims = None,
update: bool = False,
) -> LazyImgArray:
if lmd <= 0:
raise ValueError(f"lmd must be positive, but got: {lmd}")
psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims)
return self._apply_function(
_deconv.wiener,
c_axes=complement_axes(dims, self.axes),
args=(psf_ft, psf_ft_conj, lmd)
)
[docs] @_docs.copy_docs(ImgArray.lucy)
@same_dtype(asfloat=True)
@check_input_and_output_lazy
@dims_to_spatial_axes
def lucy(
self,
psf: np.ndarray,
niter: int = 50,
eps: float = 1e-5,
*,
dims: Dims = None,
update: bool = False,
) -> LazyImgArray:
psf_ft, psf_ft_conj = _deconv.check_psf(self, psf, dims)
return self._apply_function(
_deconv.richardson_lucy,
c_axes=complement_axes(dims, self.axes),
args=(psf_ft, psf_ft_conj, niter, eps)
)
def __array_function__(self, func, types, args, kwargs):
"""
Every time a numpy function (np.mean...) is called, this function will be called. Essentially numpy
function can be overloaded with this method.
"""
from dask.array.core import Array as DaskArray
args, kwargs = _replace_inputs(self, args, kwargs)
_types = []
for t in types:
if t is self.__class__:
_types.append(DaskArray)
else:
_types.append(t)
result = self.value.__array_function__(func, _types, args, kwargs)
if result is NotImplemented:
return NotImplemented
if isinstance(result, (tuple, list)):
out = []
for r in result:
if isinstance(r, DaskArray):
out.append(
self.__class__(r)._process_output(self, args, kwargs)
)
else:
out.append(r)
out = DataList(out)
elif isinstance(result, DaskArray):
out = self.__class__(result)
out._process_output(self, args, kwargs)
return out
def _process_output(self, input: LazyImgArray, args: tuple, kwargs: dict):
if "axis" in kwargs.keys():
new_axes = input.axes.drop(kwargs["axis"])
else:
new_axes = self._INHERIT
self._set_info(input, new_axes=new_axes)
return None
[docs] def as_uint8(self) -> LazyImgArray:
img = self.value
if img.dtype == np.uint8:
return img
if img.dtype == np.uint16:
out = img / 256
elif img.dtype.kind == "f":
out = img + 0.5
out = np.clip(out, 0, 255)
else:
raise TypeError(f"invalid data type: {img.dtype}")
out = out.astype(np.uint8)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_uint16(self) -> LazyImgArray:
img = self.value
if img.dtype == np.uint16:
return img
if img.dtype == np.uint8:
out = img * 256
elif img.dtype == bool:
out = img
elif img.dtype.kind == "f":
out = img + 0.5
out = np.clip(out, 0, 65535)
else:
raise TypeError(f"invalid data type: {img.dtype}")
out = out.astype(np.uint16)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_float(self) -> LazyImgArray:
if self.dtype == np.float32:
return self
out = self.value.astype(np.float32)
out = self.__class__(out)
out._set_info(self)
return out
[docs] def as_img_type(self, dtype=np.uint16) -> LazyImgArray:
dtype = np.dtype(dtype)
if self.dtype == dtype:
return self
elif dtype == "uint16":
return self.as_uint16()
elif dtype == "uint8":
return self.as_uint8()
elif dtype == "float32":
return self.as_float()
elif dtype == "float64":
warn("Data type float64 is not valid for images. It was converted to float32 instead",
UserWarning)
return self.as_float()
elif dtype == "complex64":
out = self.value.astype(np.complex64)
out = self.__class__(out)
out._set_info(self)
return out
elif dtype == "complex128":
warn("Data type complex128 is not valid for images. It was converted to complex64 instead",
UserWarning)
out = self.value.astype(np.complex64)
out = self.__class__(out)
out._set_info(self)
return out
elif dtype == "int8":
out = self.value.astype(np.int8)
out = self.__class__(out)
out._set_info(self)
return out
else:
raise ValueError(f"dtype: {dtype}")
[docs] @_docs.copy_docs(ImgArray.mean)
def mean(self, axis=None, keepdims=False, **kwargs):
return np.mean(self, axis=axis, keepdims=keepdims, **kwargs)
[docs] @_docs.copy_docs(ImgArray.std)
def std(self, axis=None, keepdims=False, **kwargs):
return np.std(self, axis=axis, keepdims=keepdims, **kwargs)
[docs] @_docs.copy_docs(ImgArray.sum)
def sum(self, axis=None, keepdims=False, **kwargs):
return np.sum(self, axis=axis, keepdims=keepdims, **kwargs)
[docs] @_docs.copy_docs(ImgArray.max)
def max(self, axis=None, keepdims=False, **kwargs):
return np.max(self, axis=axis, keepdims=keepdims, **kwargs)
[docs] @_docs.copy_docs(ImgArray.min)
def min(self, axis=None, keepdims=False, **kwargs):
return np.min(self, axis=axis, keepdims=keepdims, **kwargs)
def _set_additional_props(self, other):
# set additional properties
# If `other` does not have it and `self` has, then the property will be inherited.
for p in self.__class__.additional_props:
setattr(self, p, getattr(other, p,
getattr(self, p,
None)))
def _getitem_additional_set_info(self, other, **kwargs):
self._set_info(other, kwargs["new_axes"])
return None
def _set_info(self, other: LazyImgArray, new_axes: Any = AxesMixin._INHERIT):
self._set_additional_props(other)
# set axes
try:
if new_axes is not self._INHERIT:
self.axes = new_axes
else:
self.axes = other.axes.copy()
except ImageAxesError:
self.axes = None
return None
def _replace_inputs(img: LazyImgArray, args, kwargs):
_as_dask_array = lambda a: a.value if isinstance(a, LazyImgArray) else a
# convert arguments
args = tuple(_as_dask_array(a) for a in args)
if "axis" in kwargs:
axis = kwargs["axis"]
if isinstance(axis, str):
_axis = tuple(map(img.axisof, axis))
if len(_axis) == 1:
_axis = _axis[0]
kwargs["axis"] = _axis
if "out" in kwargs:
kwargs["out"] = tuple(_as_dask_array(a) for a in kwargs["out"])
return args, kwargs
def _ceilint(a: float):
return int(np.ceil(a))
[docs]def iter_slice(shape, iteraxes: str, all_axes: str, exclude: str = ""):
ndim = len(all_axes)
iterlist = switch_slice(axes=iteraxes,
all_axes=all_axes,
ifin=[range(s) for s in shape],
ifnot=[(slice(None),)]*ndim)
it = itertools.product(*iterlist)
c = 0 # counter
for sl in it:
if len(exclude) == 0:
outsl = sl
else:
outsl = tuple(s for i, s in enumerate(sl)
if all_axes[i] not in exclude)
yield outsl
c += 1
# if iterlist = []
if c == 0:
outsl = (slice(None),) * (len(all_axes) - len(exclude))
yield outsl
# NOTE: following FFT functions needs the "resource" argument to make it work with both
# numpy and cupy.
@lru_cache
def _get_fft_func(resource) -> Callable[[DaskArray], DaskArray]:
"""Get the scipy FFT function for dask."""
from array_api import scipy_fft
from dask import array as da
return da.fft.fft_wrap(scipy_fft.fftn)
@lru_cache
def _get_ifft_func(resource) -> Callable[[DaskArray], DaskArray]:
"""Get the scipy IFFT function for dask."""
from array_api import scipy_fft
from dask import array as da
return da.fft.fft_wrap(scipy_fft.ifftn)