API Reference

This section provides detailed documentation for the U-FISH Python API.

UFish Class

The main class for U-FISH functionality.

class ufish.api.UFish(device: None | Literal['cpu', 'cuda', 'dml'] | torch.device = None, default_weights_file: str | None = None, local_store_path: str = '~/.ufish/')[source]

Bases: object

__init__(device: None | Literal['cpu', 'cuda', 'dml'] | torch.device = None, default_weights_file: str | None = None, local_store_path: str = '~/.ufish/') None[source]
Parameters:
  • device – The device to use for training. ‘cpu’ or ‘cuda’ or ‘dml’. If None, will use ‘cuda’ if available, otherwise ‘cpu’. ‘dml’ is for using AMD GPUs on Windows.

  • default_weight_file – The default weight file to use.

  • local_store_path – The local path to store the weights.

property device: torch.device

Get the torch device.

init_model(model_type: str = 'ufish', **kwargs) None[source]

Initialize the model.

Parameters:
  • model_type – The type of the model. For example, ‘ufish’, ‘spot_learn’, …

  • kwargs – Other arguments for the model.

convert_to_onnx(output_path: Path | str) None[source]

Convert the model to ONNX format.

Parameters:

output_path – The path to the output ONNX file.

load_weights_from_internet(weights_file: str | None = None, max_retry: int = 8, force_download: bool = False) None[source]

Load weights from the huggingface repo.

Parameters:
  • weights_file – The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files.

  • max_retry – The maximum number of retries.

  • force_download – Whether to force download the weights.

load_weights_from_path(path: Path | str) None[source]

Load weights from a local file. The file can be a .pth file or an .onnx file.

Parameters:

path – The path to the weights file.

load_weights(weights_path: str | None = None, weights_file: str | None = None, max_retry: int = 8, force_download: bool = False)[source]

Load weights from a local file or the internet.

Parameters:
  • weights_path – The path to the weights file.

  • weights_file – The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files.

  • max_retry – The maximum number of retries to download the weights.

  • force_download – Whether to force download the weights.

infer(img: ndarray) ndarray[source]

Infer the image using the U-Net model.

call_spots(enhanced_img: ndarray, method: str = 'local_maxima', **kwargs) DataFrame[source]

Call spots from enhanced image.

Parameters:
  • enhanced_img – The enhanced image.

  • method – The method to use for spot calling.

  • kwargs – Other arguments for the spot calling function.

predict(img: ndarray, enh_img: ndarray | None = None, axes: str | None = None, blend_3d: bool = True, batch_size: int = 4, spots_calling_method: str = 'local_maxima', **kwargs) Tuple[DataFrame, ndarray][source]

Predict the spots in an image.

Parameters:
  • img – The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image.

  • enh_img – The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array.

  • axes – The axes of the image. For example, ‘czxy’ for a 4D image, ‘yx’ for a 2D image. If None, will try to infer the axes from the shape.

  • blend_3d – Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes.

  • batch_size – The batch size for inference. Used only when the image dimension is 3 or higher.

  • spots_calling_method – The method to use for spot calling.

  • kwargs – Other arguments for the spot calling function.

predict_chunks(img: ndarray, enh_img: ndarray | None = None, axes: str | None = None, blend_3d: bool = True, batch_size: int = 4, chunk_size: Tuple[int | str, ...] | None = None, spots_calling_method: str = 'local_maxima', **kwargs)[source]

Predict the spots in an image chunk by chunk.

Parameters:
  • img – The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image.

  • enh_img – The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array.

  • axes – The axes of the image. For example, ‘czxy’ for a 4D image, ‘yx’ for a 2D image. If None, will try to infer the axes from the shape.

  • blend_3d – Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes.

  • batch_size – The batch size for inference. Used only when the image dimension is 3 or higher.

  • chunk_size – The chunk size for processing. For example, (1, 512, 512) for a 3D image, (512, 512) for a 2D image. Using ‘image’ as a dimension will use the whole image as a chunk. For example, (1, ‘image’, ‘image’) for a 3D image, (‘image’, ‘image’, ‘image’, 512, 512) for a 5D image. If None, will use the default chunk size.

  • spots_calling_method – The method to use for spot calling.

  • kwargs – Other arguments for the spot calling function.

