Source code for livecellx.core.datasets

import argparse
import glob
import gzip
import json
import os.path
import sys
import time
from collections import deque
from datetime import timedelta
from pathlib import Path, PurePosixPath, WindowsPath, PureWindowsPath
from typing import Callable, List, Dict, Mapping, Union

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch import Tensor
from torch.nn import init
from torch.utils.data import DataLoader, random_split
import uuid

from livecellx.livecell_logger import main_debug


[docs] def read_img_default(url: str, **kwargs) -> np.ndarray: img = Image.open(url) img = np.array(img) return img
def interpolate_time_data(time2data, max_time=None, fill_mode="blank", shape=None): """Interpolate the times in the array""" import dask.array as da from dask import delayed if len(time2data) == None: return None if fill_mode == "blank": shape = list(time2data.items())[1].shape if shape is None else shape blank_img = da.from_delayed(delayed(np.zeros)(shape), shape=shape, dtype=int) res_arr = [] for time in range(max_time + 1): if time in time2data: res_arr.append(time2data[time]) else: res_arr.append(blank_img) return res_arr else: raise NotImplementedError() class LiveCellImageDatasetManager: def __init__( self, name2cache: Dict[str, "LiveCellImageDataset"] = None, path2cache: Dict[str, "LiveCellImageDataset"] = None ): if name2cache is None: name2cache = {} if path2cache is None: path2cache = {} self.name2cache = name2cache self.path2cache = path2cache def read_json(self, path): if not os.path.exists(path): raise ValueError("path does not exist: %s" % path) if self.path2cache.get(path) is not None: return self.path2cache[path] with open(path, "r") as f: json_dict = json.load(f) if ( json_dict["name"] in self.name2cache and self.name2cache[json_dict["name"]].get_dataset_path() == json_dict["data_dir_path"] ): return self.name2cache[json_dict["name"]] dataset = LiveCellImageDataset().load_from_json_dict(json_dict) self.name2cache[dataset.name] = dataset self.path2cache[path] = dataset return dataset default_dataset_manager = LiveCellImageDatasetManager() # TODO: add a method to get/cache all labels in a mask dataset at a specific time t
[docs] class LiveCellImageDataset(torch.utils.data.Dataset): """Dataset for loading images into RAM, possibly cache images and load them on demand. This class only contains one channel's imaging data. For multichannel data, we assume you have a single image for each channel. For the case where your images are stored in a single file, #TODO: you can use the MultiChannelImageDataset class. """ DEFAULT_OUT_DIR = "./livecell-datasets" def __init__( self, dir_path=None, time2url: Mapping[int, Union[str, Path]] = None, name=None, # "livecell-base", ext="tif", max_cache_size=50, max_img_num=None, force_posix_path=True, read_img_url_func: Callable = read_img_default, index_by_time=True, is_windows_path=False, ): """Initialize the dataset. Parameters ---------- dir_path : _type_, optional _description_, by default None time2url : Dict[int, str], optional _description_, by default None name : str, optional _description_, by default "livecell-base" ext : str, optional _description_, by default "tif" max_cache_size : int, optional _description_, by default 50 num_imgs : _type_, optional _description_, by default None force_posix_path : bool, optional _description_, by default True read_img_url_func : Callable, optional _description_, by default read_img_default index_by_time : bool, optional _description_, by default True """ self.read_img_url_func = read_img_url_func self.index_by_time = index_by_time # force posix path if dir_path is passed in if isinstance(dir_path, str): # dir_path = Path(dir_path) dir_path = PurePosixPath(dir_path) elif isinstance(dir_path, Path) and force_posix_path: dir_path = PurePosixPath(dir_path) self.data_dir_path = dir_path self.ext = ext if time2url is None: self.update_time2url_from_dir_path() elif isinstance(time2url, list): self.time2url = {i: path for i, path in enumerate(time2url)} else: self.time2url = time2url # force posix path if force_posix_path: # TODO: fix pathlib issues on windows; if is_windows_path: self.time2url = {time: str(PureWindowsPath(path).as_posix()) for time, path in self.time2url.items()} else: # TODO: decide prevent users from accidentally using windows path? self.time2url = { time: str(Path(path).as_posix()).replace("\\", "/") for time, path in self.time2url.items() } if max_img_num is not None: tmp_tuples = list(self.time2url.items()) tmp_tuples = sorted(tmp_tuples, key=lambda x: x[0]) tmp_tuples = tmp_tuples[:max_img_num] self.time2url = {time: path for time, path in tmp_tuples} # sort times and urls based on times time_urls = sorted(list(self.time2url.items()), key=lambda x: x[0]) self.times = [time for time, url in time_urls] self.urls = [url for time, url in time_urls] self.cache_img_idx_to_img = {} self.max_cache_size = max_cache_size self.img_idx_queue = deque() # randomly generate a name if name is None: self.name = str(uuid.uuid4()) else: self.name = name
[docs] def reindex_time2url_sequential(self): """Reindex the time2url dictionary""" cur_idx = 0 new_time2url = {} for time in self.times: new_time2url[cur_idx] = self.time2url[time] cur_idx += 1 self.update_time2url(new_time2url)
[docs] def update_time2url(self, time2url: dict): """Update the time2url dictionary""" self.time2url = time2url self.times = list(self.time2url.keys()) print("%d %s img file paths loaded;" % (len(self.time2url), self.ext))
[docs] def update_time2url_from_dir_path(self): """ Updates the time2url dictionary with the file paths of all files in the specified data directory with the specified extension. Returns: time2url (dict): A dictionary mapping time points to file paths. """ if self.data_dir_path is None: self.time2url = {} return assert self.ext, "ext must be specified" self.time2url = sorted(glob.glob(str((Path(self.data_dir_path) / Path("*.%s" % (self.ext)))))) self.time2url = {i: path for i, path in enumerate(self.time2url)} self.update_time2url(self.time2url) return self.time2url
def __len__(self): return len(self.time2url)
[docs] def insert_cache(self, img, idx): self.cache_img_idx_to_img[idx] = img self.img_idx_queue.append(idx) # Do not move this block to the top of the function: corner case for max_cache_size = 0 if len(self.cache_img_idx_to_img) > self.max_cache_size: pop_index = self.img_idx_queue.popleft() pop_img = self.cache_img_idx_to_img[pop_index] self.cache_img_idx_to_img.pop(pop_index) del pop_img
# TODO: refactor path -> url
[docs] def get_img_path(self, time) -> str: """Get the path of the image at some time Parameters ---------- time : _type_ _description_ Returns ------- str _description_ """ return self.time2url[time]
[docs] def get_dataset_name(self): return self.name
[docs] def get_dataset_path(self): return self.data_dir_path
def __getitem__(self, idx) -> np.ndarray: if idx in self.cache_img_idx_to_img: return self.cache_img_idx_to_img[idx] # TODO: optimize if self.index_by_time: img = self.get_img_by_time(idx) else: img = self.get_img_by_idx(idx) return img
[docs] def to_json_dict(self) -> dict: """Return the dataset info as a dictionary object""" return { "name": self.name, "data_dir_path": str(self.data_dir_path), "max_cache_size": int(self.max_cache_size), "ext": self.ext, "time2url": self.time2url, }
[docs] def get_default_json_path(self, out_dir=None, posix=True): """Return the default json path for this dataset""" filename = Path("livecell-dataset-%s.json" % (self.name)) if out_dir is None: out_dir = self.DEFAULT_OUT_DIR res_path = Path(out_dir) / filename if posix: res_path = res_path.as_posix() return res_path
# TODO: refactor
[docs] def write_json(self, path=None, overwrite=True, out_dir=None): """Write the dataset info to a local json file. Returns a json string if path is None.""" # If path and out_dir are None, use default out_dir if path is None and out_dir is None: out_dir = self.DEFAULT_OUT_DIR if path is None and (out_dir is not None): path = Path(out_dir) / Path("livecell-dataset-%s.json" % (self.name)) path = Path(path) if not path.parent.exists(): path.parent.mkdir(parents=True, exist_ok=True) if (not overwrite) and os.path.exists(path): main_debug("[LiveCellDataset] skip writing to an existing path: %s" % (path)) return with open(path, "w+") as f: json.dump(self.to_json_dict(), f)
[docs] def load_from_json_dict(self, json_dict, update_time2url_from_dir_path=False, is_integer_time=True): """Load from a json dict. If update_img_paths is True, then we will update the img_path_list based on the data_dir_path. Parameters ---------- json_dict : _type_ _description_ update_img_paths : bool, optional _description_, by default False Returns ------- _type_ _description_ """ self.name = json_dict["name"] self.data_dir_path = json_dict["data_dir_path"] self.ext = json_dict["ext"] if update_time2url_from_dir_path: self.update_time2url_from_dir_path() else: self.time2url = json_dict["time2url"] if is_integer_time: self.time2url = {int(time): url for time, url in self.time2url.items()} self.times = sorted(list(self.time2url.keys())) self.max_cache_size = json_dict["max_cache_size"] return self
[docs] @staticmethod def load_from_json_file(path, use_cache=True, **kwargs): if use_cache: return default_dataset_manager.read_json(path) path = Path(path) with open(path, "r") as f: json_dict = json.load(f) return LiveCellImageDataset().load_from_json_dict(json_dict, **kwargs)
[docs] def to_dask(self, times=None, ram=False, interpolate_missing=True): """convert to a dask array for napari visualization""" import dask.array as da from dask import delayed if times is None: times = self.times if ram: time2arr = {time: da.from_array(self.time2url[time]) for time in times} filled_arr = interpolate_time_data(time2arr, max_time=max(times)) return da.stack(filled_arr) # TODO interpolate missing time points lazy_reader = delayed(self.read_img_url_func) lazy_arrays = [lazy_reader(self.time2url[time]) for time in times] img_shape = self.infer_shape() if interpolate_missing: time2lazy_array = { time: da.from_delayed(lazy_array, shape=img_shape, dtype=int) for time, lazy_array in zip(times, lazy_arrays) } dask_arrays = interpolate_time_data(time2lazy_array, max_time=max(times), shape=img_shape) else: dask_arrays = [da.from_delayed(lazy_array, shape=img_shape, dtype=int) for lazy_array in lazy_arrays] return da.stack(dask_arrays)
[docs] def get_img_by_idx(self, idx): """Get an image by some index in the times list""" time = self.times[idx] url = self.urls[idx] img = self.read_img_url_func(url) return img
[docs] def get_img_by_time(self, time) -> np.ndarray: """Get an image by time""" return self.read_img_url_func(self.time2url[time])
[docs] def get_img_by_url(self, url_str: str, exact_match=False, return_path_and_time=False, ignore_missing=False): """ Returns the image corresponding to the given URL string. Args: url_str (str): The URL string to search for. exact_match (bool): If True, only exact matches will be considered. Otherwise, substring matches will also be considered. return_path_and_time (bool): If True, the function will return a tuple containing the image, URL string, and time. Otherwise, only the image will be returned. ignore_missing (bool): If True, the function will return None if the URL string is not found. Otherwise, a ValueError will be raised. Returns: If return_path_and_time is True, a tuple containing the image, URL string, and time. Otherwise, only the image will be returned. Raises: ValueError: If the URL string is not found and ignore_missing is False, or if multiple URLs match the given string and exact_match is True. """ found_url = None found_time = None def _cmp_equal(x, y): return x == y def _cmp_substr(x, y): return (x in y) or (y in x) cmp_func = _cmp_equal if exact_match else _cmp_substr for time, full_url in self.time2url.items(): if (found_url is not None) and cmp_func(url_str, full_url): raise ValueError("Duplicate url found: %s" % url_str) if cmp_func(url_str, full_url): found_url = full_url found_time = time if found_url is None: if ignore_missing: return None, None, None if return_path_and_time else None else: raise ValueError("url not found: %s" % url_str) if return_path_and_time: return self.get_img_by_time(found_time), found_url, found_time return self.get_img_by_time(found_time)
[docs] def infer_shape(self): """Infer the shape of the images in the dataset""" img = self.get_img_by_time(self.times[0]) return img.shape
[docs] def subset_by_time(self, min_time, max_time, prefix="_subset"): """Return a subset of the dataset based on time [min, max)""" times2url = {} for time in self.times: if time >= min_time and time < max_time: times2url[time] = self.time2url[time] return LiveCellImageDataset( time2url=times2url, name="_subset_" + self.name, ext=self.ext, read_img_url_func=self.read_img_url_func, )
[docs] def get_sorted_times(self): """Get the times in the dataset""" return sorted(list(self.time2url.keys()))
[docs] def time_span(self): """Get the time span of the dataset""" return self.get_sorted_times()[0], self.get_sorted_times()[-1]
[docs] def has_time(self, time): return time in self.time2url
[docs] class SingleImageDataset(LiveCellImageDataset): DEFAULT_TIME = 0 def __init__(self, img, name=None, ext=".png", in_memory=True): super().__init__( time2url={SingleImageDataset.DEFAULT_TIME: "InMemory"}, name=name, ext=ext, read_img_url_func=self.read_single_img_from_mem, index_by_time=True, ) self.img = img self.url = None # TODO: handle to case where img is not in memory self.in_memory = in_memory
[docs] def read_single_img_from_mem(self, url): return self.img.copy()
[docs] def get_img_by_time(self, time=None) -> np.ndarray: return self.read_single_img_from_mem(self.url)
[docs] def get_img_by_idx(self, idx=None): return self.read_single_img_from_mem(self.url)
# TODO # class MultiChannelImageDataset(torch.utils.data.Dataset):