from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
import numpy as np
import napari
import os
from ..arrays import *
from .._const import Const
from ..core import imread, lazy_imread
if TYPE_CHECKING:
from ..frame import TrackFrame, PathFrame
[docs]def copy_layer(layer):
args, kwargs, *_ = layer.as_layer_data_tuple()
# linear interpolation is valid only in 3D mode.
if kwargs.get("interpolation", None) == "linear":
kwargs = kwargs.copy()
kwargs["interpolation"] = "nearest"
# This is necessarry for text bound layers.
kwargs.pop("properties", None)
kwargs.pop("property_choices", None)
copy = layer.__class__(args, **kwargs)
return copy
[docs]def iter_layer(viewer:"napari.Viewer", layer_type:str):
"""
Iterate over layers and yield only certain type of layers.
Parameters
----------
layer_type : str, {"shape", "image", "point"}
Type of layer.
Yields
-------
napari.layers
Layers specified by layer_type
"""
if isinstance(layer_type, str):
layer_type = [layer_type]
layer_type = tuple(getattr(napari.layers, t) for t in layer_type)
for layer in viewer.layers:
if isinstance(layer, layer_type):
yield layer
[docs]def iter_selected_layer(viewer:"napari.Viewer", layer_type:str|list[str]):
if isinstance(layer_type, str):
layer_type = [layer_type]
layer_type = tuple(getattr(napari.layers, t) for t in layer_type)
for layer in viewer.layers.selection:
if isinstance(layer, layer_type):
yield layer
[docs]def front_image(viewer:"napari.Viewer"):
"""
From list of image layers return the most front visible image.
"""
front = None
for img in iter_layer(viewer, "Image"):
if img.visible:
front = img # This is ImgArray
if front is None:
raise ValueError("There is no visible image layer.")
return front
[docs]def to_labels(layer:napari.layers.Shapes, labels_shape, zoom_factor=1):
return layer._data_view.to_labels(labels_shape=labels_shape, zoom_factor=zoom_factor)
[docs]def make_world_scale(obj):
scale = []
for a in obj._axes:
if a in "zyx":
scale.append(obj.scale[a])
elif a == "c":
pass
else:
scale.append(1)
return scale
[docs]def upon_add_layer(event):
try:
new_layer = event.sources[0][-1]
except IndexError:
return None
new_layer.translate = new_layer.translate.astype(np.float64)
if isinstance(new_layer, napari.layers.Shapes):
_text_bound_init(new_layer)
new_layer._rotation_handle_length = 20/np.mean(new_layer.scale[-2:])
@new_layer.bind_key("Left", overwrite=True)
def left(layer):
_translate_shape(layer, -1, -1)
@new_layer.bind_key("Right", overwrite=True)
def right(layer):
_translate_shape(layer, -1, 1)
@new_layer.bind_key("Up", overwrite=True)
def up(layer):
_translate_shape(layer, -2, -1)
@new_layer.bind_key("Down", overwrite=True)
def down(layer):
_translate_shape(layer, -2, 1)
elif isinstance(new_layer, napari.layers.Points):
_text_bound_init(new_layer)
new_layer.metadata["init_translate"] = new_layer.translate.copy()
new_layer.metadata["init_scale"] = new_layer.scale.copy()
return None
[docs]def image_tuple(input: "napari.layers.Image", out: ImgArray, translate="inherit", **kwargs):
data = input.data
scale = make_world_scale(data)
if out.dtype.kind == "c":
out = np.abs(out)
contrast_limits = [float(x) for x in out.range]
if data.ndim == out.ndim:
if isinstance(translate, str) and translate == "inherit":
translate = input.translate
elif data.ndim > out.ndim:
if isinstance(translate, str) and translate == "inherit":
translate = [input.translate[i] for i in range(data.ndim) if data.axes[i] in out.axes]
scale = [scale[i] for i in range(data.ndim) if data.axes[i] in out.axes]
else:
if isinstance(translate, str) and translate == "inherit":
translate = [0.0] + list(input.translate)
scale = [1.0] + list(scale)
kw = dict(scale=scale, colormap=input.colormap, translate=translate,
blending=input.blending, contrast_limits=contrast_limits)
kw.update(kwargs)
return (out, kw, "image")
[docs]def label_tuple(input: "napari.layers.Labels", out: Label, translate="inherit", **kwargs):
data = input.data
scale = make_world_scale(data)
if isinstance(translate, str) and translate == "inherit":
translate = input.translate
kw = dict(opacity=0.3, scale=scale, translate=translate)
kw.update(kwargs)
return (out, kw, "labels")
def _translate_shape(layer, ind, direction):
data = layer.data
selected = layer.selected_data
for i in selected:
data[i][:, ind] += direction
layer.data = data
layer.selected_data = selected
layer._set_highlight()
return None
def _text_bound_init(new_layer):
@new_layer.bind_key("Alt-A", overwrite=True)
def select_all(layer):
layer.selected_data = set(np.arange(len(layer.data)))
layer._set_highlight()
@new_layer.bind_key("Control-Shift-<", overwrite=True)
def size_down(layer):
if layer.text.size > 4:
layer.text.size -= 1.0
else:
layer.text.size *= 0.8
@new_layer.bind_key("Control-Shift->", overwrite=True)
def size_up(layer):
if layer.text.size < 4:
layer.text.size += 1.0
else:
layer.text.size /= 0.8
return None
[docs]def viewer_imread(viewer:"napari.Viewer", path:str):
if "*" in path or os.path.getsize(path)/1e9 < Const["MAX_GB"]:
img = imread(path)
else:
img = lazy_imread(path)
layer = add_labeledarray(viewer, img)
viewer.text_overlay.font_size = 4 * Const["FONT_SIZE_FACTOR"]
viewer.text_overlay.visible = True
viewer.text_overlay.color = "white"
viewer.text_overlay.text = repr(img)
return layer
[docs]def add_labeledarray(viewer:"napari.Viewer", img:LabeledArray, **kwargs):
if not img.axes.is_sorted() and img.ndim > 2:
msg = f"Input image has axes that are not correctly sorted: {img.axes}. "\
"This may cause unexpected results."
warnings.warn(msg, UserWarning)
chn_ax = img.axisof("c") if "c" in img.axes else None
if isinstance(img, PhaseArray) and not "colormap" in kwargs.keys():
kwargs["colormap"] = "hsv"
kwargs["contrast_limits"] = img.border
elif img.dtype.kind == "c" and not "colormap" in kwargs.keys():
kwargs["colormap"] = "plasma"
scale = make_world_scale(img)
if "name" in kwargs:
name = kwargs.pop("name")
else:
name = "No-Name" if img.name is None else img.name
if chn_ax is not None:
name = [f"[C{i}]{name}" for i in range(img.shape.c)]
else:
name = [name]
if img.dtype.kind == "c":
img = np.abs(img)
layer = viewer.add_image(img, channel_axis=chn_ax, scale=scale,
name=name if len(name)>1 else name[0],
**kwargs)
if viewer.scale_bar.unit:
if viewer.scale_bar.unit != img.scale_unit:
msg = f"Incompatible scales. Viewer is {viewer.scale_bar.unit} while image is {img.scale_unit}."
warnings.warn(msg)
else:
viewer.scale_bar.unit = img.scale_unit
new_axes = [a for a in img.axes if a != "c"]
# add axis labels to slide bars and image orientation.
if len(new_axes) >= len(viewer.dims.axis_labels):
viewer.dims.axis_labels = new_axes
return layer
[docs]def add_labels(viewer:"napari.Viewer", labels:Label, opacity:float=0.3, name:str|list[str]=None,
**kwargs):
scale = make_world_scale(labels)
# prepare label list
if "c" in labels.axes:
lbls = labels.split("c")
else:
lbls = [labels]
# prepare name list
if isinstance(name, list):
names = [f"[L]{n}" for n in name]
elif isinstance(name, str):
names = [f"[L]{name}"] * len(lbls)
else:
names = [labels.name]
kw = dict(opacity=opacity, scale=scale)
kw.update(kwargs)
out_layers = []
for lbl, name in zip(lbls, names):
layer = viewer.add_labels(lbl.value, name=name, **kw)
out_layers.append(layer)
return out_layers
[docs]def add_dask(viewer:"napari.Viewer", img:LazyImgArray, **kwargs):
chn_ax = img.axisof("c") if "c" in img.axes else None
scale = make_world_scale(img)
if "contrast_limits" not in kwargs.keys():
# contrast limits should be determined quickly.
leny, lenx = img.shape[-2:]
sample = img.value[..., ::leny//min(10, leny), ::lenx//min(10, lenx)]
kwargs["contrast_limits"] = [float(sample.min().compute()),
float(sample.max().compute())]
name = "No-Name" if img.name is None else img.name
if chn_ax is not None:
name = [f"[Lazy][C{i}]{name}" for i in range(img.shape.c)]
else:
name = ["[Lazy]" + name]
layer = viewer.add_image(img, channel_axis=chn_ax, scale=scale,
name=name if len(name)>1 else name[0], **kwargs)
viewer.scale_bar.unit = img.scale_unit
new_axes = [a for a in img.axes if a != "c"]
# add axis labels to slide bars and image orientation.
if len(new_axes) >= len(viewer.dims.axis_labels):
viewer.dims.axis_labels = new_axes
return layer
[docs]def add_points(viewer:"napari.Viewer", points, **kwargs):
from ..frame import MarkerFrame
if isinstance(points, MarkerFrame):
scale = make_world_scale(points)
points = points.get_coords()
else:
scale=None
if "c" in points._axes:
pnts = points.split("c")
else:
pnts = [points]
for each in pnts:
metadata = {"axes": str(each._axes), "scale": each.scale}
kw = dict(size=3.2, face_color=[0,0,0,0], metadata=metadata, edge_color=viewer.window.cmap())
kw.update(kwargs)
viewer.add_points(each.values, scale=scale, **kw)
return None
[docs]def add_tracks(viewer:"napari.Viewer", track:TrackFrame, **kwargs):
if "c" in track._axes:
track_list = track.split("c")
else:
track_list = [track]
scale = make_world_scale(track[[a for a in track._axes if a != Const["ID_AXIS"]]])
for tr in track_list:
metadata = {"axes": str(tr._axes), "scale": tr.scale}
viewer.add_tracks(tr, scale=scale, metadata=metadata, **kwargs)
return None
[docs]def add_paths(viewer:"napari.Viewer", paths:PathFrame, **kwargs):
if "c" in paths._axes:
path_list = paths.split("c")
else:
path_list = [paths]
scale = make_world_scale(paths[[a for a in paths._axes if a != Const["ID_AXIS"]]])
kw = {"edge_color":"lime", "edge_width":0.3, "shape_type":"path"}
kw.update(kwargs)
for path in path_list:
metadata = {"axes": str(path._axes), "scale": path.scale}
paths = [single_path.values for single_path in path.split(Const["ID_AXIS"])]
viewer.add_shapes(paths, scale=scale, metadata=metadata, **kw)
return None
[docs]def add_table(viewer:"napari.Viewer", data=None, columns=None, name=None):
from .widgets import TableWidget
table = TableWidget(viewer, data, columns=columns, name=name)
viewer.window.add_dock_widget(table, area="right", name=table.name)
return table
[docs]def get_viewer_scale(viewer:"napari.Viewer"):
return {a: r[2] for a, r in zip(viewer.dims.axis_labels, viewer.dims.range)}
[docs]def layer_to_impy_object(viewer:"napari.Viewer", layer):
"""
Convert layer to real data.
Parameters
----------
layer : napari.layers.Layer
Input layer.
Returns
-------
ImgArray, Label, MarkerFrame or TrackFrame, or Shape features.
"""
data = layer.data
axes = "".join(viewer.dims.axis_labels)
scale = get_viewer_scale(viewer)
if isinstance(layer, (napari.layers.Image, napari.layers.Labels)):
# manually drawn ones are np.ndarray, need conversion
if type(data) is np.ndarray:
ndim = data.ndim
axes = axes[-ndim:]
if isinstance(layer, napari.layers.Image):
data = ImgArray(data, name=layer.name, axes=axes, dtype=layer.data.dtype)
else:
try:
data = layer.metadata["destination_image"].labels
except (KeyError, AttributeError):
data = Label(data, name=layer.name, axes=axes)
data.set_scale({k: v for k, v in scale.items() if k in axes})
return data
elif isinstance(layer, napari.layers.Shapes):
return data
elif isinstance(layer, napari.layers.Points):
from ..frame import MarkerFrame
ndim = data.shape[1]
axes = axes[-ndim:]
df = MarkerFrame(data, columns=layer.metadata.get("axes", axes))
df.set_scale(layer.metadata.get("scale",
{k: v for k, v in scale.items() if k in axes}))
return df.as_standard_type()
elif isinstance(layer, napari.layers.Tracks):
from ..frame import TrackFrame
ndim = data.shape[1]
axes = axes[-ndim:]
df = TrackFrame(data, columns=layer.metadata.get("axes", axes))
df.set_scale(layer.metadata.get("scale",
{k: v for k, v in scale.items() if k in axes}))
return df.as_standard_type()
else:
raise NotImplementedError(type(layer))
[docs]def get_a_selected_layer(viewer:"napari.Viewer"):
selected = list(viewer.layers.selection)
if len(selected) == 0:
raise ValueError("No layer is selected.")
elif len(selected) > 1:
raise ValueError("More than one layers are selected.")
return selected[0]
[docs]def crop_rotated_rectangle(img:LabeledArray, crds:np.ndarray, dims="yx"):
translate = np.min(crds, axis=0)
# check is sorted
ids = [img.axisof(a) for a in dims]
if sorted(ids) == ids:
cropped_img = img.rotated_crop(crds[1], crds[0], crds[2], dims=dims)
else:
crds = np.fliplr(crds)
cropped_img = img.rotated_crop(crds[3], crds[0], crds[2], dims=dims)
return cropped_img, translate
[docs]def crop_rectangle(img:LabeledArray, crds:np.ndarray, dims="yx") -> tuple[LabeledArray, np.ndarray]:
start = crds[0]
end = crds[2]
sl = []
translate = np.empty(2)
for i in [0, 1]:
sl0 = sorted([start[i], end[i]])
x0 = max(int(sl0[0]), 0)
x1 = min(int(sl0[1]), img.sizeof(dims[i]))
sl.append(f"{dims[i]}={x0}:{x1}")
translate[i] = x0
area_to_crop = ";".join(sl)
cropped_img = img[area_to_crop]
return cropped_img, translate
[docs]class ColorCycle:
def __init__(self, cmap="rainbow") -> None:
import matplotlib.pyplot as plt
self.cmap = plt.get_cmap(cmap, 16)
self.color_id = 0
def __call__(self):
"""return next colormap"""
self.color_id += 1
return list(self.cmap(self.color_id * (self.cmap.N//2+1) % self.cmap.N))