evaluate_result_dp(pred: DataFrame, true: DataFrame, mdist: float = 3.0) DataFrame[source]

Evaluate the prediction result using deepblink metrics.

Parameters:
  • pred – The predicted spots.

  • true – The true spots.

  • mdist – The maximum distance to consider a spot as a true positive.

Returns:

A pandas dataframe containing the evaluation metrics.

evaluate_result(pred: DataFrame, true: DataFrame, cutoff: float = 3.0) dict[source]

Calculate the F1 score of the prediction result.

Parameters:
  • pred – The predicted spots.

  • true – The true spots.

  • cutoff – The maximum distance to consider a spot as a true positive.

plot_result(img: ndarray, pred: DataFrame, fig_size: Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, marker_color: str = 'red', marker_style: str = 'x') Figure[source]

Plot the prediction result.

Parameters:
  • img – The image to plot.

  • pred – The predicted spots.

  • fig_size – The figure size.

  • image_cmap – The colormap for the image.

  • marker_size – The marker size.

  • marker_color – The marker color.

  • marker_style – The marker style.

plot_evaluate(img: ndarray, pred: DataFrame, true: DataFrame, cutoff: float = 3.0, fig_size: Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, tp_color: str = 'green', fp_color: str = 'red', fn_color: str = 'yellow', tp_marker: str = 'x', fp_marker: str = 'x', fn_marker: str = 'x') Figure[source]

Plot the prediction result.

Parameters:
  • img – The image to plot.

  • pred – The predicted spots.

  • true – The true spots.

  • cutoff – The maximum distance to consider a spot as a true positive.

  • fig_size – The figure size.

  • image_cmap – The colormap for the image.

  • marker_size – The marker size.

  • tp_color – The color for true positive.

  • fp_color – The color for false positive.

  • fn_color – The color for false negative.

  • tp_marker_style – The marker style for true positive.

  • fp_marker_style – The marker style for false positive.

  • fn_marker_style – The marker style for false negative.

train(train_path: str, valid_path: str, root_dir: str | None = None, img_glob: str = '*.tif', coord_glob: str = '*.csv', target_process: str | None = 'gaussian', loss_type: str = 'DiceRMSELoss', loader_workers: int = 4, data_argu: bool = False, argu_prob: float = 0.5, num_epochs: int = 50, batch_size: int = 8, optimizer_type: str = 'Adam', lr: float = 0.001, summary_dir: str = 'runs/ufish', model_save_dir: str = './models', save_period: int = 5)[source]

Train the U-Net model.

Parameters:
  • train_path – The path to the training dataset. Path to a directory containing images and coordinates, or a meta csv file.

  • valid_path – The path to the validation dataset. Path to a directory containing images and coordinates, or a meta csv file.

  • root_dir – The root directory of the dataset. If using meta csv, the root directory of the dataset.

  • img_glob – The glob pattern for the image files.

  • coord_glob – The glob pattern for the coordinate files.

  • target_process – The target image processing method. ‘gaussian’ or ‘dialation’ or None. If None, no processing will be applied. default ‘gaussian’.

  • loss_type – The loss function type.

  • loader_workers – The number of workers to use for the data loader.

  • data_argu – Whether to use data augmentation.

  • argu_prob – The probability to use data augmentation.

  • num_epochs – The number of epochs to train.

  • batch_size – The batch size.

  • optimizer_type – The optimizer type.

  • lr – The learning rate.

  • summary_dir – The directory to save the TensorBoard summary to.

  • model_save_dir – The directory to save the model to.

  • save_period – Save the model every save_period epochs.

Basic Usage

from ufish.api import UFish

