import datetime
from livecellx.core.single_cell import SingleCellStatic, SingleCellTrajectory, SingleCellTrajectoryCollection
import numpy as np
from napari.viewer import Viewer
from livecellx.livecell_logger import main_info
from livecellx.plot.visualizer import Visualizer
[docs]
class NapariVisualizer:
[docs]
def viz_traj(traj: SingleCellTrajectory, viewer: Viewer, viewer_kwargs=None):
if viewer_kwargs is None:
viewer_kwargs = dict()
shapes = traj.get_scs_napari_shapes()
shape_layer = viewer.add_shapes(shapes, **viewer_kwargs)
return shape_layer
[docs]
@staticmethod
def map_colors(values, cmap="viridis"):
import matplotlib
import matplotlib.cm as cm
if values is None or len(values) == 0:
return []
minima = min(values)
maxima = max(values)
norm = matplotlib.colors.Normalize(vmin=minima, vmax=maxima, clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
res_colors = [mapper.to_rgba(v) for v in values]
return res_colors
[docs]
def gen_trajectories_shapes(
trajectories: SingleCellTrajectoryCollection,
viewer: Viewer,
bbox=False,
contour_sample_num=100,
viewer_kwargs=None,
text_parameters={
"string": "{track_id:0.0f}\n{status}",
"size": 12,
"color": "white",
"anchor": "center",
"translation": [-2, 0],
},
):
if viewer_kwargs is None:
viewer_kwargs = dict()
all_shapes = []
track_ids = []
all_scs = []
all_status = []
for track_id, traj in trajectories:
traj_shapes, scs = traj.get_scs_napari_shapes(
bbox=bbox, contour_sample_num=contour_sample_num, return_scs=True
)
all_shapes.extend(traj_shapes)
track_ids.extend([int(track_id)] * len(traj_shapes))
all_scs.extend(scs)
all_status.extend([""] * len(traj_shapes))
properties = {"track_id": track_ids, "sc": all_scs, "status": all_status}
# Track ID can be UUID, so we need to map it to an integer
track_value_indices = [idx for idx, v in enumerate(track_ids)]
main_info(f"Number of trajectories: {len(trajectories)}", indent_level=2)
main_info("Calling viewer.add_shapes to add trajectories to napari", indent_level=2)
# Record running time
start_time = datetime.datetime.now()
shape_layer = viewer.add_shapes(
all_shapes,
properties=properties,
face_color=NapariVisualizer.map_colors(track_value_indices),
face_colormap="viridis",
shape_type="polygon",
text=text_parameters,
name="trajectories",
**viewer_kwargs,
)
end_time = datetime.datetime.now()
# Report time in seconds
main_info(f"Time to add shapes: {(end_time - start_time).total_seconds()}", indent_level=2)
return shape_layer