import copy
import cv2
from functools import partial
from typing import Optional, Tuple, Union, Annotated
import magicgui as mgui
from magicgui import magicgui
from magicgui.widgets import Container, PushButton, Widget, create_widget
from napari.layers import Shapes
import torch
from livecellx.core.single_cell import SingleCellTrajectoryCollection, SingleCellStatic
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from napari.layers import Shapes
from livecellx.livecell_logger import main_info, main_warning, main_debug
from livecellx.core import SingleCellTrajectory, SingleCellStatic
from livecellx.segment.ou_utils import create_ou_input_from_sc
from livecellx.segment.utils import find_contours_opencv, filter_contours_by_size
from livecellx.core.datasets import SingleImageDataset
def correct_sc_segment(
sc,
model,
create_ou_input_kwargs={
"padding_pixels": 50,
"dtype": float,
"remove_bg": False,
"one_object": True,
"scale": 0,
},
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray, torch.Tensor]:
import torch
from torchvision import transforms
from livecellx.model_zoo.segmentation.sc_correction_aux import CorrectSegNetAux
# padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale
input_transforms = transforms.Compose(
[
transforms.Resize(size=(412, 412)),
]
)
temp_sc = sc.copy()
new_contour = np.array(temp_sc.contour)
new_contour = new_contour[:, -2:] # remove slice index (time)
temp_sc.update_contour(new_contour)
temp_sc.update_bbox()
res_bbox = temp_sc.bbox
ou_input = create_ou_input_from_sc(temp_sc, **create_ou_input_kwargs)
# ou_input = create_ou_input_from_sc(self.sc, **create_ou_input_kwargs)
original_shape = ou_input.shape
# TODO: change to comply with the training data preparation
# for now we simply use one of the input types during training: raw_aug_duplicate.
# Please read sc_correction_dataset impl.
ou_input = input_transforms(torch.tensor([ou_input]))
ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
ou_input = ou_input.float().cuda()
back_transforms = transforms.Compose(
[
transforms.Resize(size=(original_shape[0], original_shape[1])),
]
)
seg_output, aux_output = None, None
if isinstance(model, CorrectSegNetAux):
model_output = model(ou_input)
seg_output, aux_output = model_output
else:
seg_output = model(ou_input)
seg_output = back_transforms(seg_output)
if not model.apply_gt_seg_edt:
seg_output = torch.sigmoid(seg_output)
return ou_input, seg_output, res_bbox, aux_output
[docs]
class ScSegOperator:
"""
A class for performing segmentation on single cell images.
Attributes
----------
viewer : napari.Viewer
The napari viewer.
sc : SingleCellStatic
The single cell static object.
shape_layer : napari.layers.Shapes
The napari shape layer for displaying the segmentation.
"""
MANUAL_CORRECT_SEG_MODE = 0
CSN_CORRECT_SEG_MODE = 1
DEFAULT_CSN_MODEL = None
[docs]
@staticmethod
def load_default_csn_model(path, cuda=True, has_aux=True):
import torch
from livecellx.model_zoo.segmentation.sc_correction import CorrectSegNet
from livecellx.model_zoo.segmentation.sc_correction_aux import CorrectSegNetAux
if has_aux:
model = CorrectSegNetAux.load_from_checkpoint(path)
else:
model = CorrectSegNet.load_from_checkpoint(path)
if cuda:
model.cuda()
model.eval()
ScSegOperator.DEFAULT_CSN_MODEL = model
return model
def __init__(
self,
sc: SingleCellStatic,
viewer,
shape_layer: Optional[Shapes] = None,
face_color=(0, 0, 1, 1),
magicgui_container: Optional[Container] = None,
csn_model=None,
create_sc_layer=True,
sct_observers: Optional[list] = None,
):
"""
Parameters
----------
viewer : napari.Viewer
The napari viewer.
sc : SingleCellStatic
The single cell static object.
"""
self.sc = sc
self.viewer = viewer
self.shape_layer = shape_layer
self.face_color = face_color
self.mode = self.MANUAL_CORRECT_SEG_MODE
self.magicgui_container = magicgui_container
self.csn_model = csn_model
self.sct_observers = sct_observers
if sct_observers is None:
self.sct_observers = []
if not (self.shape_layer is None):
self.setup_edit_contour_shape_layer()
if create_sc_layer:
self.create_sc_layer()
def __repr__(self) -> str:
return f"ScSegOperator(sc={self.sc}, mode={self.mode})"
[docs]
def create_sc_layer(self, name=None, contour_sample_num=100):
if name is None:
name = f"sc_{self.sc.id}"
shape_vec = self.sc.get_napari_shape_contour_vec(contour_sample_num=contour_sample_num)
shapes_data = [shape_vec]
is_dummy_shape = False
if len(shape_vec) == 0:
main_warning(f"sc {self.sc.id} has no contour (or contour list length is 0)")
# add a square shape with area = 16
tmp_contour = [[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]
tmp_shape_data = [[self.sc.timeframe] + coord for coord in tmp_contour]
shapes_data = [tmp_shape_data]
is_dummy_shape = True
properties = {"sc": [self.sc]}
shape_layer = self.viewer.add_shapes(
shapes_data if len(shapes_data) > 0 else None,
properties=properties,
face_color=[self.face_color],
shape_type="polygon",
name=name,
)
self.shape_layer = shape_layer
if is_dummy_shape:
# delete the dummy shape
self.shape_layer.data = []
self.setup_edit_contour_shape_layer()
print(">>> create sc layer done")
[docs]
def remove_sc_layer(self):
if self.shape_layer is None:
return
self.viewer.layers.remove(self.shape_layer)
self.shape_layer = None
[docs]
def update_shape_layer_by_sc(self, contour_sample_num=100):
shape_vec = self.sc.get_napari_shape_contour_vec(contour_sample_num=contour_sample_num)
self.shape_layer.data = [shape_vec]
[docs]
def correct_segment(self, model, create_ou_input_kwargs=None):
import torch
from torchvision import transforms
# padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale
temp_sc = self.sc.copy()
if create_ou_input_kwargs is None:
# Use default values
return correct_sc_segment(temp_sc, model)
else:
return correct_sc_segment(temp_sc, model, create_ou_input_kwargs=create_ou_input_kwargs)
[docs]
def replace_sc_contour(self, contour, padding_pixels=0, refresh=True):
self.sc.contour = contour + self.sc.bbox[:2] - padding_pixels
self.sc.update_bbox()
if refresh:
self.update_shape_layer_by_sc()
[docs]
def setup_edit_contour_shape_layer(self):
# [DEPRECATED] Ke did not find a way to make this work
# TODO: make sure only 1 shape in the shape layer...
return
# TODO
from copy import deepcopy
# Callback to check if shape_layer has more than one shape and remove the last one
self.saved_data = deepcopy(self.shape_layer.data)
def _shape_data_changed(event):
print("_shape_data_changed fired")
print("len of shape_layer.data:", len(self.shape_layer.data))
if len(self.shape_layer.data) > 1:
# self.shape_layer.events.data.disconnect(self._shape_data_changed) # disconnect the callback
print("[_shape_data_changed] len of saved_data:", len(self.saved_data))
self.shape_layer.data = deepcopy(self.saved_data)
# self.shape_layer.events.data.connect(self._shape_data_changed)
elif len(self.shape_layer.data) == 1:
self.saved_data = deepcopy(self.shape_layer.data)
# If the shape_layer already exists, connect the callback
if self.shape_layer is not None:
self.shape_layer.events.data.connect(_shape_data_changed)
[docs]
def notify_sct_to_update(self):
for observer in self.sct_observers:
observer.update_shape_layer_by_sc(self.sc)
[docs]
def notify_sct_to_remove_sc_operator(self):
for observer in self.sct_observers:
observer.remove_sc_operator(self)
@staticmethod
def _get_contours_from_shape_layer(layer: Shapes):
res_contours = []
for shape in layer.data:
vertices = np.array(shape)
# ignore the first vertex, which is the slice index
vertices = vertices[:, 1:3]
res_contours.append(vertices)
return res_contours
[docs]
def save_seg_callback(self, clip=True):
"""Save the segmentation to the single cell object."""
import napari
from PyQt5.QtWidgets import QMessageBox
from livecellx.core.utils import clip_polygon
print("<save_seg_callback fired>")
# Get the contour coordinates from the shape layer
contours = self._get_contours_from_shape_layer(self.shape_layer)
if len(contours) != 1:
message = "Warning: Expected 1 contour, found {}.".format(len(contours))
QMessageBox.warning(None, "Warning", message)
return
assert len(contours) > 0, "No contour is found in the shape layer."
contour = contours[0] # n x 2
# limit the contour coordinates to the image height and width
if clip:
main_info("Limiting the contour coordinates to the image height and width.", indent_level=2)
main_debug("contour before clipping:" + str(contour.shape), indent_level=2)
image_dim = self.sc.get_img_shape()
# Clipping algorithm
contour = clip_polygon(contour, image_dim[0], image_dim[1])
# Ensure the contour is within the image
contour[:, 0] = np.clip(contour[:, 0], 0, image_dim[0] - 1)
contour[:, 1] = np.clip(contour[:, 1], 0, image_dim[1] - 1)
# update the shape layer as well
main_info("Updating the shape layer of sc...", indent_level=2)
napari_vertices = [[self.sc.timeframe] + list(point) for point in contour]
napari_vertices = np.array(napari_vertices)
self.shape_layer.data = []
self.shape_layer.add([(napari_vertices, "polygon")], shape_type=["polygon"])
# Store the contour in the single cell object
self.sc.update_contour(contour)
# Notify the observers
print("<save_seg_callback> notifying sct operator to update the sc")
self.notify_sct_to_update()
print("<save_seg_callback finished>")
[docs]
def csn_correct_seg_callback(self, padding_pixels=50, threshold=0.5):
print("csn_correct_seg_callback fired")
if self.csn_model is None and ScSegOperator.DEFAULT_CSN_MODEL is None:
print("No CSN model is loaded. Please load a CSN model first.")
return
elif self.csn_model is None:
print("Using default CSN model and loading it to the operator...")
self.csn_model = ScSegOperator.DEFAULT_CSN_MODEL
create_ou_input_kwargs = {
"padding_pixels": padding_pixels,
"dtype": float,
"remove_bg": False,
"one_object": True,
"scale": 0,
}
model_ou_input, output, res_bbox, aux_output = self.correct_segment(
self.csn_model, create_ou_input_kwargs=create_ou_input_kwargs
)
bin_mask = output[0].cpu().detach().numpy()[0] > threshold
contours = find_contours_opencv(bin_mask.astype(bool))
# contour = [0]
new_shape_data = []
for contour in contours:
contour_in_original_image = contour + res_bbox[:2] - padding_pixels
# replace the current shape_layer's data with the new contour
napari_vertices = [[self.sc.timeframe] + list(point) for point in contour_in_original_image]
napari_vertices = np.array(napari_vertices)
new_shape_data.append((napari_vertices, "polygon"))
self.shape_layer.data = []
self.shape_layer.add(new_shape_data, shape_type=["polygon"])
print("csn_correct_seg_callback done!")
[docs]
def clear_sc_layer_callback(self):
self.shape_layer.data = []
print("clear_sc_layer_callback done!")
[docs]
def restore_sc_contour_callback(self):
self.update_shape_layer_by_sc()
print("restore_sc_contour_callback done!")
[docs]
def filter_cells_by_size_callback(self, min_size, max_size):
print("filter_cells_by_size_callback fired!")
contours = self._get_contours_from_shape_layer(self.shape_layer)
required_contours = filter_contours_by_size(contours, min_size, max_size)
time = self.sc.timeframe
new_shape_data = []
for contour in required_contours:
napari_vertices = [[time] + list(point) for point in contour]
napari_vertices = np.array(napari_vertices)
new_shape_data.append((napari_vertices, "polygon"))
self.shape_layer.data = []
self.shape_layer.add(new_shape_data, shape_type=["polygon"])
print("filter_cells_by_size_callback done!")
[docs]
def focus_on_sc_callback(self):
print("focus_on_sc_callback fired!")
self.viewer.dims.set_point(0, self.sc.timeframe)
print("focus_on_sc_callback done!")
[docs]
@staticmethod
def resample_contour(contour, sample_num=50, start_idx=None):
if start_idx is None:
start_idx = np.random.randint(0, len(contour))
if len(contour) == 0 or start_idx > len(contour):
main_info("The contour is empty or the start_idx is out of range.")
return contour
# rotate contour so that the start_idx is at the beginning
contour = np.roll(contour, -start_idx, axis=0)
slice_step = int(len(contour) / sample_num)
slice_step = max(slice_step, 1) # make sure slice_step is at least 1
if sample_num is not None:
contour = contour[::slice_step]
return contour
[docs]
def resample_contours_callback(self, sample_num):
print("resample_contours_callback fired!")
contours = self._get_contours_from_shape_layer(self.shape_layer)
resampled_contours = []
for contour in contours:
resampled_contours.append(self.resample_contour(contour, sample_num=sample_num))
time = self.sc.timeframe
new_shape_data = []
for contour in resampled_contours:
napari_vertices = [[time] + list(point) for point in contour]
napari_vertices = np.array(napari_vertices)
new_shape_data.append((napari_vertices, "polygon"))
self.shape_layer.data = []
self.shape_layer.add(new_shape_data, shape_type=["polygon"])
# select the newly added last contour
# The contours may all be empty? (double check) so we need to check the length first
if len(self.shape_layer.data) > 0:
self.shape_layer.selected_data = [len(self.shape_layer.data) - 1]
print("resample_contours_callback done!")
[docs]
def close(self):
# remove the shaper layer
self.viewer.layers.remove(self.shape_layer)
# self.magicgui_container.hide()
# self.magicgui_container.close()
if self.magicgui_container is not None:
try:
self.viewer.window.remove_dock_widget(self.magicgui_container.native)
except Exception as e:
main_warning("Exception when removing dock widget:", e)
self.notify_sct_to_remove_sc_operator()
[docs]
def create_sc_seg_napari_ui(sc_operator: ScSegOperator):
"""Usage
# viewer = napari.view_image(dic_dataset.to_dask(), name="dic_image", cache=True)
# shape_layer = NapariVisualizer.viz_trajectories(traj_collection, viewer, contour_sample_num=20)
# sct_operator = SctOperator(traj_collection, shape_layer, viewer)
# sct_operator.setup_shape_layer(shape_layer)
Parameters
----------
sct_operator : SctOperator
_description_
"""
@magicgui(call_button="save seg to sc")
def save_seg_to_sc():
print("[button] save callback fired!")
sc_operator.save_seg_callback()
@magicgui(call_button="auto correct seg")
def csn_correct_seg(
threshold: Annotated[float, {"widget_type": "FloatSpinBox", "max": int(1e4)}] = 0.5,
padding: Annotated[int, {"widget_type": "SpinBox", "max": int(1e4)}] = 50,
):
print("[button] csn callback fired!")
main_info("csn output threshold:" + str(threshold), indent_level=2)
sc_operator.csn_correct_seg_callback(threshold=threshold, padding_pixels=padding)
@magicgui(
auto_call=True,
mode={"choices": ["segmentation"]},
)
def switch_mode_widget(mode):
print("switch mode callback fired!")
print("mode changed to", mode)
if mode == "segmentation":
sc_operator.mode = sc_operator.MANUAL_CORRECT_SEG_MODE
sc_operator.show_selected_mode_widget()
@magicgui(call_button="clear sc layer")
def clear_sc_layer():
print("[button] clear sc layer callback fired!")
sc_operator.clear_sc_layer_callback()
@magicgui(call_button="restore sc contour")
def restore_sc_contour():
print("[button] restore sc contour callback fired!")
sc_operator.restore_sc_contour_callback()
@magicgui(call_button="filter cells by size")
def filter_cells_by_size(
lower: Annotated[int, {"widget_type": "SpinBox", "max": int(1e6)}] = 100,
upper: Annotated[int, {"widget_type": "SpinBox", "max": int(1e6)}] = 100000,
):
print("[button] filter cells by size callback fired!")
sc_operator.filter_cells_by_size_callback(lower, upper)
@magicgui(call_button="focus on sc")
def focus_on_sc():
print("[button] focus on sc callback fired!")
sc_operator.focus_on_sc_callback()
@magicgui(call_button=None)
def show_sc_id(sc_id="No SC"):
return
@magicgui(call_button="resample contours")
def resample_contours(sample_num: Annotated[int, {"widget_type": "SpinBox", "max": int(1e6)}] = 15):
print("[button] resample contours callback fired!")
sc_operator.resample_contours_callback(sample_num)
def on_close_callback():
print("on_close_callback fired!")
container = Container(
widgets=[
show_sc_id,
switch_mode_widget,
save_seg_to_sc,
csn_correct_seg,
clear_sc_layer,
restore_sc_contour,
filter_cells_by_size,
focus_on_sc,
resample_contours,
],
labels=False,
)
container.native.setParent(None)
container.native.deleteLater = lambda: on_close_callback()
show_sc_id.sc_id.value = str(sc_operator.sc.id)[:12] + "-..."
# hide call button
show_sc_id.call_button.hide()
sc_operator.magicgui_container = container
sc_operator.hide_function_widgets()
sc_operator.show_selected_mode_widget()
sc_operator.viewer.window.add_dock_widget(container, name="Sc Operator")