# Initialize
ufish = UFish()

# Load weights
ufish.load_weights()

# Predict
spots, enhanced = ufish.predict(image)

Constructor Parameters

UFish(
    model_type='ufish_2d',  # Model type: 'ufish_2d' or 'ufish_3d'
    device='auto',          # Device: 'cpu', 'gpu', or 'auto'
    num_workers=4,          # Number of workers for data loading
    verbose=True            # Print progress messages
)

Core Methods

load_weights

Load model weights from file or default location.

# Load default weights
ufish.load_weights()

# Load from file
ufish.load_weights('path/to/model.onnx')

# Load from HuggingFace
ufish.load_weights_from_internet()

Parameters:

  • weights_path (str, optional): Path to weights file. If None, loads default weights.

  • download (bool): Whether to download weights if not found locally.

predict

Predict FISH spots in an image.

spots, enhanced = ufish.predict(
    image,
    threshold=0.5,
    min_distance=3,
    exclude_border=True,
    normalize=True,
    enhance_only=False
)

Parameters:

  • image (np.ndarray): Input image (2D or 3D).

  • threshold (float): Detection threshold (0-1).

  • min_distance (int): Minimum distance between spots.

  • exclude_border (bool): Exclude spots on image border.

  • normalize (bool): Normalize input image.

  • enhance_only (bool): Return only enhanced image without spot detection.

Returns:

  • spots (pd.DataFrame): Detected spots with columns [y, x] or [z, y, x].

  • enhanced (np.ndarray): Enhanced image.

evaluate_result

Evaluate prediction against ground truth.

metrics = ufish.evaluate_result(
    pred_spots,
    true_spots,
    cutoff=3.0,
    return_matches=False
)

Parameters:

  • pred_spots (pd.DataFrame): Predicted spots.

  • true_spots (pd.DataFrame): Ground truth spots.

  • cutoff (float): Maximum distance for matching.

  • return_matches (bool): Return matched pairs.

Returns:

  • metrics (dict): Dictionary containing:

    • precision: TP / (TP + FP)

    • recall: TP / (TP + FN)

    • f1: F1 score

    • tp: True positives count

    • fp: False positives count

    • fn: False negatives count

plot_result

Visualize detected spots on image.

fig = ufish.plot_result(
    image,
    spots,
    figsize=(10, 10),
    spot_size=20,
    spot_color='red',
    show_numbers=False,
    z_slice=None
)

Parameters:

  • image (np.ndarray): Original image.

  • spots (pd.DataFrame): Detected spots.

  • figsize (tuple): Figure size.

  • spot_size (int): Size of spot markers.

  • spot_color (str): Color of spot markers.

  • show_numbers (bool): Show spot numbers.

  • z_slice (int, optional): For 3D images, which slice to show.

Returns:

  • fig (matplotlib.figure.Figure): Figure object.

plot_evaluate

Visualize evaluation results showing TP, FP, FN.

fig = ufish.plot_evaluate(
    image,
    pred_spots,
    true_spots,
    cutoff=3.0,
    figsize=(15, 5),
    z_slice=None
)

Parameters:

  • image (np.ndarray): Original image.

  • pred_spots (pd.DataFrame): Predicted spots.

  • true_spots (pd.DataFrame): Ground truth spots.

  • cutoff (float): Matching distance threshold.

  • figsize (tuple): Figure size.

  • z_slice (int, optional): For 3D images, which slice to show.

Returns:

  • fig (matplotlib.figure.Figure): Figure with three panels showing predictions, ground truth, and evaluation.

train

Train or fine-tune the model.

history = ufish.train(
    train_dir,
    val_dir,
    num_epochs=100,
    batch_size=8,
    lr=1e-4,
    lr_scheduler='cosine',
    model_save_path='model.pt',
    checkpoint_interval=10,
    early_stopping_patience=20,
    augmentation=True,
    mixed_precision=True
)

