Source code for impy.arrays.bases.metaarray
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Hashable, Union, SupportsInt, Mapping
from pathlib import Path
import numpy as np
from numpy.typing import DTypeLike
from ..axesmixin import AxesMixin, get_axes_tuple
from ..._types import *
from ...axes import ImageAxesError, AxesLike, Axes
from ...array_api import xp
from ...utils import axesop, slicer
from ...collections import DataList
if TYPE_CHECKING:
from typing_extensions import Self
SupportOneSlicing = Union[SupportsInt, slice]
SupportSlicing = Union[
SupportsInt,
str,
slice,
tuple[SupportOneSlicing, ...],
Mapping[str, SupportOneSlicing],
]
[docs]class MetaArray(AxesMixin, np.ndarray):
additional_props = ["_source", "_metadata", "_name"]
NP_DISPATCH = {}
_name: str
_source: Path | None
_metadata: dict[str, Any]
def __new__(
cls: type[MetaArray],
obj,
name: str | None = None,
axes: Iterable[Hashable] | None = None,
source: str | Path | None = None,
metadata: dict[str, Any] | None = None,
dtype: DTypeLike = None,
) -> Self:
if isinstance(obj, cls):
return obj
self = np.asarray(obj, dtype=dtype).view(cls)
self.source = source
self._name = name
self.axes = axes
self._metadata = metadata or {}
return self
@property
def source(self):
"""The source file path."""
return self._source
@source.setter
def source(self, val):
if val is None:
self._source = None
else:
self._source = Path(val)
@property
def name(self) -> str:
"""Name of the array."""
if self._name is None:
source = self.source
if source is None:
return "No name"
else:
return source.name
else:
return self._name
@name.setter
def name(self, val):
self._name = str(val)
@property
def metadata(self) -> dict[str, Any]:
"""Metadata dictionary of the array."""
return self._metadata
@metadata.setter
def metadata(self, value):
if not isinstance(value, dict):
raise TypeError(f"Cannot set {type(value)} as a metadata.")
self._metadata = value
@property
def value(self) -> np.ndarray:
"""Numpy view of the array."""
return np.asarray(self)
def __repr__(self) -> str:
if self.ndim > 0:
return super().__repr__()
return self.value[()]
def _repr_dict_(self) -> dict[str, Any]:
return {
"name": self.name,
"shape": self.shape_info,
"dtype": self.dtype,
"source": self.source,
"scale": self.scale,
}
def __str__(self):
return f"{self.__class__.__name__}<{self.name!r}>"
@property
def shape(self):
try:
tup = get_axes_tuple(self)
return tup(*super().shape)
except ImageAxesError:
return super().shape
def _set_additional_props(self, other):
# set additional properties
# If `other` does not have it and `self` has, then the property will be inherited.
for p in self.__class__.additional_props:
setattr(self, p, getattr(other, p,
getattr(self, p,
None)))
def _set_info(self, other: Self, new_axes: Any= AxesMixin._INHERIT):
self._set_additional_props(other)
# set axes
try:
if new_axes is not self._INHERIT:
self.axes = new_axes
else:
self.axes = other.axes.copy()
except ImageAxesError:
self.axes = None
return None
def __getitem__(self, key: SupportSlicing) -> Self:
key = slicer.solve_slicer(key, self.axes)
if isinstance(key, np.ndarray):
key = self._broadcast(key)
out = super().__getitem__(key) # get item as np.ndarray
if isinstance(out, self.__class__): # cannot set attribution to such as numpy.int32
new_axes = axesop.slice_axes(self.axes, key)
out._getitem_additional_set_info(
self, new_axes=new_axes, key=key
)
return out
def _getitem_additional_set_info(self, other: Self, key: Slices, new_axes):
self._set_info(other, new_axes=new_axes)
def __setitem__(self, key: SupportSlicing, value):
key = slicer.solve_slicer(key, self.axes)
if isinstance(key, MetaArray) and key.dtype == bool:
key = axesop.add_axes(self.axes, self.shape, key, key.axes)
elif isinstance(key, np.ndarray) and key.dtype == bool and key.ndim == 2:
# img[arr] ... where arr is 2-D boolean array
key = axesop.add_axes(self.axes, self.shape, key)
super().__setitem__(key, value)
def __array_finalize__(self, obj):
"""
Every time an np.ndarray object is made by numpy functions inherited to ImgArray,
this function will be called to set essential attributes. Therefore, you can use
such as img.copy() and img.astype("int") without problems (maybe...).
"""
if obj is None: return None
self._set_additional_props(obj)
try:
self.axes = getattr(obj, "axes", None)
except Exception:
self.axes = None
else:
if len(self.axes) != self.ndim:
self.axes = None
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
"""
Every time a numpy universal function (add, subtract, ...) is called,
this function will be called to set/update essential attributes.
"""
args_, _ = _replace_inputs(self, args, kwargs)
result = getattr(ufunc, method)(*args_, **kwargs)
if result is NotImplemented:
return NotImplemented
result = result.view(self.__class__)
# in the case result is such as np.float64
if not isinstance(result, self.__class__):
return result
result._process_output(ufunc, args, kwargs)
return result
def _inherit_meta(self, obj: AxesMixin, ufunc, **kwargs):
"""
Copy axis etc. from obj.
This is called in __array_ufunc__(). Unlike _set_info(), keyword `axis` must be
considered because it changes `ndim`.
"""
if "axis" in kwargs.keys():
new_axes = obj.axes.drop(kwargs["axis"])
else:
new_axes = self._INHERIT
self._set_info(obj, new_axes=new_axes)
return self
def __array_function__(self, func, types, args, kwargs):
"""
Every time a numpy function (np.mean...) is called, this function will be called. Essentially numpy
function can be overloaded with this method.
"""
if (func in self.__class__.NP_DISPATCH and
all(issubclass(t, MetaArray) for t in types)):
return self.__class__.NP_DISPATCH[func](*args, **kwargs)
args_, _ = _replace_inputs(self, args, kwargs)
result = func(*args_, **kwargs)
if result is NotImplemented:
return NotImplemented
if isinstance(result, (tuple, list)):
_as_meta_array = lambda a: a.view(self.__class__)._process_output(func, args, kwargs) \
if type(a) is np.ndarray else a
result = DataList(_as_meta_array(r) for r in result)
else:
if isinstance(result, np.ndarray):
result = result.view(self.__class__)
# in the case result is such as np.float64
if isinstance(result, self.__class__):
result._process_output(func, args, kwargs)
return result
def _process_output(self, func, args, kwargs):
# find the largest MetaArray. Largest because of broadcasting.
arr = None
for arg in args:
if isinstance(arg, self.__class__):
if arr is None or arr.ndim < arg.ndim:
arr = arg
if isinstance(arr, self.__class__):
self._inherit_meta(arr, func, **kwargs)
return self
[docs] @classmethod
def implements(cls, numpy_function):
"""
Add functions to NP_DISPATCH so that numpy functions can be overloaded.
"""
def decorator(func):
cls.NP_DISPATCH[numpy_function] = func
func.__name__ = numpy_function.__name__
return func
return decorator
[docs] def sort_axes(self) -> Self:
"""
Sort image dimensions to ptzcyx-order
Returns
-------
MetaArray
Sorted image
"""
order = self.axes.argsort()
return self.transpose(order)
[docs] def argmax_nd(self) -> tuple[int, ...]:
"""
N-dimensional version of argmax.
For instance, if yx-array takes its maximum at (5, 8), this function returns
``AxesShape(y=5, x=8)``.
Returns
-------
AxesShape
Argmax of the array.
"""
argmax = np.unravel_index(np.argmax(self), self.shape)
try:
tup = get_axes_tuple(self)
return tup(*argmax)
except ImageAxesError:
return argmax
[docs] def split(self, axis=None) -> DataList[Self]:
"""
Split n-dimensional image into (n-1)-dimensional images.
Parameters
----------
axis : str or int, optional
Along which axis the original image will be split, by default "c"
Returns
-------
list of arrays
Separate images
"""
# determine axis in int.
if axis is None:
axis = axesop.find_first_appeared(self.axes, include="cztp")
axisint = self.axisof(axis)
imgs: DataList[MetaArray] = DataList(np.moveaxis(self, axisint, 0))
for img in imgs:
img.axes = self.axes.drop(axisint)
img.set_scale(self)
return imgs
def _apply_dask(
self,
func: Callable,
c_axes: str | None = None,
drop_axis: Iterable[int] = [],
new_axis: Iterable[int] = None,
dtype = np.float32,
out_chunks: tuple[int, ...] = None,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None
) -> Self:
"""
Convert array into dask array and run a batch process in parallel. In many cases batch process
in this way is faster than `multiprocess` module.
Parameters
----------
func : callable
Function to apply.
c_axes : str, optional
Axes to iterate.
drop_axis : Iterable[int], optional
Passed to map_blocks.
new_axis : Iterable[int], optional
Passed to map_blocks.
dtype : any that can be converted to np.dtype object, default is np.float32
Output data type.
out_chunks : tuple of int, optional
Output chunks. This argument is important when the output shape will change.
args : tuple, optional
Arguments that will passed to `func`.
kwargs : dict
Keyword arguments that will passed to `func`.
Returns
-------
MetaArray
Processed array.
"""
if args is None:
args = tuple()
if kwargs is None:
kwargs = dict()
if len(c_axes) == 0:
# Do not construct dask tasks if it is not needed.
out = xp.asnumpy(func(self.value, *args, **kwargs), dtype=dtype)
else:
from dask import array as da
new_axis = _list_of_axes(self, new_axis)
drop_axis = _list_of_axes(self, drop_axis)
# determine chunk size and slices
chunks = axesop.switch_slice(c_axes, self.axes, ifin=1, ifnot=self.shape)
slice_in = []
slice_out = []
for i, a in enumerate(self.axes):
if a in c_axes:
slice_in.append(0)
slice_out.append(np.newaxis)
else:
slice_in.append(slice(None))
slice_out.append(slice(None))
if i in drop_axis:
slice_out.pop(-1)
if i in new_axis:
slice_in.append(np.newaxis)
slice_in = tuple(slice_in)
slice_out = tuple(slice_out)
all_args = (self.value,) + args
img_idx = []
_args = []
for i, arg in enumerate(all_args):
if isinstance(arg, (np.ndarray, xp.ndarray)) and arg.shape == self.shape:
_args.append(da.from_array(arg, chunks=chunks))
img_idx.append(i)
else:
_args.append(arg)
def _func(*args, **kwargs):
args = list(args)
for i in img_idx:
if args[i].ndim < len(slice_in):
continue
args[i] = args[i][slice_in]
out = func(*args, **kwargs)
return xp.asnumpy(out[slice_out])
out = da.map_blocks(
_func,
*_args,
drop_axis=drop_axis,
new_axis=new_axis,
meta=xp.array([], dtype=dtype),
chunks=out_chunks,
**kwargs
)
out = out.compute()
out = out.view(self.__class__)
return out
[docs] def transpose(self, axes) -> Self:
"""
change the order of image dimensions.
'axes' will also be arranged.
"""
_axes = [self.axisof(a) for a in axes]
new_axes = [self.axes[i] for i in list(axes)]
out: np.ndarray = np.transpose(self.value, _axes)
out = out.view(self.__class__)
out._set_info(self, new_axes=new_axes)
return out
[docs] def reshape(self, *shape, order="C", axes: AxesLike | None = None) -> Self:
out: MetaArray = super().reshape(*shape, order=order)
if axes:
out.axes = axes
return out
@property
def T(self) -> Self:
out = super().T
out.axes = out.axes[::-1]
return out
def _broadcast(self, value: Any):
"""Broadcasting method used in most of the mathematical operations."""
if not isinstance(value, MetaArray):
return value
current_axes = self.axes
if (current_axes == value.axes
or current_axes.has_undef() or
value.axes.has_undef()):
# In most cases arrays don't need broadcasting. Check axes first to
# avoid spending time on broadcasting.
return value
value = value.broadcast_to(self.shape, current_axes)
return value
[docs] def broadcast_to(
self,
shape: tuple[int, ...],
axes: AxesLike | None = None,
) -> Self:
"""
Broadcast array to specified shape and axes.
Parameters
----------
shape : shape-like
Shape of output array.
axes : AxesLike, optional
Axes of output array. If given, it must match the dimensionality of
input shape.
Returns
-------
MetaArray
Broadcasted array.
"""
if axes is None:
return np.broadcast_to(self, shape)
elif len(shape) != len(axes):
raise ValueError(f"Dimensionality mismatch: {shape=} and {axes=}")
current_axes = self.axes
if self.shape == shape and current_axes == axes:
return self
if any(a not in axes for a in current_axes):
ax0 = [str(a) for a in current_axes]
ax1 = [str(a) for a in axes]
raise ImageAxesError(
f"Cannot broadcast array with axes {ax0} to {ax1}."
)
out = self.value
for i, axis in enumerate(axes):
if axis not in current_axes:
out = np.stack([out] * shape[i], axis=i)
out = out.view(self.__class__)
if out.shape != shape:
raise ValueError(
f"Shape {shape} required but returned {out.shape}."
)
if not isinstance(axes, Axes):
new_axes = Axes(axes)
for a in self.axes:
# update axis metadata such as scale
new_axes.replace(str(a), a)
else:
new_axes = axes
out._set_info(self, new_axes=new_axes)
return out
def _dimension_matches(self, array: MetaArray):
"""Check if dimension satisfies ``self <: array``."""
img_shape = array.shape
label_shape = self.shape
return all(
[getattr(img_shape, str(a), _NOTME) == getattr(label_shape, str(a), _NOTME)
for a in self.axes]
)
def __add__(self, value) -> Self:
value = self._broadcast(value)
return super().__add__(value)
def __sub__(self, value) -> Self:
value = self._broadcast(value)
return super().__sub__(value)
def __mul__(self, value) -> Self:
value = self._broadcast(value)
return super().__mul__(value)
def __truediv__(self, value) -> Self:
value = self._broadcast(value)
return super().__truediv__(value)
def __mod__(self, value) -> Self:
value = self._broadcast(value)
return super().__mod__(value)
def __floordiv__(self, value) -> Self:
value = self._broadcast(value)
return super().__floordiv__(value)
def __gt__(self, value) -> Self:
value = self._broadcast(value)
return super().__gt__(value)
def __ge__(self, value) -> Self:
value = self._broadcast(value)
return super().__ge__(value)
def __lt__(self, value) -> Self:
value = self._broadcast(value)
return super().__lt__(value)
def __le__(self, value) -> Self:
value = self._broadcast(value)
return super().__le__(value)
def __eq__(self, value) -> Self:
value = self._broadcast(value)
return super().__eq__(value)
def __ne__(self, value) -> Self:
value = self._broadcast(value)
return super().__ne__(value)
def __and__(self, value) -> Self:
value = self._broadcast(value)
return super().__and__(value)
def __or__(self, value) -> Self:
value = self._broadcast(value)
return super().__or__(value)
def __ne__(self, value) -> Self:
value = self._broadcast(value)
return super().__ne__(value)
def __iadd__(self, value) -> Self:
value = self._broadcast(value)
return super().__iadd__(value)
def __isub__(self, value) -> Self:
value = self._broadcast(value)
return super().__isub__(value)
def __imul__(self, value) -> Self:
value = self._broadcast(value)
return super().__imul__(value)
def __itruediv__(self, value) -> Self:
value = self._broadcast(value)
return super().__itruediv__(value)
def __imod__(self, value) -> Self:
value = self._broadcast(value)
return super().__imod__(value)
def __ifloordiv__(self, value) -> Self:
value = self._broadcast(value)
return super().__ifloordiv__(value)
if TYPE_CHECKING:
def astype(self, dtype) -> Self: ...
def flatten(self, order="C") -> Self: ...
def ravel(self, order="C") -> Self: ...
def _list_of_axes(img: MetaArray, axis):
if axis is None:
axis = []
elif hasattr(axis, "__iter__"):
axis = [img.axisof(a) for a in axis]
elif np.isscalar(axis):
axis = [axis]
return axis
def _replace_inputs(img: MetaArray, args: tuple[Any], kwargs: dict[str, Any]):
_as_np_ndarray = lambda a: a.value if isinstance(a, MetaArray) else a
# convert arguments
args = tuple(_as_np_ndarray(a) for a in args)
if kwargs.get("axis", None) is not None:
axis = kwargs["axis"]
if not hasattr(axis, "__iter__"):
axis = [axis]
kwargs["axis"] = tuple(map(img.axisof, axis))
if kwargs.get("axes", None) is not None:
# used in such as np.rot90
axes = kwargs["axes"]
kwargs["axes"] = tuple(map(img.axisof, axes))
if kwargs.get("out", None) is not None:
kwargs["out"] = tuple(_as_np_ndarray(a) for a in kwargs["out"])
return args, kwargs
_NOTME = NotMe()