You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
6.0 KiB
165 lines
6.0 KiB
# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
|
|
|
|
import random
|
|
import numpy as np
|
|
import scipy.ndimage
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
|
|
|
|
def get_train_transforms():
|
|
rand_flip = RandFlip()
|
|
cast = Cast(types=(np.float32, np.uint8))
|
|
rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
|
|
rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
|
|
train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
|
|
return train_transforms
|
|
|
|
|
|
class RandBalancedCrop:
|
|
def __init__(self, patch_size, oversampling):
|
|
self.patch_size = patch_size
|
|
self.oversampling = oversampling
|
|
|
|
def __call__(self, data):
|
|
image, label = data["image"], data["label"]
|
|
if random.random() < self.oversampling:
|
|
image, label, cords = self.rand_foreg_cropd(image, label)
|
|
else:
|
|
image, label, cords = self._rand_crop(image, label)
|
|
data.update({"image": image, "label": label})
|
|
return data
|
|
|
|
@staticmethod
|
|
def randrange(max_range):
|
|
return 0 if max_range == 0 else random.randrange(max_range)
|
|
|
|
def get_cords(self, cord, idx):
|
|
return cord[idx], cord[idx] + self.patch_size[idx]
|
|
|
|
def _rand_crop(self, image, label):
|
|
ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
|
|
cord = [self.randrange(x) for x in ranges]
|
|
low_x, high_x = self.get_cords(cord, 0)
|
|
low_y, high_y = self.get_cords(cord, 1)
|
|
low_z, high_z = self.get_cords(cord, 2)
|
|
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
|
|
|
|
def rand_foreg_cropd(self, image, label):
|
|
def adjust(foreg_slice, patch_size, label, idx):
|
|
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
|
|
sign = -1 if diff < 0 else 1
|
|
diff = abs(diff)
|
|
ladj = self.randrange(diff)
|
|
hadj = diff - ladj
|
|
low = max(0, foreg_slice[idx].start - sign * ladj)
|
|
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
|
|
diff = patch_size[idx - 1] - (high - low)
|
|
if diff > 0 and low == 0:
|
|
high += diff
|
|
elif diff > 0:
|
|
low -= diff
|
|
return low, high
|
|
|
|
cl = np.random.choice(np.unique(label[label > 0]))
|
|
foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
|
|
foreg_slices = [x for x in foreg_slices if x is not None]
|
|
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
|
|
slice_idx = np.argsort(slice_volumes)[-2:]
|
|
foreg_slices = [foreg_slices[i] for i in slice_idx]
|
|
if not foreg_slices:
|
|
return self._rand_crop(image, label)
|
|
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
|
|
low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
|
|
low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
|
|
low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
|
|
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
|
|
|
|
|
|
class RandFlip:
|
|
def __init__(self):
|
|
self.axis = [1, 2, 3]
|
|
self.prob = 1 / len(self.axis)
|
|
|
|
def flip(self, data, axis):
|
|
data["image"] = np.flip(data["image"], axis=axis).copy()
|
|
data["label"] = np.flip(data["label"], axis=axis).copy()
|
|
return data
|
|
|
|
def __call__(self, data):
|
|
for axis in self.axis:
|
|
if random.random() < self.prob:
|
|
data = self.flip(data, axis)
|
|
return data
|
|
|
|
|
|
class Cast:
|
|
def __init__(self, types):
|
|
self.types = types
|
|
|
|
def __call__(self, data):
|
|
data["image"] = data["image"].astype(self.types[0])
|
|
data["label"] = data["label"].astype(self.types[1])
|
|
return data
|
|
|
|
|
|
class RandomBrightnessAugmentation:
|
|
def __init__(self, factor, prob):
|
|
self.prob = prob
|
|
self.factor = factor
|
|
|
|
def __call__(self, data):
|
|
image = data["image"]
|
|
if random.random() < self.prob:
|
|
factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
|
|
image = (image * (1 + factor)).astype(image.dtype)
|
|
data.update({"image": image})
|
|
return data
|
|
|
|
|
|
class GaussianNoise:
|
|
def __init__(self, mean, std, prob):
|
|
self.mean = mean
|
|
self.std = std
|
|
self.prob = prob
|
|
|
|
def __call__(self, data):
|
|
image = data["image"]
|
|
if random.random() < self.prob:
|
|
scale = np.random.uniform(low=0.0, high=self.std)
|
|
noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
|
|
data.update({"image": image + noise})
|
|
return data
|
|
|
|
|
|
class PytTrain(Dataset):
|
|
def __init__(self, images, labels, **kwargs):
|
|
self.images, self.labels = images, labels
|
|
self.train_transforms = get_train_transforms()
|
|
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
|
|
self.patch_size = patch_size
|
|
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
|
|
data = self.rand_crop(data)
|
|
data = self.train_transforms(data)
|
|
return data["image"], data["label"]
|
|
|
|
|
|
class PytVal(Dataset):
|
|
def __init__(self, images, labels):
|
|
self.images, self.labels = images, labels
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
return np.load(self.images[idx]), np.load(self.labels[idx]) |