from __future__ import annotations
import numpy as np
from ..axes import Axes, Axis, UndefAxis, AxisLike
[docs]def find_first_appeared(axes, include="", exclude=""):
include = list(include)
exclude = list(exclude)
for a in axes:
if a in include and not a in exclude:
return a
raise ValueError(f"Inappropriate axes: {axes}")
[docs]def add_axes(axes: Axes, shape: tuple[int, ...], key: np.ndarray, key_axes="yx"):
"""
Stack `key` to make its shape key_axes-> axes.
"""
key_axes = list(key_axes)
if shape == key.shape:
return key
for i, o in enumerate(axes):
if o not in key_axes:
key = np.stack([key] * shape[i], axis=i)
return key
[docs]def complement_axes(axes, all_axes="ptzcyx") -> list[AxisLike]:
c_axes = []
axes_list = list(axes)
for a in all_axes:
if a not in axes_list:
c_axes.append(a)
return c_axes
[docs]def switch_slice(axes, all_axes, ifin=np.newaxis, ifnot=":"):
axes = list(axes)
if ifnot == ":":
ifnot = [slice(None)] * len(all_axes)
elif not hasattr(ifnot, "__iter__"):
ifnot = [ifnot] * len(all_axes)
if not hasattr(ifin, "__iter__"):
ifin = [ifin] * len(all_axes)
sl = []
for a, slin, slout in zip(all_axes, ifin, ifnot):
if a in axes:
sl.append(slin)
else:
sl.append(slout)
sl = tuple(sl)
return sl
[docs]def slice_axes(axes: Axes, key):
ndim = len(axes)
if isinstance(key, tuple):
ndim += sum(k is None for k in key)
rest = ndim - len(key)
if any(k is ... for k in key):
idx = key.index(...)
_keys = key[:idx] + (slice(None),) * (rest + 1) + key[idx + 1:]
else:
_keys = key + (slice(None),) * rest
elif isinstance(key, np.ndarray) or hasattr(key, "__array__"):
if key.ndim == 1:
new_axes = axes
else:
new_axes = [UndefAxis()] + axes[key.ndim:]
return new_axes
elif key is None:
return [UndefAxis()] + axes
elif key is ...:
return axes
else:
_keys = (key,) +(slice(None),) * (ndim - 1)
new_axes: list[Axis] = []
list_idx: list[int] = []
axes_iter = iter(axes)
for sl in _keys:
if sl is not None:
a = next(axes_iter)
if isinstance(sl, (slice, np.ndarray)):
new_axes.append(a.slice_axis(sl))
elif isinstance(sl, list):
new_axes.append(a.slice_axis(sl))
list_idx.append(a)
else:
new_axes.append(UndefAxis()) # new axis
if len(list_idx) > 1:
added = False
out: list[Axis] = []
for a in new_axes:
if a not in list_idx:
out.append(a)
elif not added:
out.append(UndefAxis())
added = True
new_axes = out
return new_axes