{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a cellpose to segment A549 cells \n", "Author: Ke \n", "Data source: Dr. Weikang Wang" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "cellpose.models.Cellpose" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from cellpose import models\n", "from cellpose.io import imread\n", "from pathlib import Path\n", "\n", "# model_type='cyto' or 'nuclei' or 'cyto2'\n", "model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n", "type(model)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# path for saving re-fitted cellpose model\n", "model_save_path = Path(\"./notebook_results/cellpose/cellpose_A549_cyto2\")\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Loading data for training models from CellPose" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_dirs = [\n", " # \"../datasets/wwk_train/A549_cellbody_seg_train/label/train\",\n", " \"../datasets/wwk_train/A549_seg_train/train\",\n", "]\n", "raw_img_dir = [Path(path) / \"Img\" for path in data_dirs]\n", "dist_img_dir = [Path(path) / \"Bwdist\" for path in data_dirs]\n", "mask_img_dir = [Path(path) / \"Interior\" for path in data_dirs]\n", "\n", "# check if paths exist\n", "for i in range(len(raw_img_dir)):\n", " assert raw_img_dir[i].exists(), f\"{raw_img_dir[i]} does not exist\"\n", " assert dist_img_dir[i].exists(), f\"{dist_img_dir[i]} does not exist\"\n", " assert mask_img_dir[i].exists(), f\"{mask_img_dir[i]} does not exist\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('../datasets/wwk_train/A549_seg_train/train/Img')]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_img_dir" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "raw_img_paths = [sorted(list(path.glob(\"*.tif\"))) for path in raw_img_dir]\n", "dist_img_paths = [sorted(list(path.glob(\"*.tif\"))) for path in dist_img_dir]\n", "mask_img_paths = [sorted(list(path.glob(\"*.png\"))) for path in mask_img_dir]\n", "\n", "# check existence of all images\n", "for i in range(len(raw_img_dir)):\n", " assert len(raw_img_paths[i]) == len(dist_img_paths[i]) == len(\n", " mask_img_paths[i]\n", " ), f\"Number of images in {raw_img_dir[i]}, {dist_img_dir[i]}, {mask_img_dir[i]} do not match, number of images: {len(raw_img_paths[i])}, {len(dist_img_paths[i])}, {len(mask_img_paths[i])}\"\n", "\n", "# flatten all lists\n", "raw_img_paths = [item for sublist in raw_img_paths for item in sublist]\n", "dist_img_paths = [item for sublist in dist_img_paths for item in sublist]\n", "mask_img_paths = [item for sublist in mask_img_paths for item in sublist]\n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# read images\n", "raw_imgs = [imread(str(path)) for path in raw_img_paths]\n", "dist_imgs = [imread(str(path)) for path in dist_img_paths]\n", "mask_imgs = [imread(str(path)) for path in mask_img_paths]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(767, 767, 767)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(raw_imgs), len(dist_imgs), len(mask_imgs)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# squeeze images\n", "raw_imgs = [img.squeeze() for img in raw_imgs]\n", "dist_imgs = [img.squeeze() for img in dist_imgs]\n", "mask_imgs = [img.squeeze() for img in mask_imgs]\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Check image shape match" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(len(raw_imgs)):\n", " assert (\n", " raw_imgs[i].shape == dist_imgs[i].shape == mask_imgs[i].shape\n", " ), f\"Image shapes do not match for image {i}, {raw_imgs[i].shape}, {dist_imgs[i].shape}, {mask_imgs[i].shape}\"" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(767, 767, 767)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(raw_imgs), len(dist_imgs), len(mask_imgs)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Note the following assumptions \n", " when Dr. WWK annotated datasets, he intentionally avoid overlapping masks, so we can obtain label masks simply by label()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from livecellx.preprocess.utils import normalize_img_to_uint8\n", "# normalize images\n", "raw_imgs = [normalize_img_to_uint8(img) for img in raw_imgs]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of empty mask labels: 29 total number of images: 767\n" ] } ], "source": [ "import skimage\n", "label_mask_imgs = [skimage.measure.label(mask_img) for mask_img in mask_imgs]\n", "\n", "# counter how many mask labels are empty\n", "empty_mask_label_count = 0\n", "for i in range(len(label_mask_imgs)):\n", " if len(np.unique(label_mask_imgs[i])) <= 1:\n", " empty_mask_label_count += 1\n", " # show image and label mask\n", " # fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", " # ax[0].imshow(raw_imgs[i])\n", " # ax[1].imshow(label_mask_imgs[i])\n", "print(f\"Number of empty mask labels: {empty_mask_label_count}\", \"total number of images:\", len(label_mask_imgs))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 6%|▌ | 45/767 [00:02<00:38, 18.68it/s]empty masks!\n", " 7%|▋ | 56/767 [00:03<00:27, 26.12it/s]empty masks!\n", " 9%|▉ | 69/767 [00:03<00:28, 24.59it/s]empty masks!\n", " 11%|█ | 85/767 [00:04<00:25, 26.46it/s]empty masks!\n", " 14%|█▎ | 104/767 [00:05<00:36, 18.33it/s]empty masks!\n", " 14%|█▍ | 107/767 [00:05<00:32, 20.41it/s]empty masks!\n", " 26%|██▌ | 196/767 [00:11<00:28, 20.01it/s]empty masks!\n", "empty masks!\n", " 28%|██▊ | 217/767 [00:12<00:28, 19.06it/s]empty masks!\n", " 29%|██▉ | 224/767 [00:12<00:28, 18.82it/s]empty masks!\n", " 32%|███▏ | 242/767 [00:13<00:25, 20.36it/s]empty masks!\n", " 34%|███▍ | 259/767 [00:14<00:22, 22.49it/s]empty masks!\n", " 34%|███▍ | 262/767 [00:14<00:21, 23.74it/s]empty masks!\n", " 41%|████ | 316/767 [00:17<00:31, 14.45it/s]empty masks!\n", "empty masks!\n", " 42%|████▏ | 325/767 [00:17<00:22, 19.62it/s]empty masks!\n", " 44%|████▍ | 340/767 [00:18<00:20, 20.45it/s]empty masks!\n", " 45%|████▍ | 345/767 [00:18<00:17, 24.18it/s]empty masks!\n", "empty masks!\n", " 50%|████▉ | 383/767 [00:20<00:29, 12.95it/s]" ] } ], "source": [ "model.sz.cp.train(train_data=raw_imgs, train_labels=label_mask_imgs, batch_size=4, channels=[0,0], n_epochs=500, save_path=model_save_path)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Randomly show 10 prediction samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from livecellx.segment.cellpose_utils import segment_single_images_by_cellpose\n", "\n", "for _ in range(3):\n", " index = np.random.randint(0, len(raw_imgs))\n", " masks = segment_single_images_by_cellpose(raw_imgs[index], model, channels=[[0, 0]], diameter=100)\n", " print(\"masks shape: \", masks.shape)\n", " fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", " axes[0].imshow(raw_imgs[index][0])\n", " axes[0].set_title(\"raw image\")\n", " axes[1].imshow(masks)\n", " axes[1].set_title(\"cellpose mask\")\n", " axes[2].imshow(label_mask_imgs[index])\n", " axes[2].set_title(\"label mask\")\n", " plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Predict with the model trained on your own data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset\n", "dataset_dir_path = Path(\n", " \"../datasets/test_data_STAV-A549/DIC_data\"\n", ")\n", "\n", "mask_dataset_path = Path(\"../datasets/test_data_STAV-A549/mask_data\")\n", "dic_dataset = LiveCellImageDataset(dataset_dir_path, ext=\"tif\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from livecellx.preprocess.utils import enhance_contrast\n", "for i in range(0, len(dic_dataset), 1):\n", " img = dic_dataset[i]\n", " img = normalize_img_to_uint8(img)\n", " masks = segment_single_images_by_cellpose(img, model, channels=[[0, 0]], diameter=150)\n", " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", " axes[0].imshow(enhance_contrast(img))\n", " axes[0].set_title(\"raw image\")\n", " axes[1].imshow(masks)\n", " axes[1].set_title(\"cellpose mask\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "original_pretrained_model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n", "for i in range(0, len(dic_dataset), 1):\n", " img = dic_dataset[i]\n", " img = normalize_img_to_uint8(img)\n", " masks = segment_single_images_by_cellpose(img, original_pretrained_model, channels=[[0, 0]], diameter=80)\n", " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", " axes[0].imshow(enhance_contrast(img))\n", " axes[0].set_title(\"raw image\")\n", " axes[1].imshow(masks)\n", " axes[1].set_title(\"cellpose mask\")\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "livecell", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "d9c7226d793827cd27273ad20fbb4775c3cb91053ab9378a09de5f8c6f258919" } } }, "nbformat": 4, "nbformat_minor": 2 }