Source code for ufish.api

import os
import os.path as osp
import time
import typing as T
from pathlib import Path
import importlib
from functools import partial, cached_property

import numpy as np
import pandas as pd

from .utils.log import logger

if T.TYPE_CHECKING:
    import torch
    from torch import nn
    from matplotlib.figure import Figure
    import onnxruntime


BASE_STORE_URL = 'https://huggingface.co/GangCaoLab/U-FISH/resolve/main/'
DEFAULT_WEIGHTS_FILE = 'v1.0-alldata-ufish_c32.onnx'
STATC_STORE_PATH = osp.abspath(
    osp.join(osp.dirname(__file__), "model/weights/"))


[docs] class UFish():
[docs] def __init__( self, device: T.Union[ None, T.Literal['cpu', 'cuda', 'dml'], "torch.device"] = None, default_weights_file: T.Optional[str] = None, local_store_path: str = '~/.ufish/' ) -> None: """ Args: device: The device to use for training. 'cpu' or 'cuda' or 'dml'. If None, will use 'cuda' if available, otherwise 'cpu'. 'dml' is for using AMD GPUs on Windows. default_weight_file: The default weight file to use. local_store_path: The local path to store the weights. """ self._device = device self._infer_mode = False self.model: T.Optional["nn.Module"] = None self.ort_session: T.Optional["onnxruntime.InferenceSession"] = None if default_weights_file is None: default_weights_file = DEFAULT_WEIGHTS_FILE self.default_weights_file = default_weights_file self.store_base_url = BASE_STORE_URL self.local_store_path = Path( os.path.expanduser(local_store_path)) self.weight_path: T.Optional[str] = None
@cached_property def device(self) -> "torch.device": """Get the torch device.""" return self._get_torch_device() def _get_torch_device(self) -> "torch.device": """Get the torch device.""" import torch if isinstance(self._device, torch.device): device = self._device self._device = device.type return device if self._device is None: if torch.cuda.is_available(): self._device = 'cuda' elif os.name == 'nt': try: importlib.import_module('torch_directml') logger.info( "Using DirectML for training on Windows.") self._device = 'dml' except Exception: self._device = 'cpu' else: self._device = 'cpu' if self._device == "cuda": if torch.cuda.is_available(): return torch.device('cuda') else: logger.warning( "CUDA is not available, using CPU instead.") return torch.device('cpu') elif self._device == "dml": import torch_directml dml = torch_directml.device() return dml else: return torch.device(self._device)
[docs] def init_model( self, model_type: str = 'ufish', **kwargs) -> None: """Initialize the model. Args: model_type: The type of the model. For example, 'ufish', 'spot_learn', ... kwargs: Other arguments for the model. """ if model_type == 'ufish': from .model.network.ufish_net import UFishNet self.model = UFishNet(**kwargs) elif model_type == 'spot_learn': from .model.network.spot_learn import SpotLearn self.model = SpotLearn(**kwargs) elif model_type == 'det_net': from .model.network.det_net import DetNet self.model = DetNet(**kwargs) elif model_type.endswith('.py'): with open(model_type, 'r') as f: code = f.read() vars = {} exec(code, vars, vars) self.model = vars['Net'](**kwargs) else: raise ValueError(f'Unknown model type: {model_type}') params = sum(p.numel() for p in self.model.parameters()) logger.info( f'Initializing {model_type} model with kwargs: {kwargs}') logger.info(f'Number of parameters: {params}') self.model = self.model.to(self.device)
[docs] def convert_to_onnx( self, output_path: T.Union[Path, str],) -> None: """Convert the model to ONNX format. Args: output_path: The path to the output ONNX file. """ if self.model is None: raise RuntimeError('Model is not initialized.') self._turn_on_infer_mode(trace_model=True) import torch import torch.onnx output_path = str(output_path) logger.info( f'Converting model to ONNX format, saving to {output_path}') inp = torch.rand(1, 1, 512, 512).to(self.device) dyn_axes = {0: 'batch_size', 2: 'y', 3: 'x'} torch.onnx.export( self.model, inp, output_path, input_names=['input'], output_names=['output'], opset_version=11, do_constant_folding=True, dynamic_axes={ 'input': dyn_axes, 'output': dyn_axes, }, )
def _turn_on_infer_mode(self, trace_model: bool = False) -> None: """Turn on the infer mode.""" if self._infer_mode: return self._infer_mode = True assert self.model is not None self.model.eval() if trace_model: import torch device = next(self.model.parameters()).device inp = torch.rand(1, 1, 512, 512).to(device) self.model = torch.jit.trace(self.model, inp)
[docs] def load_weights_from_internet( self, weights_file: T.Optional[str] = None, max_retry: int = 8, force_download: bool = False, ) -> None: """Load weights from the huggingface repo. Args: weights_file: The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files. max_retry: The maximum number of retries. force_download: Whether to force download the weights. """ import torch weights_file = weights_file or self.default_weights_file weight_url = self.store_base_url + weights_file local_weight_path = self.local_store_path / weights_file if local_weight_path.exists() and (not force_download): logger.info( f'Local weights {local_weight_path} exists, ' 'skip downloading.' ) else: logger.info( f'Downloading weights from {weight_url}, ' f'storing to {local_weight_path}') local_weight_path.parent.mkdir(parents=True, exist_ok=True) try_count = 0 while try_count < max_retry: try: torch.hub.download_url_to_file( weight_url, local_weight_path) break except Exception as e: logger.warning(f'Error downloading weights: {e}') try_count += 1 time.sleep(0.5) else: raise RuntimeError( f'Error downloading weights from {weight_url}.') self.load_weights_from_path(local_weight_path)
[docs] def load_weights_from_path( self, path: T.Union[Path, str], ) -> None: """Load weights from a local file. The file can be a .pth file or an .onnx file. Args: path: The path to the weights file. """ path = str(path) self.weight_path = path if path.endswith('.pth'): self._load_pth_file(path) elif path.endswith('.onnx'): self._load_onnx(path) else: raise ValueError( 'Weights file must be a pth file or an onnx file.')
[docs] def load_weights( self, weights_path: T.Optional[str] = None, weights_file: T.Optional[str] = None, max_retry: int = 8, force_download: bool = False, ): """Load weights from a local file or the internet. Args: weights_path: The path to the weights file. weights_file: The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files. max_retry: The maximum number of retries to download the weights. force_download: Whether to force download the weights. """ if weights_path is not None: self.load_weights_from_path(weights_path) else: if weights_file is not None: self.load_weights_from_internet( weights_file=weights_file, max_retry=max_retry, force_download=force_download, ) else: weights_path = osp.join(STATC_STORE_PATH, DEFAULT_WEIGHTS_FILE) self.load_weights_from_path(weights_path) return self
def _load_pth_file(self, path: T.Union[Path, str]) -> None: """Load weights from a local file. Args: path: The path to the pth weights file.""" import torch if self.model is None: self.init_model() assert self.model is not None path = str(path) logger.info(f'Loading weights from {path}') state_dict = torch.load(path, map_location=self.device) self.model.load_state_dict(state_dict) self.ort_session = None def _load_onnx( self, onnx_path: T.Union[Path, str], providers: T.Optional[T.List[str]] = None, ) -> None: """Load weights from a local ONNX file, and create an onnxruntime session. Args: onnx_path: The path to the ONNX file. providers: The providers to use. """ import onnxruntime onnx_path = str(onnx_path) logger.info(f'Loading ONNX from {onnx_path}') if self._device == 'cuda': providers = providers or ['CUDAExecutionProvider'] elif self._device == 'dml': providers = providers or ['DmlExecutionProvider'] else: providers = providers or ['CPUExecutionProvider'] self.ort_session = onnxruntime.InferenceSession( onnx_path, providers=providers) self.model = None
[docs] def infer(self, img: np.ndarray) -> np.ndarray: """Infer the image using the U-Net model.""" if self.ort_session is not None: output = self._infer_onnx(img) elif self.model is not None: output = self._infer_torch(img) else: raise RuntimeError( 'Both torch model and ONNX model are not initialized.') return output
def _infer_torch(self, img: np.ndarray) -> np.ndarray: """Infer the image using the torch model.""" self._turn_on_infer_mode() assert self.model is not None import torch tensor = torch.from_numpy(img).float() if self.device.type == 'cuda': tensor = tensor.cuda() with torch.no_grad(): output = self.model(tensor) output = output.detach().cpu().numpy() return output def _infer_onnx(self, img: np.ndarray) -> np.ndarray: """Infer the image using the ONNX model.""" assert self.ort_session is not None ort_inputs = {self.ort_session.get_inputs()[0].name: img} ort_outs = self.ort_session.run(None, ort_inputs) output = ort_outs[0] return output def _enhance_img2d(self, img: np.ndarray) -> np.ndarray: """Enhance a 2D image.""" output = self.infer(img[np.newaxis, np.newaxis])[0, 0] return output def _enhance_img3d( self, img: np.ndarray, batch_size: int = 4) -> np.ndarray: """Enhance a 3D image.""" logger.info( f'Enhancing 3D image in shape {img.shape}, ' f'batch size: {batch_size}') output = np.zeros_like(img, dtype=np.float32) for i in range(0, output.shape[0], batch_size): logger.info( f'Enhancing slice {i}-{i+batch_size}/{output.shape[0]}') _slice = img[i:i+batch_size][:, np.newaxis] output[i:i+batch_size] = self.infer(_slice)[:, 0] return output def _enhance_2d_or_3d( self, img: np.ndarray, axes: str, batch_size: int = 4, blend_3d: bool = False, ) -> np.ndarray: """Enhance a 2D or 3D image.""" from .utils.img import scale_image img = scale_image(img, warning=True) if img.ndim == 2: output = self._enhance_img2d(img) elif img.ndim == 3: if blend_3d: if 'z' not in axes: logger.warning( 'Image does not have a z axis, ' + 'cannot blend along z axis.') from .utils.img import enhance_blend_3d logger.info( "Blending 3D image from 3 directions: z, y, x.") output = enhance_blend_3d( img, self._enhance_img3d, axes=axes, batch_size=batch_size) else: output = self._enhance_img3d(img, batch_size=batch_size) else: raise ValueError('Image must be 2D or 3D.') return output
[docs] def call_spots( self, enhanced_img: np.ndarray, method: str = 'local_maxima', **kwargs, ) -> pd.DataFrame: """Call spots from enhanced image. Args: enhanced_img: The enhanced image. method: The method to use for spot calling. kwargs: Other arguments for the spot calling function. """ assert enhanced_img.ndim in (2, 3), 'Image must be 2D or 3D.' call_func: T.Callable if method == 'cc_center': from .utils.spot_calling import call_spots_cc_center as call_func else: from .utils.spot_calling import call_spots_local_maxima as call_func # noqa df = call_func(enhanced_img, **kwargs) return df
def _pred_2d_or_3d( self, img: np.ndarray, axes: str, blend_3d: bool = False, batch_size: int = 4, spots_calling_method: str = 'local_maxima', **kwargs, ) -> T.Tuple[pd.DataFrame, np.ndarray]: """Predict the spots in a 2D or 3D image. """ assert img.ndim in (2, 3), 'Image must be 2D or 3D.' assert len(axes) == img.ndim, \ "axes and image dimension must have the same length" enhanced_img = self._enhance_2d_or_3d( img, axes, batch_size=batch_size, blend_3d=(blend_3d and ('z' in axes)) ) df = self.call_spots( enhanced_img, method=spots_calling_method, **kwargs) return df, enhanced_img
[docs] def predict( self, img: np.ndarray, enh_img: T.Optional[np.ndarray] = None, axes: T.Optional[str] = None, blend_3d: bool = True, batch_size: int = 4, spots_calling_method: str = 'local_maxima', **kwargs, ) -> T.Tuple[pd.DataFrame, np.ndarray]: """Predict the spots in an image. Args: img: The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image. enh_img: The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array. axes: The axes of the image. For example, 'czxy' for a 4D image, 'yx' for a 2D image. If None, will try to infer the axes from the shape. blend_3d: Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes. batch_size: The batch size for inference. Used only when the image dimension is 3 or higher. spots_calling_method: The method to use for spot calling. kwargs: Other arguments for the spot calling function. """ from .utils.img import ( infer_img_axes, check_img_axes, map_predfunc_to_img ) if axes is None: logger.info("Axes not specified, infering from image shape.") axes = infer_img_axes(img.shape) logger.info(f"Infered axes: {axes}, image shape: {img.shape}") check_img_axes(img, axes) if not isinstance(img, np.ndarray): img = np.array(img) predfunc = partial( self._pred_2d_or_3d, blend_3d=blend_3d, batch_size=batch_size, spots_calling_method=spots_calling_method, **kwargs, ) df, enhanced_img = map_predfunc_to_img( predfunc, img, axes) if enh_img is not None: enh_img[:] = enhanced_img return df, enhanced_img
[docs] def predict_chunks( self, img: np.ndarray, enh_img: T.Optional[np.ndarray] = None, axes: T.Optional[str] = None, blend_3d: bool = True, batch_size: int = 4, chunk_size: T.Optional[T.Tuple[T.Union[int, str], ...]] = None, spots_calling_method: str = 'local_maxima', **kwargs, ): """Predict the spots in an image chunk by chunk. Args: img: The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image. enh_img: The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array. axes: The axes of the image. For example, 'czxy' for a 4D image, 'yx' for a 2D image. If None, will try to infer the axes from the shape. blend_3d: Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes. batch_size: The batch size for inference. Used only when the image dimension is 3 or higher. chunk_size: The chunk size for processing. For example, (1, 512, 512) for a 3D image, (512, 512) for a 2D image. Using 'image' as a dimension will use the whole image as a chunk. For example, (1, 'image', 'image') for a 3D image, ('image', 'image', 'image', 512, 512) for a 5D image. If None, will use the default chunk size. spots_calling_method: The method to use for spot calling. kwargs: Other arguments for the spot calling function. """ from .utils.img import ( check_img_axes, chunks_iterator, process_chunk_size, infer_img_axes) if axes is None: axes = infer_img_axes(img.shape) check_img_axes(img, axes) if chunk_size is None: from .utils.img import get_default_chunk_size chunk_size = get_default_chunk_size(axes) logger.info(f"Chunk size not specified, using {chunk_size}.") chunk_size = process_chunk_size(chunk_size, img.shape) logger.info(f"Chunk size: {chunk_size}") total_dfs = [] if enh_img is None: enh_img = np.zeros_like(img, dtype=np.float32) for c_range, chunk in chunks_iterator(img, chunk_size): logger.info("Processing chunk: " + str(c_range) + ", chunk shape: " + str(chunk.shape)) c_df, c_enh = self.predict( chunk, axes=axes, blend_3d=blend_3d, batch_size=batch_size, spots_calling_method=spots_calling_method, **kwargs) dim_start = [c_range[i][0] for i in range(len(axes))] c_df += dim_start total_dfs.append(c_df) c_enh = c_enh[ tuple(slice(0, (r[1]-r[0])) for r in c_range)] enh_img[tuple(slice(*r) for r in c_range)] = c_enh df = pd.concat(total_dfs, ignore_index=True) return df, enh_img
[docs] def evaluate_result_dp( self, pred: pd.DataFrame, true: pd.DataFrame, mdist: float = 3.0, ) -> pd.DataFrame: """Evaluate the prediction result using deepblink metrics. Args: pred: The predicted spots. true: The true spots. mdist: The maximum distance to consider a spot as a true positive. Returns: A pandas dataframe containing the evaluation metrics.""" from .utils.metrics_deepblink import compute_metrics axis_names = list(pred.columns) axis_cols = [n for n in axis_names if n.startswith('axis')] pred = pred[axis_cols].values true = true[axis_cols].values metrics = compute_metrics( pred, true, mdist=mdist) return metrics
[docs] def evaluate_result( self, pred: pd.DataFrame, true: pd.DataFrame, cutoff: float = 3.0, ) -> dict: """Calculate the F1 score of the prediction result. Args: pred: The predicted spots. true: The true spots. cutoff: The maximum distance to consider a spot as a true positive. """ from .utils.metrics import compute_metrics res = compute_metrics(pred.values, true.values, cutoff=cutoff) return res
[docs] def plot_result( self, img: np.ndarray, pred: pd.DataFrame, fig_size: T.Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, marker_color: str = 'red', marker_style: str = 'x', ) -> "Figure": """Plot the prediction result. Args: img: The image to plot. pred: The predicted spots. fig_size: The figure size. image_cmap: The colormap for the image. marker_size: The marker size. marker_color: The marker color. marker_style: The marker style. """ from .utils.plot import Plot2d plt2d = Plot2d() plt2d.default_figsize = fig_size plt2d.default_marker_size = marker_size plt2d.default_marker_color = marker_color plt2d.default_marker_style = marker_style plt2d.default_imshow_cmap = image_cmap plt2d.new_fig() plt2d.image(img) plt2d.spots(pred.values) return plt2d.fig
[docs] def plot_evaluate( self, img: np.ndarray, pred: pd.DataFrame, true: pd.DataFrame, cutoff: float = 3.0, fig_size: T.Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, tp_color: str = 'green', fp_color: str = 'red', fn_color: str = 'yellow', tp_marker: str = 'x', fp_marker: str = 'x', fn_marker: str = 'x', ) -> "Figure": """Plot the prediction result. Args: img: The image to plot. pred: The predicted spots. true: The true spots. cutoff: The maximum distance to consider a spot as a true positive. fig_size: The figure size. image_cmap: The colormap for the image. marker_size: The marker size. tp_color: The color for true positive. fp_color: The color for false positive. fn_color: The color for false negative. tp_marker_style: The marker style for true positive. fp_marker_style: The marker style for false positive. fn_marker_style: The marker style for false negative. """ from .utils.plot import Plot2d plt2d = Plot2d() plt2d.default_figsize = fig_size plt2d.default_marker_size = marker_size plt2d.default_imshow_cmap = image_cmap plt2d.new_fig() plt2d.image(img) plt2d.evaluate_result( pred.values, true.values, cutoff=cutoff, tp_color=tp_color, fp_color=fp_color, fn_color=fn_color, tp_marker=tp_marker, fp_marker=fp_marker, fn_marker=fn_marker, ) return plt2d.fig
def _load_dataset( self, path: str, root_dir_path: T.Optional[str] = None, img_glob: str = '*.tif', coord_glob: str = '*.csv', process_func=None, transform=None, ): """Load a dataset from a path.""" from .data import FISHSpotsDataset _path = Path(path) if _path.is_dir(): if root_dir_path is not None: logger.info(f"Dataset's root directory: {root_dir_path}") _path = Path(root_dir_path) / _path logger.info(f"Loading dataset from dir: {_path}") logger.info( f'Image glob: {img_glob}, Coordinate glob: {coord_glob}') _path_str = str(_path) dataset = FISHSpotsDataset.from_dir( _path_str, _path_str, img_glob=img_glob, coord_glob=coord_glob, process_func=process_func, transform=transform) else: logger.info(f"Loading dataset using meta csv: {_path}") assert _path.suffix == '.csv', \ "Meta file must be a csv file." root_dir = root_dir_path or _path.parent logger.info(f'Data root directory: {root_dir}') dataset = FISHSpotsDataset.from_meta_csv( root_dir=str(root_dir), meta_csv_path=str(_path), process_func=process_func, transform=transform) return dataset
[docs] def train( self, train_path: str, valid_path: str, root_dir: T.Optional[str] = None, img_glob: str = '*.tif', coord_glob: str = '*.csv', target_process: T.Optional[str] = 'gaussian', loss_type: str = 'DiceRMSELoss', loader_workers: int = 4, data_argu: bool = False, argu_prob: float = 0.5, num_epochs: int = 50, batch_size: int = 8, optimizer_type: str = "Adam", lr: float = 1e-3, summary_dir: str = "runs/ufish", model_save_dir: str = "./models", save_period: int = 5, ): """Train the U-Net model. Args: train_path: The path to the training dataset. Path to a directory containing images and coordinates, or a meta csv file. valid_path: The path to the validation dataset. Path to a directory containing images and coordinates, or a meta csv file. root_dir: The root directory of the dataset. If using meta csv, the root directory of the dataset. img_glob: The glob pattern for the image files. coord_glob: The glob pattern for the coordinate files. target_process: The target image processing method. 'gaussian' or 'dialation' or None. If None, no processing will be applied. default 'gaussian'. loss_type: The loss function type. loader_workers: The number of workers to use for the data loader. data_argu: Whether to use data augmentation. argu_prob: The probability to use data augmentation. num_epochs: The number of epochs to train. batch_size: The batch size. optimizer_type: The optimizer type. lr: The learning rate. summary_dir: The directory to save the TensorBoard summary to. model_save_dir: The directory to save the model to. save_period: Save the model every `save_period` epochs. """ from .model.train import train_on_dataset from .data import FISHSpotsDataset if self.model is None: logger.info('Model is not initialized. Will initialize a new one.') self.init_model() assert self.model is not None if data_argu: logger.info( 'Using data augmentation. ' + f'Probability: {argu_prob}' ) from .data import DataAugmentation transform = DataAugmentation(p=argu_prob) else: transform = None logger.info(f'Using {target_process} as target process.') process_func: T.Optional[T.Callable] = None if target_process == 'gaussian': process_func = FISHSpotsDataset.gaussian_filter elif target_process == 'dialation': process_func = FISHSpotsDataset.dialate_mask elif isinstance(target_process, str): from functools import partial process_func = partial( FISHSpotsDataset.dialate_mask, footprint=target_process) logger.info(f"Loading training dataset from {train_path}") train_dataset = self._load_dataset( train_path, root_dir_path=root_dir, img_glob=img_glob, coord_glob=coord_glob, process_func=process_func, transform=transform, ) logger.info(f"Loading validation dataset from {valid_path}") valid_dataset = self._load_dataset( valid_path, root_dir_path=root_dir, img_glob=img_glob, coord_glob=coord_glob, process_func=process_func, ) logger.info( f"Training dataset size: {len(train_dataset)}, " f"Validation dataset size: {len(valid_dataset)}" ) logger.info( f"Number of epochs: {num_epochs}, " f"Batch size: {batch_size}, " f"Learning rate: {lr}" ) train_on_dataset( self.model, train_dataset, valid_dataset, device=self.device, loss_type=loss_type, loader_workers=loader_workers, num_epochs=num_epochs, batch_size=batch_size, optimizer_type=optimizer_type, lr=lr, summary_dir=summary_dir, model_save_dir=model_save_dir, save_period=save_period, )