from __future__ import annotations
import numpy as np
from pathlib import Path
import itertools
from functools import partial
from typing import TYPE_CHECKING, Any, Sequence, MutableMapping, Protocol
from warnings import warn
from scipy import ndimage as ndi
from .specials import PropArray
from ._utils._skimage import skmes, skseg
from ._utils import _misc, _docs
from .bases import MetaArray
from .label import Label
from ..utils.misc import check_nd, largest_zeros
from ..utils.axesop import complement_axes, find_first_appeared
from ..utils.deco import check_input_and_output, dims_to_spatial_axes, same_dtype
from ..io import imsave
from ..collections import DataList, DataDict
from ..axes import ImageAxesError, AxesLike, slicer, Axes
from .._types import Dims, nDInt, nDFloat, Callable, Coords, Iterable, PaddingMode
from ..array_api import xp
if TYPE_CHECKING:
from typing_extensions import Self
from ..frame import PathFrame
from numpy.typing import ArrayLike, DTypeLike
from ..roi import RoiList
[docs]class SupportAxesSlicing(Protocol):
"""A protocol that covariate objects must follow."""
@property
def axes(self) -> Axes:
"""Axes object bound to the object."""
def _dimension_matches(self, array: MetaArray) -> bool:
"""Check if self matches array's shape and axes."""
[docs] def copy(self) -> Self:
"""Shallow copy of the object."""
def __getitem__(self, key) -> SupportAxesSlicing | None:
"""Slice object."""
def _slice_by(self, key) -> SupportAxesSlicing | None:
"""Slice object."""
def _fmt_slice(k, length: int):
if isinstance(k, slice):
return slice(*k.indices(length))
else:
return k
[docs]class ArrayCovariates(MutableMapping[str, SupportAxesSlicing]):
"""A dictionary of covariate objects for a MetaArray."""
def __init__(self, data: dict[str, SupportAxesSlicing], parent: MetaArray):
import weakref
self._data = data
self._parent_ref = weakref.ref(parent)
@property
def parent(self):
"""Return the parent MetaArray object"""
out = self._parent_ref()
if out is None:
raise ValueError("Parent LabeledArray is deleted.")
return out
[docs] def construct_by_copying(self, parent: MetaArray | None) -> Self:
if parent is None:
parent = self.parent
data: dict[str, SupportAxesSlicing] = {}
for k, value in self.items():
if parent.axes.contains(value.axes):
data[k] = value.copy()
return self.__class__(data, parent)
[docs] def construct_by_slicing(self, key, next_parent: MetaArray | None) -> Self:
parent = self.parent
if next_parent is None:
next_parent = parent
data: dict[str, SupportAxesSlicing] = {}
for k, value in self.items():
if value is not None:
if not isinstance(key, np.ndarray):
if isinstance(key, tuple):
_keys = key
else:
_keys = (key,)
label_sl = tuple(
_fmt_slice(_keys[i], parent.shape[i])
for i, a in enumerate(parent.axes)
if (a in value.axes and i < len(_keys))
)
if len(label_sl) == 0 or len(label_sl) > len(value.axes):
label_sl = ()
else:
label_sl = key
if hasattr(value, "_slice_by"):
data[k] = value._slice_by(label_sl)
else:
data[k] = value[label_sl]
return self.__class__(data, next_parent)
def __getitem__(self, key: str) -> SupportAxesSlicing:
return self._data[key]
def __setitem__(self, key: str, value: SupportAxesSlicing) -> None:
if value is None:
self.pop(key, None)
return
parent = self.parent
if not value._dimension_matches(parent):
raise ValueError(
f"Shape of input object ({value!r}) does not match the "
f"parent array ({parent.shape_info})."
)
self._data[key] = value
def __delitem__(self, key: str) -> None:
del self._data[key]
def __len__(self) -> int:
return len(self._data)
def __iter__(self):
return iter(self._data)
[docs]class LabeledArray(MetaArray):
_name: str
_source: Path | None
_metadata: dict[str, Any]
_covariates: ArrayCovariates
def __new__(
cls: type[LabeledArray],
obj,
name: str | None = None,
axes: AxesLike | None = None,
source: str | Path | None = None,
metadata: dict[str, Any] | None = None,
dtype: DTypeLike = None,
) -> Self:
self: LabeledArray = super().__new__(
cls, obj, name, axes, source, metadata, dtype
)
self._covariates = ArrayCovariates({}, self)
return self
@MetaArray.axes.setter
def axes(self, value: AxesLike):
if not hasattr(self, "_axes"):
# not initialized yet
MetaArray.axes.fset(self, value)
return
old_axes = self.axes
MetaArray.axes.fset(self, value)
new_axes = self.axes
_old_to_new_map = {k: v for k, v in zip(old_axes, new_axes)}
for v in self._covariates.values():
v.axes = [_old_to_new_map[a] for a in v.axes]
@property
def range(self) -> tuple[float, float]:
"""Return min/max range of the array."""
return self.min(), self.max()
@property
def covariates(self) -> ArrayCovariates:
"""Get all the covariates."""
return self._covariates
@property
def labels(self) -> Label | None:
"""The label of the image."""
return self.covariates.get("labels")
@labels.setter
def labels(self, value: np.ndarray | None):
if value is None:
self.covariates.pop("labels", None)
return
if value is self:
raise ValueError("Setting labels recursively is not allowed.")
if not isinstance(value, Label):
# convert input
arr = np.asarray(value)
if arr.dtype.kind == "i":
arr = arr.astype(np.uint64)
elif arr.dtype.kind != "u":
raise TypeError(
f"Input label must be unsigned int but has wrong dtype {arr.dtype}."
)
axes = str(self.axes)[-arr.ndim:]
value = Label(arr, axes=axes).optimize()
if not value._dimension_matches(self):
raise ValueError(
f"Shape of input label ({value.shape_info}) does not match the "
f"parent array ({self.shape_info})."
)
self.covariates["labels"] = value
@labels.deleter
def labels(self):
self.covariates.pop("labels", None)
@property
def rois(self) -> RoiList:
"""ROIs of the image."""
return self.covariates.get("rois")
@rois.setter
def rois(self, val) -> None:
from ..roi import RoiList, POS
if isinstance(val, RoiList):
if val.axes[0] == POS and val.axes[1:] == self.axes[1:]:
import copy
val = copy.copy(val)
val.axes = self.axes[0] + val.axes[1:]
self.covariates["rois"] = val
else:
self.covariates["rois"] = RoiList(self.axes, val)
@rois.deleter
def rois(self) -> None:
self.covariates.pop("rois", None)
[docs] def set_scale(self, other=None, **kwargs) -> None:
super().set_scale(other, **kwargs)
for cov in self.covariates.values():
if hasattr(cov, "set_scale"):
cov.set_scale(other, **kwargs)
return None
def _repr_dict_(self):
if self.labels is not None:
labels_shape_info = self.labels.shape_info
if len(labels_shape_info) == 0:
labels_shape_info = "scalar"
else:
labels_shape_info = "No label"
return {
"name": self.name,
"shape": self.shape_info,
"label shape": labels_shape_info,
"dtype": self.dtype,
"source": self.source,
"scale": self.scale,
}
[docs] def imsave(
self,
save_path: str | Path,
*,
dtype: DTypeLike = None,
overwrite: bool = True,
) -> None:
"""
Save image at the same directory as the original image by default.
For tif file format, if the image contains wrong axes for ImageJ (= except for tzcyx),
then it will converted automatically if possible. For mrc file format, only zyx and yx is
allowed. zyx-scale is also saved.
Parameters
----------
save_path : str
File name.
dtype : dtype-like, optional
In what data type img will be saved.
overwrite : bool, default is True
Whether to overwrite the file if it already exists.
"""
save_path = Path(save_path)
if self.ndim < 2:
raise ValueError("Cannot save <2D array as an image.")
if save_path.suffix == "":
if self.source is not None:
ext = self.source.suffix
if ext == "":
ext = ".tif"
else:
ext = ".tif"
save_path = save_path.parent / (save_path.name + ext)
if not Path(save_path).is_absolute():
if self.source is None:
raise ValueError(
"Image directory path is unknown. Set by \n"
" >>> img.source = \"...\"\n"
"or specify absolute path like\n"
" >>> img.imsave(\"/path/to/XXX.tif\")"
)
save_path = self.source.parent / save_path
if not overwrite and save_path.exists():
raise FileExistsError(f"File {save_path!r} already exists.")
if self.metadata is None:
self.metadata = {}
if dtype is None:
dtype = self.dtype
# save image
imsave(save_path, self)
return None
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Basic Functions
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def __array_finalize__(self, obj):
super().__array_finalize__(obj)
self._inherit_covariates(obj)
def _set_info(self, other: Self, new_axes: Any = MetaArray._INHERIT):
super()._set_info(other, new_axes)
self._inherit_covariates(other)
def _inherit_covariates(self, other: Self):
if isinstance(other, LabeledArray):
if other is not self:
self._covariates = other.covariates.construct_by_copying(self)
else:
self._covariates = other._covariates
else:
self._covariates = ArrayCovariates({}, self)
def _getitem_additional_set_info(self, other: Self, key, new_axes):
self._covariates = getattr(self, "covariates", ArrayCovariates({}, self))
super()._set_info(other, new_axes)
if isinstance(other, LabeledArray):
self._covariates = other.covariates.construct_by_slicing(key, self)
return None
def _update(self, out: Self):
self.value[:] = out.as_img_type(self.dtype).value[:]
return None
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Type Conversions
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
[docs] def as_uint8(self) -> Self:
if self.dtype == np.uint8:
return self
if self.dtype == np.uint16:
out = self.value / 256
elif self.dtype == bool:
out = self.value
elif self.dtype.kind == "f":
out = self.value + 0.5
out[out<0] = 0
out[out>255] = 255
else:
raise TypeError(f"invalid data type: {self.dtype}")
out = out.astype(np.uint8)
out = out.view(self.__class__)
out._set_info(self)
return out
[docs] def as_uint16(self) -> Self:
if self.dtype == np.uint16:
return self
if self.dtype == np.uint8:
out = self.value.astype(np.uint16) * 256
elif self.dtype == bool:
out = self.value
elif self.dtype.kind == "f":
out = self.value + 0.5
out[out<0] = 0
out[out>65535] = 65535
else:
raise TypeError(f"invalid data type: {self.dtype}")
out = out.astype(np.uint16)
out = out.view(self.__class__)
out._set_info(self)
return out
[docs] def as_float(self) -> Self:
if self.dtype == np.float32:
return self
out = self.value.astype(np.float32).view(self.__class__)
out._set_info(self)
return out
[docs] def as_img_type(self, dtype=np.uint16) -> Self:
dtype = np.dtype(dtype)
if self.dtype == dtype:
return self
elif dtype == "uint16":
return self.as_uint16()
elif dtype == "uint8":
return self.as_uint8()
elif dtype == "float32":
return self.as_float()
elif dtype == "bool":
return self.astype("bool")
elif dtype == "float64":
warn("Data type float64 is not valid for images. It was converted to float32 instead",
UserWarning)
return self.as_float()
elif dtype == "complex64":
return self.astype(np.complex64)
elif dtype == "complex128":
warn("Data type complex128 is not valid for images. It was converted to complex64 instead",
UserWarning)
return self.astype(np.complex64)
elif dtype in ("int8", "int16"):
return self.astype(dtype)
else:
raise ValueError(f"dtype: {dtype}")
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Simple Visualizations
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
[docs] def hist(self, contrast=None):
"""Show intensity profile."""
from ._utils import _plot as _plt
_plt.hist(self.value, contrast)
return None
[docs] @dims_to_spatial_axes
def imshow(self, label: bool = False, dims = 2, plugin="matplotlib", **kwargs):
from ._imshow import imshow
return imshow(self, label=label, dims=dims, plugin=plugin, **kwargs)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Interpolation
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
[docs] @_docs.write_docs
@same_dtype(asfloat=True)
@dims_to_spatial_axes
def map_coordinates(
self,
coordinates,
*,
mode: PaddingMode = "constant",
cval: float = 0,
order: int = 3,
prefilter: bool | None = None,
dims: Dims = None,
) -> Self:
"""
Coordinate mapping in the image. See ``scipy.ndimage.map_coordinates``.
Parameters
----------
coordinates : ArrayLike
Interpolation coordinates. Must be (D, N) or (D, X_1, ..., X_D) shape.
{mode}{cval}{order}
prefilter : bool, optional
Spline prefilter applied to the array. By default set to True if ``order`` is larger
than 1.
{dims}
Returns
-------
LabeledArray
Transformed image. Axes will be correctly filled if possible.
"""
coords = xp.asarray(coordinates)
c_axes = complement_axes(dims, self.axes)
if coords.ndim != 2:
drop_axis = []
else:
drop_axis = [self.axisof(a) for a in dims[:-1]]
if prefilter is None:
prefilter = order > 1
out = self._apply_dask(
_map_coordinates,
c_axes,
dtype=self.dtype,
drop_axis=drop_axis,
args=(coords,),
kwargs=dict(mode=mode, cval=cval, order=order, prefilter=prefilter),
)
if coords.ndim == len(dims) + 1:
if isinstance(coordinates, MetaArray):
new_axes = c_axes + coordinates.axes[1:]
else:
new_axes = self.axes
else:
if isinstance(coordinates, MetaArray):
new_axes = c_axes + coordinates.axes[1:2]
else:
new_axes = c_axes + ["#"]
out = out.view(self.__class__)
out._set_info(self, new_axes=new_axes)
return out
[docs] @_docs.write_docs
def pointprops(self, coords: Coords, *, order: int = 3, squeeze: bool = True) -> PropArray:
"""
Measure interpolated intensity at points with float coordinates.
This method is essentially identical to :func:`map_coordinates` but is
more straightforward for measuring intensities at points.
Parameters
----------
coords : DataFrame or array-like
Coordinates of point to be measured.
{order}
squeeze : bool, default is True
If True and only one point is measured, the redundant dimension ID_AXIS will be deleted.
Returns
-------
PropArray or float
Intensities at points.
Examples
--------
Calculate centroids and measure intensities.
>>> coords = img.proj("t").centroid_sm()
>>> prop = img.pointprops(coords)
"""
id_axis = "N"
coords = MetaArray(np.atleast_2d(coords), axes=[id_axis, "dim"])
npoints, ncol = coords.shape
dims = self.axes[-ncol:]
out = self.map_coordinates(coords.T, order=order, dims=dims)
out = PropArray(
out, name=out.name, axes=out.axes, source=out.source,
metadata=out.metadata, propname="pointprops",
)
out = np.moveaxis(out, out.axisof(id_axis), 0)
if npoints == 1 and squeeze:
out = out[0]
return out
[docs] @_docs.write_docs
def reslice(
self,
a: ArrayLike,
b=None,
*,
order: int = 3,
prefilter: bool | None = None,
) -> PropArray:
"""
Measure line profile (kymograph) iteratively for every slice of image. This
function is almost same as `skimage.measure.profile_line`, but can reslice
3D-images. The argument `linewidth` is not implemented here because it is
useless.
Parameters
----------
a : array-like
Path or source coordinate. If the former, it must be like:
`a = [[y0, x0], [y1, x1], ..., [yn, xn]]`
b : array-like, optional
Destination coordinate. If specified, `a` must be the source coordinate.
{order}
Returns
-------
PropArray
Line scans.
Examples
--------
1. Rescile along a line and fit to a model function for every time frame.
>>> scan = img.reslice([18, 32], [53, 48])
>>> out = scan.curve_fit(func, init, return_fit=True)
>>> plt.plot(scan[0])
>>> plt.plot(out.fit[0])
2. Rescile along a path.
>>> scan = img.reslice([[18, 32], [53,48], [22,45], [28, 32]])
"""
# path = [[y1, x1], [y2, x2], ..., [yn, xn]]
if b is not None:
a = [list(a), list(b)]
a = np.asarray(a, dtype=np.float32)
_, ndim = a.shape
seg = SegmentedLine(a)
coords = seg.sample_points().T
if ndim == self.ndim:
dims = self.axes
else:
dims = complement_axes("c", self.axes)[-ndim:]
c_axes = complement_axes(dims, self.axes)
result = self.map_coordinates(
coords, order=order, mode="constant", prefilter=prefilter, dims=dims,
)
new_axis = "s"
out = PropArray(result, name=self.name, dtype=np.float32,
axes=c_axes+[new_axis], propname="reslice")
out.set_scale(self)
out.set_scale({new_axis: self.scale[dims[-1]] * seg.interv})
return out
[docs] @_docs.write_docs
@check_input_and_output
def pathprops(
self,
paths: PathFrame | ArrayLike | Sequence[ArrayLike],
properties: str | Callable | Iterable[str | Callable] = "mean",
*,
order: int = 1,
) -> DataDict[str, PropArray]:
"""
Measure line property using func(line_scan) for each functions in properties.
Parameters
----------
paths : PathFrame
Paths to measure properties.
properties : str or callable, or their iterable
Properties to be analyzed.
{order}
Returns
-------
DataDict of PropArray
Line properties. Keys are property names and values are the corresponding PropArrays.
Examples
--------
1. Time-course measurement of intensities on a path.
>>> img.pathprops([[2, 3], [102, 301], [200, 400]])
"""
id_axis = "N"
# normalize paths
if type(paths).__name__ == "PathFrame":
paths = [np.asarray(path) for path in paths.split(id_axis)]
elif _count_list_depth(paths) == 2:
paths = [np.asarray(paths)]
else:
paths = [np.asarray(path) for path in paths]
ndim = paths[0].shape[1]
npaths = len(paths)
dims = ["z", "y", "x"][-ndim:]
# make a function dictionary
funcdict = dict()
if isinstance(properties, str) or callable(properties):
properties = (properties,)
for f in properties:
if isinstance(f, str):
funcdict[f] = getattr(np, f)
elif callable(f):
funcdict[f.__name__] = f
else:
raise TypeError(f"Cannot interpret property {f}")
c_axes = complement_axes(dims, self.axes)
out_shape = tuple(self.sizeof(a) for a in c_axes)
out = DataDict(
{k: PropArray(
np.empty((npaths,) + out_shape, dtype=np.float32),
name=self.name,
axes=[id_axis] + c_axes,
source=self.source,
propname = f"pathprops<{k}>",
dtype=np.float32
)
for k in funcdict.keys()
}
)
if order > 1:
self = self.spline_filter(order=order, mode="constant")
for i, path in enumerate(paths):
resliced = self.reslice(path, order=order, prefilter=False)
for name, func in funcdict.items():
out[name][i] = np.apply_along_axis(func, axis=-1, arr=resliced.value)
return out
[docs] @_docs.write_docs
@dims_to_spatial_axes
@same_dtype(asfloat=True)
@check_input_and_output
def spline_filter(
self,
order: int = 3,
mode: PaddingMode = "mirror",
*,
dims: Dims = None,
update: bool = False,
):
"""
Run spline filter.
Parameters
----------
{order}{mode}{dims}{update}
Returns
-------
LabeledArray
Filtered image.
"""
from ._utils import _filters
return self._apply_dask(
_filters.spline_filter,
c_axes=complement_axes(dims, self.axes),
args=(order, np.float32, mode),
)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Cropping
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
[docs] @_docs.write_docs
@check_input_and_output
@dims_to_spatial_axes
def crop_center(self, scale: nDFloat = 0.5, *, dims=2) -> Self:
r"""
Crop out the center of an image.
Parameters
----------
scale : float or array-like, default is 0.5
Scale of the cropped image. If an array is given, each axis will be cropped in different scales,
using each value respectively.
{dims}
Returns
-------
Self
CroppedImage
Examples
--------
1. Create a :math:`512\times512` image from a :math:`1024\times1024` image.
>>> img_cropped = img.crop_center(scale=0.5)
2. Create a :math:`21\times256\times256` image from a :math:`63\times1024\times1024` image.
>>> img_cropped = img.crop_center(scale=[1/3, 1/2, 1/2])
"""
# check scale
if hasattr(scale, "__iter__") and len(scale) == 3 and dims == "yx":
dims = "zyx"
scale = np.asarray(check_nd(scale, len(dims)))
if np.any((scale <= 0) | (1 < scale)):
raise ValueError(f"scale must be (0, 1], but got {scale}")
# Make axis-targeted slicing string
sizes = self.sizesof(dims)
fmt = slicer.get_formatter(dims)
slices = []
for size, sc in zip(sizes, scale):
if sc == 1:
x0 = 0
x1 = size
else:
x0 = int(np.ceil((size - 1) / 2 * (1 - sc)))
x1 = int((size - 1) / 2 * (1 + sc) + 1)
slices.append(slice(x0, x1))
out = self[fmt[tuple(slices)]]
return out
[docs] @check_input_and_output
def crop_kernel(self, radius: nDInt = 2) -> Self:
r"""
Make a kernel from an image by cropping out the center region.
This function is useful especially in `ImgArray.defocus()`.
Parameters
----------
radius : int or array-like of int, default is 2
Radius of the kernel.
Returns
-------
LabeledArray
Kernel
Examples
--------
Make a :math:`4\times4\times4` kernel from a point spread function image (suppose the
image shapes are all even numbers).
>>> psf = ip.imread(r".../PSF.tif")
>>> psfker = psf.crop_kernel()
>>> psfer.shape
(4, 4, 4)
"""
sizes = self.shape
radii = check_nd(radius, len(sizes))
return self[tuple(slice(s//2-r, (s+1)//2+r) for s, r in zip(sizes, radii))]
[docs] @_docs.write_docs
@check_input_and_output
@dims_to_spatial_axes
def remove_edges(self, pixel: nDInt = 1, *, dims=2) -> Self:
"""
Remove pixels from the edges.
Parameters
----------
pixel : int or array-like, default is 1
Number of pixels to remove. If an array is given, each axis will be cropped with different pixels,
using each value respectively.
{dims}
Returns
-------
LabeledArray
Cropped image.
"""
if hasattr(pixel, "__iter__") and len(pixel) == 3 and len(dims) == 2:
dims = "zyx"
pixel = np.asarray(check_nd(pixel, len(dims)), dtype=np.int64)
if np.any(pixel < 0):
raise ValueError("`pixel` must be positive.")
fmt = slicer.get_formatter(dims)
sl = tuple(slice(px, (-px or None)) for px in pixel)
out = self[fmt[sl]]
return out
[docs] @_docs.write_docs
@check_input_and_output
@dims_to_spatial_axes
def rotated_crop(self, origin, dst1, dst2, dims=2) -> Self:
"""
Crop the image at four courners of an rotated rectangle. Currently only supports rotation within
yx-plane. An rotated rectangle is specified with positions of a origin and two destinations `dst1`
and `dst2`, i.e., vectors (dst1-origin) and (dst2-origin) represent a rotated rectangle. Let
origin be the origin of a xy-plane, the rotation direction from dst1 to dst2 must be counter-
clockwise, or the cropped image will be reversed.
Parameters
----------
origin : (float, float)
dst1 : (float, float)
dst2 :(float, float)
{dims}
Returns
-------
LabeledArray
Cropped array.
"""
origin = np.asarray(origin)
dst1 = np.asarray(dst1)
dst2 = np.asarray(dst2)
ax0 = _misc.make_rotated_axis(origin, dst2)
ax1 = _misc.make_rotated_axis(dst1, origin)
all_coords = ax0[:, np.newaxis] + ax1[np.newaxis] - origin
all_coords = np.moveaxis(all_coords, -1, 0)
cropped_img = self._apply_dask(
ndi.map_coordinates, complement_axes(dims, self.axes),
dtype=self.dtype,
args=(all_coords,),
kwargs=dict(prefilter=False, order=1)
)
cropped_img = cropped_img.view(self.__class__)
cropped_img.axes = self.axes
if self.labels is not None:
try:
lbl = self.labels
cropped_labels = np.empty(lbl.shape[:-2] + all_coords.shape[1:], dtype=lbl.dtype)
for sl, lbl2d in lbl.iter(complement_axes(dims, lbl.axes)):
cropped_labels[sl] = ndi.map_coordinates(lbl2d, all_coords, prefilter=False, order=0)
except Exception:
print("cropping labels failed")
else:
cropped_img.append_label(cropped_labels)
return cropped_img
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Label handling and others
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
[docs] @_docs.write_docs
@dims_to_spatial_axes
def specify(self, center: Coords, radius: Coords, *, dims: Dims = None,
labeltype: str = "square") -> Label:
"""
Make rectangle or ellipse labels from points.
Parameters
----------
center : array like or MarkerFrame
Coordinates of centers. For MarkerFrame, it must have the same axes order.
radius : float or array-like
Radius of labels.
{dims}
labeltype : str, default is "square"
The shape of labels.
Returns
-------
Label
Labeled regions.
Examples
--------
Find single molecules, draw circular labels around them if mean values were greater than 100.
>>> coords = img.find_sm()
>>> filter_func = lambda a: np.mean(a) > 100
>>> img.specify(coords, 3.5, filt=filter_func, labeltype="circle")
>>> ip.gui.add(img)
"""
from ..frame import MarkerFrame
if isinstance(center, MarkerFrame):
from ._utils._process_numba import _specify_circ_2d, _specify_circ_3d, _specify_square_2d, _specify_square_3d
ndim = len(dims)
radius = np.asarray(check_nd(radius, ndim), dtype=np.float32)
if labeltype in ("square", "s"):
radius = radius.astype(np.uint8)
_specify = {2: _specify_square_2d,
3: _specify_square_3d}[ndim]
elif labeltype in ("circle", "c"):
_specify = {2: _specify_circ_2d,
3: _specify_circ_3d}[ndim]
else:
raise ValueError("`labeltype` must be 'square' or 'circle'.")
label_axes = str(center.col_axes)
label_shape = self.sizesof(label_axes)
labels = largest_zeros(label_shape)
n_label = 1
for sl, crds in center.iter(complement_axes(dims, center.col_axes)):
_specify(labels[sl], crds.values, radius, n_label)
n_label += len(crds)
if self.labels is not None:
warn("Existing labels are updated.", UserWarning)
self.labels = Label(labels, axes=label_axes).optimize()
self.labels.set_scale(self)
else:
center = np.asarray(center)
if center.ndim == 1:
center = center.reshape(1, -1)
cols = {2:"yx", 3:"zyx"}[center.shape[1]]
center = MarkerFrame(center, columns=cols, dtype=np.uint16)
return self.specify(center, radius, dims=dims, labeltype=labeltype)
return self.labels
[docs] @_docs.write_docs
@dims_to_spatial_axes
def label(
self,
ref_image: np.ndarray | None = None,
filt: Callable[..., bool] | None = None,
*,
dims: Dims = None,
connectivity: int | None = None
) -> Label:
"""
Label image using skimage's label().
Label image using `ref_image` as reference image, or image itself. If
``filt`` is given, image will be labeled only if certain condition
dictated in `filt` is satisfied. `regionprops_table` is called inside
every time image is labeled.
.. code-block:: python
def filt(img, lbl, area, major_axis_length):
return area>10 and major_axis_length>5
Parameters
----------
ref_image : array, optional
Image to make label, by default self is used.
filt : callable, positional argument but not optional
Filter function. The first argument is intensity image sliced from
`self`, the second is label image sliced from labeled `ref_image`,
and the rest arguments is properties that will be calculated using
`regionprops` function. The property arguments **must be named
exactly same** as the properties in `regionprops`. Number of
arguments can be two.
{dims}
{connectivity}
Returns
-------
Label
Newly created label.
Examples
--------
1. Label the image with threshold and visualize with napari.
>>> thr = img.threshold()
>>> img.label(thr)
>>> ip.gui.add(img)
2. Label regions if only intensity is high.
>>> def high_intensity(img, lbl, slice):
>>> return np.mean(img[slice]) > 10000
>>> img.label(lbl, filt)
3. Label regions if no hole exists.
>>> def no_hole(img, lbl, euler_number):
>>> return euler_number > 0
>>> img.label(lbl, filt)
4. Label regions if centroids are inside themselves.
>>> def no_hole(img, lbl, centroid):
>>> yc, xc = map(int, centroid)
>>> return lbl[yc, xc] > 0
>>> img.label(lbl, filt)
"""
# check the shape of label_image
if ref_image is None:
ref_image = self
else:
if not isinstance(ref_image, MetaArray):
ref_image = np.asarray(ref_image)
ref_image = MetaArray(
ref_image,
axes=self.axes[-ref_image.ndim:]
)
if not ref_image._dimension_matches(self):
raise ImageAxesError(
f"Shape mismatch. Image is {self.shape_info} but reference is"
f"{ref_image.shape_info}."
)
c_axes = complement_axes(dims, self.axes)
labels = largest_zeros(ref_image.shape)
if filt is None:
labels[:] = ref_image._apply_dask(
skmes.label,
c_axes=c_axes,
kwargs=dict(background=0, connectivity=connectivity)
).view(np.ndarray)
else:
if not callable(filt):
raise TypeError("`filt` must be callable.")
import inspect
import pandas as pd
properties = tuple(inspect.signature(filt).parameters)[2:]
offset = 1
for sl, lbl in ref_image.iter(c_axes):
lbl = skmes.label(lbl, background=0, connectivity=connectivity)
img = self.value[sl]
# Following lines are essentially doing the same thing as
# `skmes.regionprops_table`. However, `skmes.regionprops_table`
# returns tuples in the separated columns in DataFrame and rename
# property names like "centroid-0" and "centroid-1".
props_obj = skmes.regionprops(lbl, img, cache=False)
d = {prop_name: [getattr(prop, prop_name) for prop in props_obj]
for prop_name in properties}
df = pd.DataFrame(d)
del_list = [i + 1 for i, r in df.iterrows() if not filt(img, lbl, **r)]
labels[sl] = skseg.relabel_sequential(
np.where(np.isin(lbl, del_list), 0, lbl),
offset=offset
)[0]
offset += labels.max()
# correct the label numbers of `labels`
labels = labels.view(Label)
labels._set_info(ref_image)
labels = labels.increment_iter(c_axes).optimize()
labels.set_scale(self)
self.labels = labels
return self.labels
[docs] @_docs.write_docs
@dims_to_spatial_axes
def label_if(
self,
ref_image: np.ndarray | None = None,
filt: Callable[..., bool] | None = None,
*,
dims: Dims = None,
connectivity: int | None = None,
) -> Label:
warn(
"`label_if` is deprecated and will be removed soon. `label` method does the "
"same function",
DeprecationWarning,
)
return self.label(ref_image, filt, dims=dims, connectivity=connectivity)
[docs] @check_input_and_output
def append_label(self, label_image: np.ndarray, new: bool = False) -> Label:
"""
Append new labels from an array. This function works for boolean or signed int arrays.
Parameters
----------
label_image : np.ndarray
Labeled image.
new : bool, default is False
If True, existing labels will be removed anyway.
Returns
-------
Label
New labels.
Example
-------
Make label from different channels.
>>> thr0 = img["c=0"].threshold("90%")
>>> thr0.label() # binary to label
>>> thr1 = img["c=1"].threshold("90%")
>>> thr1.label() # binary to label
>>> img.append_label(thr0.labels)
>>> img.append_label(thr1.labels)
If `thr0` has 100 labels and `thr1` has 150 labels then `img` will have :math:`100+150=250` labels.
"""
# check and cast label dtype
if not isinstance(label_image, np.ndarray):
raise TypeError(f"`label_image` must be ndarray, but got {type(label_image)}")
elif label_image.dtype.kind == "u":
pass
elif label_image.dtype == bool:
label_image = label_image.astype(np.uint8)
elif label_image.dtype == np.int32:
label_image = label_image.astype(np.uint16)
elif label_image.dtype == np.int64:
label_image = label_image.astype(np.uint32)
elif label_image.dtype.kind == "i":
label_image = label_image.astype(np.uint8)
else:
raise ValueError(
f"`label_image` has dtype {label_image.dtype}, which is unable "
"to be interpreted as an label."
)
if self.labels is not None and not new:
if label_image.shape != self.labels.shape:
raise ImageAxesError(
"Shape mismatch. Existing labels have shape "
f"{self.labels.shape} while labels with shape "
f"{label_image.shape} is given."
)
self.labels = self.labels.add_label(label_image)
else:
# when label_image is simple ndarray
if not isinstance(label_image, MetaArray):
if label_image.shape == self.shape:
axes = self.axes
elif label_image.ndim == 2 and self.axes.contains("yx"):
axes = "yx"
else:
raise ValueError("Could not infer axes of `label_image`.")
else:
axes = label_image.axes
if not self.axes.contains(label_image.axes):
raise ImageAxesError(
f"Axes mismatch. Image has {self.axes}-axes but "
f"{axes} was given."
)
self.labels = Label(label_image, axes=axes, source=self.source)
return self.labels
[docs] @check_input_and_output(need_labels=True)
def proj_labels(self, axis=None, forbid_overlap=False) -> Label:
"""
Label projection. This function is useful when zyx-labels are drawn but you want to reduce the
dimension.
Parameters
----------
axis : str, optional
Along which axis projection will be calculated. If None, most plausible one will be chosen.
forbid_overlap : bool, default is False
If True and there were any label overlap, this function will raise ValueError.
Returns
-------
Label
Projected labels.
"""
self.labels = self.labels.proj(axis=axis, forbid_overlap=forbid_overlap)
return self.labels
[docs] def split(self, axis=None) -> DataList[Self]:
"""
Split n-dimensional image into (n-1)-dimensional images. This function is different from
`np.split`, which split an array into smaller pieces (n-D to n-D).
Parameters
----------
axis : str or int, optional
Along which axis the original image will be split, by default "c"
Returns
-------
list of arrays
Separate images
"""
# determine axis in int.
if axis is None:
axis = find_first_appeared(self.axes, include="cztpa")
axisint = self.axisof(axis)
imgs = super().split(axisint)
if self.labels is not None:
labels = self.labels.split(axisint)
for img, lbl in zip(imgs, labels):
lbl.axes = self.labels.axes.drop(axisint)
lbl.set_scale(self.labels)
img.labels = lbl
return imgs
[docs] def tile(
self,
shape: tuple[int, int] | None = None,
along: str | None = None,
order: str | None = None,
) -> Self:
"""
Tile images in a certain order.
Parameters
----------
shape : tuple[int, int], optional
Grid shape. This parameter must be specified unless the length of `along` is 2.
along : str, optional
Axis (Axes) over which will be iterated.
order : str, {"r", "c"}, optional
Order of iteration. "r" means row-wise and "c" means column-wise.
row-wise
----->
----->
----->
column-wise
| | |
| | |
v v v
Returns
-------
Labeled
Tiled array
"""
if along is None:
for a in self.axes:
l = np.prod(shape)
if self.sizeof(a) == l:
along = a
break
else:
raise ValueError(f"Could not find axis that can be reshaped to shape {shape}")
elif len(along) == 2:
uyaxis, uxaxis = self.axisof(along[0]), self.axisof(along[1])
if uyaxis < uxaxis:
shape = self.sizesof(along)
order = "r"
else:
order = "c"
shape = self.sizesof(along[::-1])
elif len(along) == 1:
if shape is None:
raise ValueError("`shape` must be specified unless the length of `along` is 2.")
else:
raise ValueError("`along` must be a string with length 1 or 2.")
if order is None:
order = "r"
uy_max, ux_max = shape
imgy, imgx = self.sizesof("yx")
if len(shape) == 2:
c_axes = complement_axes("yx"+along, self.axes)
new_axes = c_axes + "yx"
outshape = self.sizesof(c_axes) + (uy_max*imgy, ux_max*imgx)
else:
raise ValueError("Shape mismatch")
out = np.zeros(outshape, dtype=self.dtype)
if order == "r":
iter_tile = _iter_tile_yx
elif order == "c":
iter_tile = _iter_tile_xy
else:
raise ValueError(f"Could not interpret order={repr(order)}.")
for (_, img), sl in zip(self.iter(along), iter_tile(uy_max, ux_max, imgy, imgx)):
out[sl] = img
out = out.view(self.__class__)
out._set_info(self, new_axes=new_axes)
if self.labels is not None:
tiled_label = self.labels.tile(shape, along, order)
out.labels = tiled_label
return out
[docs] @check_input_and_output
def for_each_channel(self, func: str, along: str = "c", **kwargs) -> Self:
"""
Apply same function with different parameters for each channel. This function will be useful
when the parameters are dependent on channels, like wave length.
Parameters
----------
func : str
Function name to apply over channel axis.
along : str, default is "c"
Along which axis function will be applied to.
Returns
-------
LabeledArray
output image stack
"""
if not hasattr(self, func):
raise AttributeError(f"{self.__class__} does not have method {func}")
imgs = self.split(along)
outs = []
for img, kw in zip(imgs, _iter_dict(kwargs, len(imgs))):
outs.append(out)
out = np.stack(outs, axis=along)
return out
[docs] @check_input_and_output
def for_params(self, func: Callable|str, var: dict[str, Iterable] = None, **kwargs) -> DataList:
"""
Apply same function with different parameters with same input. This function will be useful
when you want to try different conditions to the same image.
Parameters
----------
func : callable or str
Function to apply repetitively. If str, then member method will be called.
var : dict[str, Iterable], optional
Name of variable and the values to try. If you want to try sigma=1,2,3 then you should
give `var={"sigma": [1, 2, 3]}`.
kwargs
Fixed paramters that will be passed to `func`. If `var` is not given and only one parameter
is provided in `kwargs`, then kwargs will be `var`.
Returns
-------
DataList
List of outputs.
Example
-------
1. Try LoG filter with different Gaussian kernel size and visualize all of them in napari.
>>> out = img.for_params("log_filter", var={"sigma":[1, 2, 3, 4]})
# or
>>> out = img.for_params("log_filter", sigma=[1, 2, 3, 4])
# then
>>> ip.gui.add(out)
"""
if isinstance(func, str) and hasattr(self, func):
f = getattr(self, func)
elif callable(func):
f = partial(func, self)
elif not callable(func):
raise AttributeError(f"{func} is neither {self.__class__}'s' method nor callable object.")
if isinstance(var, dict):
key, values = tuple(var.items())[0]
elif var is None and len(kwargs) == 1:
key, values = tuple(kwargs.items())[0]
kwargs = dict()
else:
raise ValueError("Wrong inputs.")
if key in kwargs.keys():
raise ValueError(f"Keyword {key} exists in `kwargs`.")
outlist = DataList()
for v in values:
kwargs[key] = v
out = f(**kwargs)
outlist.append(out)
return outlist
def _iter_dict(d, nparam):
out = dict()
for i in range(nparam):
for k, v in d.items():
if isinstance(v, list):
if len(v) != nparam:
# raise error here for an earlier feedback.
raise ValueError(f"Number of parameter '{k}' does not match the number channels.")
out[k] = v[i]
else:
out[k] = v
yield out
def _iter_tile_yx(ymax, xmax, imgy, imgx):
"""
+--+--+--+
| 0| 1| 2|
+--+--+--+
| 3| 4| 5|
+--+--+--+
| 6| 7|..|
+--+--+--+
"""
for uy, ux in itertools.product(range(ymax), range(xmax)):
sly = slice(uy*imgy, (uy+1)*imgy, None)
slx = slice(ux*imgx, (ux+1)*imgx, None)
yield ..., sly, slx
def _iter_tile_xy(ymax, xmax, imgy, imgx):
"""
+--+--+--+
| 0| 3| 6|
+--+--+--+
| 1| 4| 7|
+--+--+--+
| 2| 5|..|
+--+--+--+
"""
for uy, ux in itertools.product(range(xmax), range(ymax)):
sly = slice(uy*imgy, (uy+1)*imgy, None)
slx = slice(ux*imgx, (ux+1)*imgx, None)
yield ..., slx, sly
[docs]class SegmentedLine:
def __init__(self, nodes: np.ndarray):
if nodes.shape[0] < 2:
raise ValueError("More than one points must be given.")
vec = np.diff(nodes, axis=0)
dist = np.sqrt(np.sum(vec**2, axis=1))
dist_sum = np.sum(dist)
npoints = int(dist_sum)
interv = dist_sum / npoints
self.length = dist_sum
self.vec = vec
self.dist = dist
self.nodes = nodes
self.interv = interv
[docs] def sample_points(self) -> np.ndarray:
res = 0
out = [self.nodes[0:1]]
npoints = 1
for v, d, p in zip(self.vec, self.dist, self.nodes[:-1]):
res0 = res
d_int, res = divmod(d + res, self.interv)
idx = self.interv*(np.arange(d_int) + 1) - res0
xs = idx[:, np.newaxis] * v[np.newaxis]/d + p
out.append(xs)
npoints += xs.shape[0]
if npoints <= int(self.length):
out.append(self.nodes[-1:])
return np.concatenate(out, axis=0)
def _count_list_depth(x) -> int:
n = 0
out = x
while True:
try:
out = out[0]
except IndexError:
break
else:
n += 1
return n
def _map_coordinates(input, coordinates, order, mode, cval, prefilter):
return xp.ndi.map_coordinates(
xp.asarray(input),
xp.asarray(coordinates),
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)