Train a cellpose to segment A549 cells

Author: Ke
Data source: Dr. Weikang Wang
[1]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
from pathlib import Path

# model_type='cyto' or 'nuclei' or 'cyto2'
model = models.Cellpose(gpu=True, model_type="cyto2")
type(model)
[1]:
cellpose.models.Cellpose
[2]:
# path for saving re-fitted cellpose model
model_save_path = Path("./notebook_results/cellpose/cellpose_A549_cyto2")

Loading data for training models from CellPose

[ ]:
data_dirs = [
    # "../datasets/wwk_train/A549_cellbody_seg_train/label/train",
    "../datasets/wwk_train/A549_seg_train/train",
]
raw_img_dir = [Path(path) / "Img" for path in data_dirs]
dist_img_dir = [Path(path) / "Bwdist" for path in data_dirs]
mask_img_dir = [Path(path) / "Interior" for path in data_dirs]

# check if paths exist
for i in range(len(raw_img_dir)):
    assert raw_img_dir[i].exists(), f"{raw_img_dir[i]} does not exist"
    assert dist_img_dir[i].exists(), f"{dist_img_dir[i]} does not exist"
    assert mask_img_dir[i].exists(), f"{mask_img_dir[i]} does not exist"
[3]:
raw_img_dir
[3]:
[PosixPath('../datasets/wwk_train/A549_seg_train/train/Img')]
[4]:
raw_img_paths = [sorted(list(path.glob("*.tif"))) for path in raw_img_dir]
dist_img_paths = [sorted(list(path.glob("*.tif"))) for path in dist_img_dir]
mask_img_paths = [sorted(list(path.glob("*.png"))) for path in mask_img_dir]

# check existence of all images
for i in range(len(raw_img_dir)):
    assert len(raw_img_paths[i]) == len(dist_img_paths[i]) == len(
        mask_img_paths[i]
    ), f"Number of images in {raw_img_dir[i]}, {dist_img_dir[i]}, {mask_img_dir[i]} do not match, number of images: {len(raw_img_paths[i])}, {len(dist_img_paths[i])}, {len(mask_img_paths[i])}"

# flatten all lists
raw_img_paths = [item for sublist in raw_img_paths for item in sublist]
dist_img_paths = [item for sublist in dist_img_paths for item in sublist]
mask_img_paths = [item for sublist in mask_img_paths for item in sublist]
[5]:
# read images
raw_imgs = [imread(str(path)) for path in raw_img_paths]
dist_imgs = [imread(str(path)) for path in dist_img_paths]
mask_imgs = [imread(str(path)) for path in mask_img_paths]
[6]:
len(raw_imgs), len(dist_imgs), len(mask_imgs)
[6]:
(767, 767, 767)
[7]:
# squeeze images
raw_imgs = [img.squeeze() for img in raw_imgs]
dist_imgs = [img.squeeze() for img in dist_imgs]
mask_imgs = [img.squeeze() for img in mask_imgs]

Check image shape match

[ ]:
for i in range(len(raw_imgs)):
    assert (
        raw_imgs[i].shape == dist_imgs[i].shape == mask_imgs[i].shape
    ), f"Image shapes do not match for image {i}, {raw_imgs[i].shape}, {dist_imgs[i].shape}, {mask_imgs[i].shape}"
[8]:
len(raw_imgs), len(dist_imgs), len(mask_imgs)
[8]:
(767, 767, 767)
Note the following assumptions
when Dr. WWK annotated datasets, he intentionally avoid overlapping masks, so we can obtain label masks simply by label()
[9]:
from livecellx.preprocess.utils import normalize_img_to_uint8
# normalize images
raw_imgs = [normalize_img_to_uint8(img) for img in raw_imgs]
[10]:
import skimage
label_mask_imgs = [skimage.measure.label(mask_img) for mask_img in mask_imgs]