Parameters:

  • train_dir (str): Training data directory.

  • val_dir (str): Validation data directory.

  • num_epochs (int): Number of training epochs.

  • batch_size (int): Batch size.

  • lr (float): Learning rate.

  • lr_scheduler (str): Learning rate scheduler (‘cosine’, ‘step’, or None).

  • model_save_path (str): Path to save best model.

  • checkpoint_interval (int): Save checkpoint every N epochs.

  • early_stopping_patience (int): Early stopping patience.

  • augmentation (bool): Use data augmentation.

  • mixed_precision (bool): Use mixed precision training.

Returns:

  • history (dict): Training history with loss and metrics.

Data Module

Functions for data loading and preprocessing.

class ufish.data.Reader[source]
read_coords(path: str, ndim: int) ndarray[source]
class ufish.data.FileReader(root_dir: str, meta_csv_path: str)[source]

Read images and coordinates from meta_csv and files.

__init__(root_dir: str, meta_csv_path: str)[source]
class ufish.data.ListReader(img_list: List[ndarray], coord_list: List[ndarray])[source]

Read images and coordinates from a list of images and coordinates.

__init__(img_list: List[ndarray], coord_list: List[ndarray])[source]
class ufish.data.DirReader(img_dir: str, coord_dir: str, img_glob: str = '*.tif', coord_glob: str = '*.csv')[source]

Read images and coordinates from a directory of images and coordinates.

__init__(img_dir: str, coord_dir: str, img_glob: str = '*.tif', coord_glob: str = '*.csv')[source]
check_prefix()[source]
class ufish.data.FISHSpotsDataset(reader: Reader, process_func: Callable | None = None, transform=None)[source]
__init__(reader: Reader, process_func: Callable | None = None, transform=None)[source]

FISH spots dataset.

Parameters:
  • reader – The reader to read images and coordinates.

  • process_func – The function to process the target image.

  • transform – The transform to apply to the samples.

static gaussian_filter(mask: ndarray, sigma=1) ndarray[source]

Apply Gaussian filter to the mask.

static dialate_mask(mask: ndarray, footprint: str = 'disk(2)') ndarray[source]

Dialate the mask.

Parameters:
  • mask – The mask to dialate.

  • footprint – The footprint to use for dialation.

coords_to_target(coords: ndarray, shape: Tuple[int, int]) ndarray[source]

Convert coordinates to target image.

Parameters:
  • coords – The coordinates to convert.

  • shape – The shape of the target image.

classmethod from_meta_csv(root_dir: str, meta_csv_path: str, process_func: Callable | None = None, transform=None)[source]

Create a dataset from a meta CSV file.

classmethod from_list(img_list: List[ndarray], coord_list: List[ndarray], process_func: Callable | None = None, transform=None)[source]

Create a dataset from a list of images and coordinates.

classmethod from_dir(img_dir: str, coord_dir: str, img_glob: str = '*.tif', coord_glob: str = '*.csv', process_func: Callable | None = None, transform=None)[source]

Create a dataset from a directory of images and coordinates.

class ufish.data.RandomFlip(p=0.5)[source]
__init__(p=0.5)[source]
class ufish.data.RandomRotation(p=0.5, angle_range=(-90, 90))[source]
__init__(p=0.5, angle_range=(-90, 90))[source]
class ufish.data.RandomTranslation(p=0.5, shift_range=(-256, 256))[source]
__init__(p=0.5, shift_range=(-256, 256))[source]
class ufish.data.GaussianNoise(p=0.5, sigma_range=(0, 0.5))[source]
__init__(p=0.5, sigma_range=(0, 0.5))[source]
class ufish.data.SaltAndPepperNoise(p=0.5, salt_range=(0, 0.0001), pepper_range=(0, 0.0001))[source]
__init__(p=0.5, salt_range=(0, 0.0001), pepper_range=(0, 0.0001))[source]
class ufish.data.ToTensorWrapper[source]
class ufish.data.DataAugmentation(p=0.5, each_transform_p=0.5)[source]
__init__(p=0.5, each_transform_p=0.5)[source]

