import os
import random
from pathlib import Path
import typing as T
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
import scipy.ndimage as ndi
from skimage.io import imread
from skimage.morphology import dilation
import skimage.morphology as morphology # noqa: F401
from .utils.img import scale_image
from .utils.log import logger
[docs]
class Reader:
def __len__(self):
pass
def __getitem__(self, idx: int):
pass
[docs]
def read_coords(self, path: str, ndim: int) -> np.ndarray:
coordinates = pd.read_csv(path)
axes = [f'axis-{i}' for i in range(ndim)]
coords = coordinates[axes].values
return coords
[docs]
class FileReader(Reader):
"""Read images and coordinates from
meta_csv and files."""
[docs]
def __init__(self, root_dir: str, meta_csv_path: str):
self.root_dir = root_dir
self.meta_data = pd.read_csv(meta_csv_path)
def __len__(self):
return len(self.meta_data)
def __getitem__(self, idx: int):
img_path = os.path.join(self.root_dir, self.meta_data.iloc[idx, 0])
image = imread(img_path)
coord_path = os.path.join(self.root_dir, self.meta_data.iloc[idx, 1])
coords = self.read_coords(coord_path, image.ndim)
sample = {'image': image, 'coords': coords}
return sample
[docs]
class ListReader(Reader):
"""Read images and coordinates from
a list of images and coordinates."""
[docs]
def __init__(
self,
img_list: T.List[np.ndarray],
coord_list: T.List[np.ndarray]):
self.img_list = img_list
self.coord_list = coord_list
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx: int):
image = self.img_list[idx]
coordinates = self.coord_list[idx]
sample = {'image': image, 'coords': coordinates}
return sample
[docs]
class DirReader(Reader):
"""Read images and coordinates from
a directory of images and coordinates."""
[docs]
def __init__(
self,
img_dir: str,
coord_dir: str,
img_glob: str = '*.tif',
coord_glob: str = '*.csv',
):
self.img_dir = Path(img_dir)
self.coord_dir = Path(coord_dir)
self.img_paths = sorted(self.img_dir.glob(img_glob))
self.coord_paths = sorted(self.coord_dir.glob(coord_glob))
assert len(self.img_paths) == len(self.coord_paths), \
"Number of images and coordinates must match."
self.check_prefix()
[docs]
def check_prefix(self):
for img_path, coord_path in zip(self.img_paths, self.coord_paths):
img_prefix = img_path.stem
coord_prefix = coord_path.stem
if img_prefix != coord_prefix:
logger.warning(
f"Image prefix {img_prefix} does not match "
f"coordinate prefix {coord_prefix}")
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx: int):
img_path = self.img_paths[idx]
coord_path = self.coord_paths[idx]
image = imread(img_path)
coords = self.read_coords(str(coord_path), image.ndim)
sample = {'image': image, 'coords': coords}
return sample
[docs]
class FISHSpotsDataset(Dataset):
[docs]
def __init__(
self, reader: Reader,
process_func: T.Optional[T.Callable] = None,
transform=None):
"""FISH spots dataset.
Args:
reader: The reader to read images and coordinates.
process_func: The function to process the target image.
transform: The transform to apply to the samples.
"""
self.reader = reader
self.transform = transform
self.process_func = process_func
[docs]
@staticmethod
def gaussian_filter(mask: np.ndarray, sigma=1) -> np.ndarray:
"""Apply Gaussian filter to the mask."""
if mask.max() == 0:
return mask
peak = np.stack(np.where(mask > 0), axis=1)
res = ndi.gaussian_filter(mask, sigma=sigma)
peak_val = res[tuple(peak.T)]
res /= peak_val.min()
return res
[docs]
@staticmethod
def dialate_mask(
mask: np.ndarray,
footprint: str = 'disk(2)'
) -> np.ndarray:
"""Dialate the mask.
Args:
mask: The mask to dialate.
footprint: The footprint to use for dialation.
"""
_footprint = eval(f"morphology.{footprint}")
return dilation(mask, footprint=_footprint)
def __len__(self):
return len(self.reader)
def __getitem__(self, idx: int):
data = self.reader[idx]
image, coords = data['image'], data['coords']
target = self.coords_to_target(coords, image.shape)
image = scale_image(image)
image = np.expand_dims(image, axis=0)
sample = {'image': image, 'target': target}
if self.transform:
sample = self.transform(sample)
return sample
[docs]
def coords_to_target(
self, coords: np.ndarray,
shape: T.Tuple[int, int],
) -> np.ndarray:
"""Convert coordinates to target image.
Args:
coords: The coordinates to convert.
shape: The shape of the target image.
"""
mask = np.zeros(shape, dtype=np.float32)
# remove out-of-bound coordinates
c = coords
c = (c + 0.5).astype(np.uint32)
c = c[(c[:, 0] >= 0) & (c[:, 0] < shape[0])]
c = c[(c[:, 1] >= 0) & (c[:, 1] < shape[1])]
mask[c[:, 0], c[:, 1]] = 1
if self.process_func:
mask = self.process_func(mask)
mask = np.expand_dims(mask, axis=0)
return mask
[docs]
@classmethod
def from_list(
cls,
img_list: T.List[np.ndarray],
coord_list: T.List[np.ndarray],
process_func: T.Optional[T.Callable] = None,
transform=None):
"""Create a dataset from a list of images and coordinates."""
reader = ListReader(img_list, coord_list)
return cls(reader, process_func, transform)
[docs]
@classmethod
def from_dir(
cls,
img_dir: str,
coord_dir: str,
img_glob: str = '*.tif',
coord_glob: str = '*.csv',
process_func: T.Optional[T.Callable] = None,
transform=None):
"""Create a dataset from a directory of images and coordinates."""
reader = DirReader(img_dir, coord_dir, img_glob, coord_glob)
return cls(reader, process_func, transform)
[docs]
class RandomFlip:
[docs]
def __init__(self, p=0.5):
self.p = p
def __call__(self, sample):
image, target = sample['image'], sample['target']
if random.random() < self.p:
image = np.flip(image, axis=2).copy()
target = np.flip(target, axis=2).copy()
if random.random() < self.p:
image = np.flip(image, axis=1).copy()
target = np.flip(target, axis=1).copy()
return {'image': image, 'target': target}
[docs]
class RandomRotation:
[docs]
def __init__(self, p=0.5, angle_range=(-90, 90)):
self.p = p
self.angle_range = angle_range
def __call__(self, sample):
if random.random() > self.p:
return sample
image, target = sample['image'], sample['target']
angle = random.uniform(self.angle_range[0], self.angle_range[1])
image = ndi.rotate(
image, angle, axes=(1, 2), order=1, reshape=False)
target = ndi.rotate(
target, angle, axes=(1, 2), order=1, reshape=False)
return {'image': image, 'target': target}
[docs]
class RandomTranslation:
[docs]
def __init__(self, p=0.5, shift_range=(-256, 256)):
self.p = p
self.shift_range = shift_range
def __call__(self, sample):
if random.random() > self.p:
return sample
image, target = sample['image'], sample['target']
shift_y = random.uniform(self.shift_range[0], self.shift_range[1])
shift_x = random.uniform(self.shift_range[0], self.shift_range[1])
image = ndi.shift(
image, (0, shift_y, shift_x), cval=0.0)
target = ndi.shift(
target, (0, shift_y, shift_x), cval=0.0)
return {'image': image, 'target': target}
[docs]
class GaussianNoise:
[docs]
def __init__(self, p=0.5, sigma_range=(0, 0.5)):
self.p = p
self.sigma_range = sigma_range
def __call__(self, sample):
if random.random() > self.p:
return sample
image, target = sample['image'], sample['target']
sigma = random.uniform(self.sigma_range[0], self.sigma_range[1])
noise = np.random.normal(
0, sigma, image.shape).astype(np.float32)
image += noise
image = np.clip(image, 0, 255)
return {'image': image, 'target': target}
[docs]
class SaltAndPepperNoise:
[docs]
def __init__(
self, p=0.5,
salt_range=(0, 1e-4),
pepper_range=(0, 1e-4)):
self.p = p
self.salt_range = salt_range
self.pepper_range = pepper_range
def __call__(self, sample):
if random.random() > self.p:
return sample
p_salt = random.uniform(
self.salt_range[0], self.salt_range[1])
p_pepper = random.uniform(
self.pepper_range[0], self.pepper_range[1])
image, target = sample['image'], sample['target']
mask = np.random.choice(
[0, 1, 2], size=image.shape,
p=[(1-p_salt-p_pepper), p_salt, p_pepper])
image[mask == 1] = 0.0
image[mask == 2] = 255.0
return {'image': image, 'target': target}
[docs]
class ToTensorWrapper:
def __call__(self, sample):
return {
'image': torch.tensor(sample['image'], dtype=torch.float32),
'target': torch.tensor(sample['target'], dtype=torch.float32)
}
[docs]
class DataAugmentation:
[docs]
def __init__(self, p=0.5, each_transform_p=0.5):
self.p = p
_p = each_transform_p
self.transforms = [
RandomFlip(p=_p),
RandomRotation(p=_p),
RandomTranslation(p=_p),
GaussianNoise(p=_p),
SaltAndPepperNoise(p=_p),
]
self.to_tensor = ToTensorWrapper()
def __call__(self, sample):
if random.random() <= self.p:
for transform in self.transforms:
sample = transform(sample)
sample = self.to_tensor(sample)
return sample