You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					165 lines
				
				6.0 KiB
			
		
		
			
		
	
	
					165 lines
				
				6.0 KiB
			| 
											5 days ago
										 | # copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
 | ||
|  | 
 | ||
|  | import random
 | ||
|  | import numpy as np
 | ||
|  | import scipy.ndimage
 | ||
|  | from torch.utils.data import Dataset
 | ||
|  | from torchvision import transforms
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def get_train_transforms():
 | ||
|  |     rand_flip = RandFlip()
 | ||
|  |     cast = Cast(types=(np.float32, np.uint8))
 | ||
|  |     rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
 | ||
|  |     rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
 | ||
|  |     train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
 | ||
|  |     return train_transforms
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class RandBalancedCrop:
 | ||
|  |     def __init__(self, patch_size, oversampling):
 | ||
|  |         self.patch_size = patch_size
 | ||
|  |         self.oversampling = oversampling
 | ||
|  | 
 | ||
|  |     def __call__(self, data):
 | ||
|  |         image, label = data["image"], data["label"]
 | ||
|  |         if random.random() < self.oversampling:
 | ||
|  |             image, label, cords = self.rand_foreg_cropd(image, label)
 | ||
|  |         else:
 | ||
|  |             image, label, cords = self._rand_crop(image, label)
 | ||
|  |         data.update({"image": image, "label": label})
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  |     @staticmethod
 | ||
|  |     def randrange(max_range):
 | ||
|  |         return 0 if max_range == 0 else random.randrange(max_range)
 | ||
|  | 
 | ||
|  |     def get_cords(self, cord, idx):
 | ||
|  |         return cord[idx], cord[idx] + self.patch_size[idx]
 | ||
|  | 
 | ||
|  |     def _rand_crop(self, image, label):
 | ||
|  |         ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
 | ||
|  |         cord = [self.randrange(x) for x in ranges]
 | ||
|  |         low_x, high_x = self.get_cords(cord, 0)
 | ||
|  |         low_y, high_y = self.get_cords(cord, 1)
 | ||
|  |         low_z, high_z = self.get_cords(cord, 2)
 | ||
|  |         image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
 | ||
|  |         label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
 | ||
|  |         return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
 | ||
|  | 
 | ||
|  |     def rand_foreg_cropd(self, image, label):
 | ||
|  |         def adjust(foreg_slice, patch_size, label, idx):
 | ||
|  |             diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
 | ||
|  |             sign = -1 if diff < 0 else 1
 | ||
|  |             diff = abs(diff)
 | ||
|  |             ladj = self.randrange(diff)
 | ||
|  |             hadj = diff - ladj
 | ||
|  |             low = max(0, foreg_slice[idx].start - sign * ladj)
 | ||
|  |             high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
 | ||
|  |             diff = patch_size[idx - 1] - (high - low)
 | ||
|  |             if diff > 0 and low == 0:
 | ||
|  |                 high += diff
 | ||
|  |             elif diff > 0:
 | ||
|  |                 low -= diff
 | ||
|  |             return low, high
 | ||
|  | 
 | ||
|  |         cl = np.random.choice(np.unique(label[label > 0]))
 | ||
|  |         foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
 | ||
|  |         foreg_slices = [x for x in foreg_slices if x is not None]
 | ||
|  |         slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
 | ||
|  |         slice_idx = np.argsort(slice_volumes)[-2:]
 | ||
|  |         foreg_slices = [foreg_slices[i] for i in slice_idx]
 | ||
|  |         if not foreg_slices:
 | ||
|  |             return self._rand_crop(image, label)
 | ||
|  |         foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
 | ||
|  |         low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
 | ||
|  |         low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
 | ||
|  |         low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
 | ||
|  |         image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
 | ||
|  |         label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
 | ||
|  |         return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class RandFlip:
 | ||
|  |     def __init__(self):
 | ||
|  |         self.axis = [1, 2, 3]
 | ||
|  |         self.prob = 1 / len(self.axis)
 | ||
|  | 
 | ||
|  |     def flip(self, data, axis):
 | ||
|  |         data["image"] = np.flip(data["image"], axis=axis).copy()
 | ||
|  |         data["label"] = np.flip(data["label"], axis=axis).copy()
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  |     def __call__(self, data):
 | ||
|  |         for axis in self.axis:
 | ||
|  |             if random.random() < self.prob:
 | ||
|  |                 data = self.flip(data, axis)
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class Cast:
 | ||
|  |     def __init__(self, types):
 | ||
|  |         self.types = types
 | ||
|  | 
 | ||
|  |     def __call__(self, data):
 | ||
|  |         data["image"] = data["image"].astype(self.types[0])
 | ||
|  |         data["label"] = data["label"].astype(self.types[1])
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class RandomBrightnessAugmentation:
 | ||
|  |     def __init__(self, factor, prob):
 | ||
|  |         self.prob = prob
 | ||
|  |         self.factor = factor
 | ||
|  | 
 | ||
|  |     def __call__(self, data):
 | ||
|  |         image = data["image"]
 | ||
|  |         if random.random() < self.prob:
 | ||
|  |             factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
 | ||
|  |             image = (image * (1 + factor)).astype(image.dtype)
 | ||
|  |             data.update({"image": image})
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class GaussianNoise:
 | ||
|  |     def __init__(self, mean, std, prob):
 | ||
|  |         self.mean = mean
 | ||
|  |         self.std = std
 | ||
|  |         self.prob = prob
 | ||
|  | 
 | ||
|  |     def __call__(self, data):
 | ||
|  |         image = data["image"]
 | ||
|  |         if random.random() < self.prob:
 | ||
|  |             scale = np.random.uniform(low=0.0, high=self.std)
 | ||
|  |             noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
 | ||
|  |             data.update({"image": image + noise})
 | ||
|  |         return data
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class PytTrain(Dataset):
 | ||
|  |     def __init__(self, images, labels, **kwargs):
 | ||
|  |         self.images, self.labels = images, labels
 | ||
|  |         self.train_transforms = get_train_transforms()
 | ||
|  |         patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
 | ||
|  |         self.patch_size = patch_size
 | ||
|  |         self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
 | ||
|  | 
 | ||
|  |     def __len__(self):
 | ||
|  |         return len(self.images)
 | ||
|  | 
 | ||
|  |     def __getitem__(self, idx):
 | ||
|  |         data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
 | ||
|  |         data = self.rand_crop(data)
 | ||
|  |         data = self.train_transforms(data)
 | ||
|  |         return data["image"], data["label"]
 | ||
|  | 
 | ||
|  | 
 | ||
|  | class PytVal(Dataset):
 | ||
|  |     def __init__(self, images, labels):
 | ||
|  |         self.images, self.labels = images, labels
 | ||
|  | 
 | ||
|  |     def __len__(self):
 | ||
|  |         return len(self.images)
 | ||
|  | 
 | ||
|  |     def __getitem__(self, idx):
 | ||
|  |         return np.load(self.images[idx]), np.load(self.labels[idx])
 |