Source code for livecellx.core.pl_utils

from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from typing import List
from livecellx.core.single_cell import SingleCellStatic
from livecellx.core.utils import crop_or_pad_img
from livecellx.livecell_logger import main_info, main_warning


[docs] def add_colorbar(im, ax, fig): divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="3%", pad=0.05) fig.colorbar(im, cax=cax, orientation="vertical")
def viz_embedding_region( embedding, scs: List[SingleCellStatic], padding=40, x_range=(float("-inf"), float("inf")), y_range=(float("-inf"), float("inf")), title="Single cell in Embedding Space", max_crops=10, fix_dims=None, randomly_select=True, sort_by_x=True, dpi=300, show_mask=False, ): _crops = [] mask_crops = [] _crop_coords = [] if fix_dims is not None: main_info(f"Setting: fix_dims={fix_dims}") indices = range(len(scs)) if randomly_select: indices = np.random.choice(len(scs), len(scs), replace=False) for idx in indices: _sc = scs[idx] # print("embedding[idx]", embedding[idx]) # Embedding in range if x_range[0] <= embedding[idx][0] <= x_range[1] and y_range[0] <= embedding[idx][1] <= y_range[1]: # _sc.show(crop=True, padding=20) # plt.show() if show_mask: mask_crop = _sc.get_mask_crop(padding=padding) img_crop = _sc.get_img_crop(padding=padding) if not (fix_dims is None): img_crop = crop_or_pad_img(img_crop, fix_dims=fix_dims) if show_mask: mask_crop = crop_or_pad_img(mask_crop, fix_dims=fix_dims) _crops.append(img_crop) if show_mask: mask_crops.append(mask_crop) _crop_coords.append(embedding[idx]) if len(_crops) >= max_crops: break if len(_crops) == 0: main_info(f"No single cell found in the embedding space") return if sort_by_x: main_info(f"Sorting crops by x") sort_indices = np.argsort(np.array(_crop_coords)[:, 0]) _crops = [_crops[i] for i in sort_indices] if show_mask: mask_crops = [mask_crops[i] for i in sort_indices] _crop_coords = [_crop_coords[i] for i in sort_indices] # Visualize crops on one orw import matplotlib.pyplot as plt if show_mask: fig, axes = plt.subplots(2, len(_crops), figsize=(len(_crops) * 4, 10), dpi=dpi) else: fig, axes = plt.subplots(1, len(_crops), figsize=(len(_crops) * 4, 5), dpi=dpi) axes = [axes] for idx in range(len(_crops)): axes[0][idx].imshow(_crops[idx]) axes[0][idx].axis("off") if show_mask: axes[1][idx].imshow(mask_crops[idx]) axes[1][idx].axis("off") fig.suptitle(title, fontsize=20) fig.tight_layout(rect=[0, 0.03, 1, 0.95], w_pad=0, h_pad=0) return fig, axes