API Reference
This section provides detailed documentation for the U-FISH Python API.
UFish Class
The main class for U-FISH functionality.
- class ufish.api.UFish(device: None | Literal['cpu', 'cuda', 'dml'] | torch.device = None, default_weights_file: str | None = None, local_store_path: str = '~/.ufish/')[source]
Bases:
object- __init__(device: None | Literal['cpu', 'cuda', 'dml'] | torch.device = None, default_weights_file: str | None = None, local_store_path: str = '~/.ufish/') None[source]
- Parameters:
device – The device to use for training. ‘cpu’ or ‘cuda’ or ‘dml’. If None, will use ‘cuda’ if available, otherwise ‘cpu’. ‘dml’ is for using AMD GPUs on Windows.
default_weight_file – The default weight file to use.
local_store_path – The local path to store the weights.
- property device: torch.device
Get the torch device.
- init_model(model_type: str = 'ufish', **kwargs) None[source]
Initialize the model.
- Parameters:
model_type – The type of the model. For example, ‘ufish’, ‘spot_learn’, …
kwargs – Other arguments for the model.
- convert_to_onnx(output_path: Path | str) None[source]
Convert the model to ONNX format.
- Parameters:
output_path – The path to the output ONNX file.
- load_weights_from_internet(weights_file: str | None = None, max_retry: int = 8, force_download: bool = False) None[source]
Load weights from the huggingface repo.
- Parameters:
weights_file – The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files.
max_retry – The maximum number of retries.
force_download – Whether to force download the weights.
- load_weights_from_path(path: Path | str) None[source]
Load weights from a local file. The file can be a .pth file or an .onnx file.
- Parameters:
path – The path to the weights file.
- load_weights(weights_path: str | None = None, weights_file: str | None = None, max_retry: int = 8, force_download: bool = False)[source]
Load weights from a local file or the internet.
- Parameters:
weights_path – The path to the weights file.
weights_file – The name of the weights file on the internet. See https://huggingface.co/GangCaoLab/U-FISH/tree/main for available weights files.
max_retry – The maximum number of retries to download the weights.
force_download – Whether to force download the weights.
- call_spots(enhanced_img: ndarray, method: str = 'local_maxima', **kwargs) DataFrame[source]
Call spots from enhanced image.
- Parameters:
enhanced_img – The enhanced image.
method – The method to use for spot calling.
kwargs – Other arguments for the spot calling function.
- predict(img: ndarray, enh_img: ndarray | None = None, axes: str | None = None, blend_3d: bool = True, batch_size: int = 4, spots_calling_method: str = 'local_maxima', **kwargs) Tuple[DataFrame, ndarray][source]
Predict the spots in an image.
- Parameters:
img – The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image.
enh_img – The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array.
axes – The axes of the image. For example, ‘czxy’ for a 4D image, ‘yx’ for a 2D image. If None, will try to infer the axes from the shape.
blend_3d – Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes.
batch_size – The batch size for inference. Used only when the image dimension is 3 or higher.
spots_calling_method – The method to use for spot calling.
kwargs – Other arguments for the spot calling function.
- predict_chunks(img: ndarray, enh_img: ndarray | None = None, axes: str | None = None, blend_3d: bool = True, batch_size: int = 4, chunk_size: Tuple[int | str, ...] | None = None, spots_calling_method: str = 'local_maxima', **kwargs)[source]
Predict the spots in an image chunk by chunk.
- Parameters:
img – The image to predict, it should be a multi dimensional array. For example, shape (c, z, y, x) for a 4D image, shape (z, y, x) for a 3D image, shape (y, x) for a 2D image.
enh_img – The enhanced image, if None, will be created. It can be a multi dimensional array or a zarr array.
axes – The axes of the image. For example, ‘czxy’ for a 4D image, ‘yx’ for a 2D image. If None, will try to infer the axes from the shape.
blend_3d – Whether to blend the 3D image. Used only when the image contains a z axis. If True, will blend the 3D enhanced images along the z, y, x axes.
batch_size – The batch size for inference. Used only when the image dimension is 3 or higher.
chunk_size – The chunk size for processing. For example, (1, 512, 512) for a 3D image, (512, 512) for a 2D image. Using ‘image’ as a dimension will use the whole image as a chunk. For example, (1, ‘image’, ‘image’) for a 3D image, (‘image’, ‘image’, ‘image’, 512, 512) for a 5D image. If None, will use the default chunk size.
spots_calling_method – The method to use for spot calling.
kwargs – Other arguments for the spot calling function.
- evaluate_result_dp(pred: DataFrame, true: DataFrame, mdist: float = 3.0) DataFrame[source]
Evaluate the prediction result using deepblink metrics.
- Parameters:
pred – The predicted spots.
true – The true spots.
mdist – The maximum distance to consider a spot as a true positive.
- Returns:
A pandas dataframe containing the evaluation metrics.
- evaluate_result(pred: DataFrame, true: DataFrame, cutoff: float = 3.0) dict[source]
Calculate the F1 score of the prediction result.
- Parameters:
pred – The predicted spots.
true – The true spots.
cutoff – The maximum distance to consider a spot as a true positive.
- plot_result(img: ndarray, pred: DataFrame, fig_size: Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, marker_color: str = 'red', marker_style: str = 'x') Figure[source]
Plot the prediction result.
- Parameters:
img – The image to plot.
pred – The predicted spots.
fig_size – The figure size.
image_cmap – The colormap for the image.
marker_size – The marker size.
marker_color – The marker color.
marker_style – The marker style.
- plot_evaluate(img: ndarray, pred: DataFrame, true: DataFrame, cutoff: float = 3.0, fig_size: Tuple[int, int] = (10, 10), image_cmap: str = 'gray', marker_size: int = 20, tp_color: str = 'green', fp_color: str = 'red', fn_color: str = 'yellow', tp_marker: str = 'x', fp_marker: str = 'x', fn_marker: str = 'x') Figure[source]
Plot the prediction result.
- Parameters:
img – The image to plot.
pred – The predicted spots.
true – The true spots.
cutoff – The maximum distance to consider a spot as a true positive.
fig_size – The figure size.
image_cmap – The colormap for the image.
marker_size – The marker size.
tp_color – The color for true positive.
fp_color – The color for false positive.
fn_color – The color for false negative.
tp_marker_style – The marker style for true positive.
fp_marker_style – The marker style for false positive.
fn_marker_style – The marker style for false negative.
- train(train_path: str, valid_path: str, root_dir: str | None = None, img_glob: str = '*.tif', coord_glob: str = '*.csv', target_process: str | None = 'gaussian', loss_type: str = 'DiceRMSELoss', loader_workers: int = 4, data_argu: bool = False, argu_prob: float = 0.5, num_epochs: int = 50, batch_size: int = 8, optimizer_type: str = 'Adam', lr: float = 0.001, summary_dir: str = 'runs/ufish', model_save_dir: str = './models', save_period: int = 5)[source]
Train the U-Net model.
- Parameters:
train_path – The path to the training dataset. Path to a directory containing images and coordinates, or a meta csv file.
valid_path – The path to the validation dataset. Path to a directory containing images and coordinates, or a meta csv file.
root_dir – The root directory of the dataset. If using meta csv, the root directory of the dataset.
img_glob – The glob pattern for the image files.
coord_glob – The glob pattern for the coordinate files.
target_process – The target image processing method. ‘gaussian’ or ‘dialation’ or None. If None, no processing will be applied. default ‘gaussian’.
loss_type – The loss function type.
loader_workers – The number of workers to use for the data loader.
data_argu – Whether to use data augmentation.
argu_prob – The probability to use data augmentation.
num_epochs – The number of epochs to train.
batch_size – The batch size.
optimizer_type – The optimizer type.
lr – The learning rate.
summary_dir – The directory to save the TensorBoard summary to.
model_save_dir – The directory to save the model to.
save_period – Save the model every save_period epochs.
Basic Usage
from ufish.api import UFish
# Initialize
ufish = UFish()
# Load weights
ufish.load_weights()
# Predict
spots, enhanced = ufish.predict(image)
Constructor Parameters
UFish(
model_type='ufish_2d', # Model type: 'ufish_2d' or 'ufish_3d'
device='auto', # Device: 'cpu', 'gpu', or 'auto'
num_workers=4, # Number of workers for data loading
verbose=True # Print progress messages
)
Core Methods
load_weights
Load model weights from file or default location.
# Load default weights
ufish.load_weights()
# Load from file
ufish.load_weights('path/to/model.onnx')
# Load from HuggingFace
ufish.load_weights_from_internet()
Parameters:
weights_path (str, optional): Path to weights file. If None, loads default weights.
download (bool): Whether to download weights if not found locally.
predict
Predict FISH spots in an image.
spots, enhanced = ufish.predict(
image,
threshold=0.5,
min_distance=3,
exclude_border=True,
normalize=True,
enhance_only=False
)
Parameters:
image (np.ndarray): Input image (2D or 3D).
threshold (float): Detection threshold (0-1).
min_distance (int): Minimum distance between spots.
exclude_border (bool): Exclude spots on image border.
normalize (bool): Normalize input image.
enhance_only (bool): Return only enhanced image without spot detection.
Returns:
spots (pd.DataFrame): Detected spots with columns [y, x] or [z, y, x].
enhanced (np.ndarray): Enhanced image.
evaluate_result
Evaluate prediction against ground truth.
metrics = ufish.evaluate_result(
pred_spots,
true_spots,
cutoff=3.0,
return_matches=False
)
Parameters:
pred_spots (pd.DataFrame): Predicted spots.
true_spots (pd.DataFrame): Ground truth spots.
cutoff (float): Maximum distance for matching.
return_matches (bool): Return matched pairs.
Returns:
metrics (dict): Dictionary containing:
precision: TP / (TP + FP)
recall: TP / (TP + FN)
f1: F1 score
tp: True positives count
fp: False positives count
fn: False negatives count
plot_result
Visualize detected spots on image.
fig = ufish.plot_result(
image,
spots,
figsize=(10, 10),
spot_size=20,
spot_color='red',
show_numbers=False,
z_slice=None
)
Parameters:
image (np.ndarray): Original image.
spots (pd.DataFrame): Detected spots.
figsize (tuple): Figure size.
spot_size (int): Size of spot markers.
spot_color (str): Color of spot markers.
show_numbers (bool): Show spot numbers.
z_slice (int, optional): For 3D images, which slice to show.
Returns:
fig (matplotlib.figure.Figure): Figure object.
plot_evaluate
Visualize evaluation results showing TP, FP, FN.
fig = ufish.plot_evaluate(
image,
pred_spots,
true_spots,
cutoff=3.0,
figsize=(15, 5),
z_slice=None
)
Parameters:
image (np.ndarray): Original image.
pred_spots (pd.DataFrame): Predicted spots.
true_spots (pd.DataFrame): Ground truth spots.
cutoff (float): Matching distance threshold.
figsize (tuple): Figure size.
z_slice (int, optional): For 3D images, which slice to show.
Returns:
fig (matplotlib.figure.Figure): Figure with three panels showing predictions, ground truth, and evaluation.
train
Train or fine-tune the model.
history = ufish.train(
train_dir,
val_dir,
num_epochs=100,
batch_size=8,
lr=1e-4,
lr_scheduler='cosine',
model_save_path='model.pt',
checkpoint_interval=10,
early_stopping_patience=20,
augmentation=True,
mixed_precision=True
)
Parameters:
train_dir (str): Training data directory.
val_dir (str): Validation data directory.
num_epochs (int): Number of training epochs.
batch_size (int): Batch size.
lr (float): Learning rate.
lr_scheduler (str): Learning rate scheduler (‘cosine’, ‘step’, or None).
model_save_path (str): Path to save best model.
checkpoint_interval (int): Save checkpoint every N epochs.
early_stopping_patience (int): Early stopping patience.
augmentation (bool): Use data augmentation.
mixed_precision (bool): Use mixed precision training.
Returns:
history (dict): Training history with loss and metrics.
Data Module
Functions for data loading and preprocessing.
- class ufish.data.FileReader(root_dir: str, meta_csv_path: str)[source]
Read images and coordinates from meta_csv and files.
- class ufish.data.ListReader(img_list: List[ndarray], coord_list: List[ndarray])[source]
Read images and coordinates from a list of images and coordinates.
- class ufish.data.DirReader(img_dir: str, coord_dir: str, img_glob: str = '*.tif', coord_glob: str = '*.csv')[source]
Read images and coordinates from a directory of images and coordinates.
- class ufish.data.FISHSpotsDataset(reader: Reader, process_func: Callable | None = None, transform=None)[source]
- __init__(reader: Reader, process_func: Callable | None = None, transform=None)[source]
FISH spots dataset.
- Parameters:
reader – The reader to read images and coordinates.
process_func – The function to process the target image.
transform – The transform to apply to the samples.
- static dialate_mask(mask: ndarray, footprint: str = 'disk(2)') ndarray[source]
Dialate the mask.
- Parameters:
mask – The mask to dialate.
footprint – The footprint to use for dialation.
- coords_to_target(coords: ndarray, shape: Tuple[int, int]) ndarray[source]
Convert coordinates to target image.
- Parameters:
coords – The coordinates to convert.
shape – The shape of the target image.
- classmethod from_meta_csv(root_dir: str, meta_csv_path: str, process_func: Callable | None = None, transform=None)[source]
Create a dataset from a meta CSV file.
- class ufish.data.SaltAndPepperNoise(p=0.5, salt_range=(0, 0.0001), pepper_range=(0, 0.0001))[source]
Dataset Classes
FISHDataset
PyTorch dataset for FISH images and labels.
from ufish.data import FISHDataset
dataset = FISHDataset(
image_dir='data/images',
label_dir='data/labels',
transform=None,
target_transform=None
)
DataLoader Creation
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True
)
Utility Functions
Image Processing
from ufish.utils import (
normalize_image,
pad_image,
tile_image,
stitch_tiles,
apply_clahe
)
# Normalize image
img_norm = normalize_image(img, percentile=(1, 99))
# Pad image for tiling
img_padded = pad_image(img, tile_size=512, overlap=64)
# Create tiles
tiles = tile_image(img, tile_size=512, overlap=64)
# Stitch tiles back
stitched = stitch_tiles(tiles, original_shape, overlap=64)
Spot Processing
from ufish.utils import (
filter_border_spots,
merge_duplicate_spots,
spots_to_mask,
mask_to_spots
)
# Filter border spots
filtered = filter_border_spots(spots, image_shape, border=10)
# Merge nearby spots
merged = merge_duplicate_spots(spots, min_distance=3)
# Convert spots to mask
mask = spots_to_mask(spots, image_shape)
# Convert mask to spots
spots = mask_to_spots(mask)
File I/O
from ufish.utils import (
read_zarr_chunk,
write_zarr_chunk,
read_n5_dataset,
save_spots_csv
)
# Read zarr chunk
chunk = read_zarr_chunk('data.zarr', chunk_coords)
# Save spots to CSV
save_spots_csv(spots, 'output.csv', include_confidence=True)
Model Module
Low-level model operations.
Model Architecture
from ufish.model import UNet2D, UNet3D
# Create 2D U-Net
model_2d = UNet2D(
in_channels=1,
out_channels=1,
features=[16, 32, 64, 128]
)
# Create 3D U-Net
model_3d = UNet3D(
in_channels=1,
out_channels=1,
features=[16, 32, 64]
)
ONNX Export
from ufish.model import export_to_onnx
export_to_onnx(
model,
input_shape=(1, 1, 512, 512),
output_path='model.onnx',
opset_version=11
)
Advanced Features
Custom Peak Detection
Implement custom peak detection:
from skimage.feature import peak_local_max
class CustomUFish(UFish):
def detect_peaks(self, enhanced_image, **kwargs):
# Custom peak detection logic
peaks = peak_local_max(
enhanced_image,
min_distance=kwargs.get('min_distance', 3),
threshold_abs=kwargs.get('threshold', 0.5),
exclude_border=kwargs.get('exclude_border', True),
p_norm=2 # Custom parameter
)
return peaks
Ensemble Prediction
Combine multiple models:
class EnsembleUFish:
def __init__(self, model_paths):
self.models = [UFish() for _ in model_paths]
for model, path in zip(self.models, model_paths):
model.load_weights(path)
def predict(self, image):
all_enhanced = []
for model in self.models:
_, enhanced = model.predict(image, enhance_only=True)
all_enhanced.append(enhanced)
# Average enhanced images
avg_enhanced = np.mean(all_enhanced, axis=0)
# Detect spots on averaged result
spots = self.detect_spots(avg_enhanced)
return spots, avg_enhanced
Batch GPU Processing
Process multiple images on GPU efficiently:
import torch
def batch_predict_gpu(ufish, images, batch_size=16):
results = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
batch_tensor = torch.stack([torch.from_numpy(img) for img in batch])
batch_tensor = batch_tensor.to(device)
with torch.no_grad():
enhanced_batch = ufish.model(batch_tensor)
for enhanced in enhanced_batch:
spots = ufish.detect_spots(enhanced.cpu().numpy())
results.append(spots)
return results