Dataset Classes

FISHDataset

PyTorch dataset for FISH images and labels.

from ufish.data import FISHDataset

dataset = FISHDataset(
    image_dir='data/images',
    label_dir='data/labels',
    transform=None,
    target_transform=None
)

DataLoader Creation

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

Utility Functions

Image Processing

from ufish.utils import (
    normalize_image,
    pad_image,
    tile_image,
    stitch_tiles,
    apply_clahe
)

# Normalize image
img_norm = normalize_image(img, percentile=(1, 99))

# Pad image for tiling
img_padded = pad_image(img, tile_size=512, overlap=64)

# Create tiles
tiles = tile_image(img, tile_size=512, overlap=64)

# Stitch tiles back
stitched = stitch_tiles(tiles, original_shape, overlap=64)

Spot Processing

from ufish.utils import (
    filter_border_spots,
    merge_duplicate_spots,
    spots_to_mask,
    mask_to_spots
)

# Filter border spots
filtered = filter_border_spots(spots, image_shape, border=10)

# Merge nearby spots
merged = merge_duplicate_spots(spots, min_distance=3)

# Convert spots to mask
mask = spots_to_mask(spots, image_shape)

# Convert mask to spots
spots = mask_to_spots(mask)

File I/O

from ufish.utils import (
    read_zarr_chunk,
    write_zarr_chunk,
    read_n5_dataset,
    save_spots_csv
)

# Read zarr chunk
chunk = read_zarr_chunk('data.zarr', chunk_coords)

# Save spots to CSV
save_spots_csv(spots, 'output.csv', include_confidence=True)

Model Module

Low-level model operations.

Model Architecture

from ufish.model import UNet2D, UNet3D

# Create 2D U-Net
model_2d = UNet2D(
    in_channels=1,
    out_channels=1,
    features=[16, 32, 64, 128]
)

# Create 3D U-Net
model_3d = UNet3D(
    in_channels=1,
    out_channels=1,
    features=[16, 32, 64]
)

ONNX Export

from ufish.model import export_to_onnx

export_to_onnx(
    model,
    input_shape=(1, 1, 512, 512),
    output_path='model.onnx',
    opset_version=11
)

Advanced Features

Custom Peak Detection

Implement custom peak detection:

from skimage.feature import peak_local_max

class CustomUFish(UFish):
    def detect_peaks(self, enhanced_image, **kwargs):
        # Custom peak detection logic
        peaks = peak_local_max(
            enhanced_image,
            min_distance=kwargs.get('min_distance', 3),
            threshold_abs=kwargs.get('threshold', 0.5),
            exclude_border=kwargs.get('exclude_border', True),
            p_norm=2  # Custom parameter
        )
        return peaks

Ensemble Prediction

Combine multiple models:

class EnsembleUFish:
    def __init__(self, model_paths):
        self.models = [UFish() for _ in model_paths]
        for model, path in zip(self.models, model_paths):
            model.load_weights(path)

    def predict(self, image):
        all_enhanced = []
        for model in self.models:
            _, enhanced = model.predict(image, enhance_only=True)
            all_enhanced.append(enhanced)

        # Average enhanced images
        avg_enhanced = np.mean(all_enhanced, axis=0)

        # Detect spots on averaged result
        spots = self.detect_spots(avg_enhanced)
        return spots, avg_enhanced

Batch GPU Processing

Process multiple images on GPU efficiently:

import torch

def batch_predict_gpu(ufish, images, batch_size=16):
    results = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        batch_tensor = torch.stack([torch.from_numpy(img) for img in batch])
        batch_tensor = batch_tensor.to(device)

        with torch.no_grad():
            enhanced_batch = ufish.model(batch_tensor)

        for enhanced in enhanced_batch:
            spots = ufish.detect_spots(enhanced.cpu().numpy())
            results.append(spots)

    return results