# counter how many mask labels are empty
empty_mask_label_count = 0
for i in range(len(label_mask_imgs)):
    if len(np.unique(label_mask_imgs[i])) <= 1:
        empty_mask_label_count += 1
        # show image and label mask
        # fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        # ax[0].imshow(raw_imgs[i])
        # ax[1].imshow(label_mask_imgs[i])
print(f"Number of empty mask labels: {empty_mask_label_count}", "total number of images:", len(label_mask_imgs))
Number of empty mask labels: 29 total number of images: 767
[11]:
model.sz.cp.train(train_data=raw_imgs, train_labels=label_mask_imgs, batch_size=4, channels=[0,0], n_epochs=500, save_path=model_save_path)
  6%|▌         | 45/767 [00:02<00:38, 18.68it/s]empty masks!
  7%|▋         | 56/767 [00:03<00:27, 26.12it/s]empty masks!
  9%|▉         | 69/767 [00:03<00:28, 24.59it/s]empty masks!
 11%|█         | 85/767 [00:04<00:25, 26.46it/s]empty masks!
 14%|█▎        | 104/767 [00:05<00:36, 18.33it/s]empty masks!
 14%|█▍        | 107/767 [00:05<00:32, 20.41it/s]empty masks!
 26%|██▌       | 196/767 [00:11<00:28, 20.01it/s]empty masks!
empty masks!
 28%|██▊       | 217/767 [00:12<00:28, 19.06it/s]empty masks!
 29%|██▉       | 224/767 [00:12<00:28, 18.82it/s]empty masks!
 32%|███▏      | 242/767 [00:13<00:25, 20.36it/s]empty masks!
 34%|███▍      | 259/767 [00:14<00:22, 22.49it/s]empty masks!
 34%|███▍      | 262/767 [00:14<00:21, 23.74it/s]empty masks!
 41%|████      | 316/767 [00:17<00:31, 14.45it/s]empty masks!
empty masks!
 42%|████▏     | 325/767 [00:17<00:22, 19.62it/s]empty masks!
 44%|████▍     | 340/767 [00:18<00:20, 20.45it/s]empty masks!
 45%|████▍     | 345/767 [00:18<00:17, 24.18it/s]empty masks!
empty masks!
 50%|████▉     | 383/767 [00:20<00:29, 12.95it/s]

Randomly show 10 prediction samples

[ ]:
from livecellx.segment.cellpose_utils import segment_single_images_by_cellpose

for _ in range(3):
    index = np.random.randint(0, len(raw_imgs))
    masks = segment_single_images_by_cellpose(raw_imgs[index], model, channels=[[0, 0]], diameter=100)
    print("masks shape: ", masks.shape)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(raw_imgs[index][0])
    axes[0].set_title("raw image")
    axes[1].imshow(masks)
    axes[1].set_title("cellpose mask")
    axes[2].imshow(label_mask_imgs[index])
    axes[2].set_title("label mask")
    plt.show()

Predict with the model trained on your own data

[ ]:
from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset
dataset_dir_path = Path(
    "../datasets/test_data_STAV-A549/DIC_data"
)

mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
dic_dataset = LiveCellImageDataset(dataset_dir_path, ext="tif")
[ ]:
from livecellx.preprocess.utils import enhance_contrast
for i in range(0, len(dic_dataset), 1):
    img = dic_dataset[i]
    img = normalize_img_to_uint8(img)
    masks = segment_single_images_by_cellpose(img, model, channels=[[0, 0]], diameter=150)
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(enhance_contrast(img))
    axes[0].set_title("raw image")
    axes[1].imshow(masks)
    axes[1].set_title("cellpose mask")
    plt.show()
[ ]:
original_pretrained_model = models.Cellpose(gpu=True, model_type="cyto2")
for i in range(0, len(dic_dataset), 1):
    img = dic_dataset[i]
    img = normalize_img_to_uint8(img)
    masks = segment_single_images_by_cellpose(img, original_pretrained_model, channels=[[0, 0]], diameter=80)
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(enhance_contrast(img))
    axes[0].set_title("raw image")
    axes[1].imshow(masks)
    axes[1].set_title("cellpose mask")
    plt.show()