from __future__ import annotations
import numpy as np
from .metaarray import MetaArray
from .historyarray import HistoryArray
from ...axes import Axes
from ...utils.axesop import del_axis
from ...collections import DataList
[docs]def safe_set_info(out: MetaArray, img: MetaArray, history, new_axes):
if isinstance(img, HistoryArray):
out._set_info(img, history, new_axes=new_axes)
else:
try:
out._set_info(img, new_axes=new_axes)
except Exception:
pass
return None
# Overloading numpy functions using __array_function__.
# https://numpy.org/devdocs/reference/arrays.classes.html
@MetaArray.implements(np.squeeze)
def _(img: MetaArray):
out = np.squeeze(img.value).view(img.__class__)
new_axes = "".join(a for a in img.axes if img.sizeof(a) > 1)
safe_set_info(out, img, "squeeze", new_axes)
return out
@MetaArray.implements(np.take)
def _(a: MetaArray, indices, axis=None, out=None, mode="raise"):
new_axes = del_axis(a.axes, axis)
if isinstance(axis, str):
axis = a.axes.find(axis)
out = np.take(a.value, indices, axis=axis, out=out, mode=mode).view(a.__class__)
if isinstance(out, a.__class__):
out._set_info(a, new_axes=new_axes)
return out
@MetaArray.implements(np.stack)
def _(imgs: list[MetaArray], axis="c", dtype=None):
"""
Create stack image from list of images.
Parameters
----------
imgs : iterable object of images.
Images to stack. These images must have exactly the same shapes.
axis : str or int, default is "c"
Which axis will be the new one.
dtype : str, optional
Output dtype.
Returns
-------
ImgArray
Image stack
"""
if imgs[0].axes:
old_axes = str(imgs[0].axes)
else:
old_axes = None
if isinstance(axis, int):
_axis = axis
axis = "p"
if old_axes:
new_axes = old_axes[:_axis] + axis + old_axes[_axis:]
else:
new_axes = None
elif not isinstance(axis, str):
raise TypeError(f"`axis` must be int or str, but got {type(axis)}")
else:
# find where to add new axis
if old_axes:
if imgs[0].axes.is_sorted():
new_axes = Axes(axis + old_axes)
new_axes.sort()
_axis = new_axes.find(axis)
else:
new_axes = axis + old_axes
_axis = 0
else:
new_axes = None
_axis = 0
if dtype is None:
dtype = imgs[0].dtype
arrs = [img.as_img_type(dtype).value for img in imgs]
out = np.stack(arrs, axis=0)
out = np.moveaxis(out, 0, _axis)
out = out.view(imgs[0].__class__)
safe_set_info(out, imgs[0], f"stack(axis={axis})", new_axes)
return out
@MetaArray.implements(np.concatenate)
def _(imgs: list[MetaArray], axis="c", dtype=None, casting="same_kind"):
if not isinstance(axis, (int, str)):
raise TypeError(f"`axis` must be int or str, but got {type(axis)}")
axis = imgs[0].axisof(axis)
out = np.concatenate([img.value for img in imgs], axis=axis, dtype=dtype, casting=casting)
out = out.view(imgs[0].__class__)
safe_set_info(out, imgs[0], f"concatenate(axis={axis})", imgs[0].axes)
return out
@MetaArray.implements(np.block)
def _(imgs: list[MetaArray]):
def _recursive_view(obj):
if isinstance(obj, MetaArray):
return obj.value
else:
return [_recursive_view(a) for a in obj]
def _recursive_get0(obj):
first = obj[0]
if isinstance(first, MetaArray):
return first
else:
return _recursive_get0(first)
img0 = _recursive_get0(imgs)
imgs = _recursive_view(imgs)
out = np.block(imgs).view(img0.__class__)
safe_set_info(out, img0, "block", img0.axes)
return out
@MetaArray.implements(np.zeros_like)
def _(img: MetaArray, name: str = None):
out = np.zeros_like(img.value).view(img.__class__)
out._set_info(img, new_axes=img.axes)
if isinstance(name, str):
out.name = name
return out
@MetaArray.implements(np.empty_like)
def _(img: MetaArray, name:str = None):
out = np.empty_like(img.value).view(img.__class__)
out._set_info(img, new_axes=img.axes)
if isinstance(name, str):
out.name = name
return out
@MetaArray.implements(np.expand_dims)
def _(img: MetaArray, axis):
if isinstance(axis, str):
new_axes = Axes(axis + str(img.axes))
new_axes.sort()
axisint = tuple(new_axes.find(a) for a in axis)
else:
axisint = axis
new_axes = img.axes
out = np.expand_dims(img.value, axisint).view(img.__class__)
safe_set_info(out, img, f"expand_dims({axis})", new_axes)
return out
@MetaArray.implements(np.transpose)
def _(img: MetaArray, axes: str):
return img.transpose(axes)
@MetaArray.implements(np.split)
def _(img: MetaArray, indices_or_sections, axis=0):
if not isinstance(axis, (int, str)):
raise TypeError(f"`axis` must be int or str, but got {type(axis)}")
axis = img.axisof(axis)
imgs = np.split(img.value, indices_or_sections, axis=axis)
out = []
for i, each in enumerate(imgs):
each = each.view(img.__class__)
safe_set_info(each, img, f"np.split[{i}]", new_axes="inherit")
out.append(each)
return DataList(out)