openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

165 lines
6.0 KiB

# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
import random
import numpy as np
import scipy.ndimage
from torch.utils.data import Dataset
from torchvision import transforms
def get_train_transforms():
rand_flip = RandFlip()
cast = Cast(types=(np.float32, np.uint8))
rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
return train_transforms
class RandBalancedCrop:
def __init__(self, patch_size, oversampling):
self.patch_size = patch_size
self.oversampling = oversampling
def __call__(self, data):
image, label = data["image"], data["label"]
if random.random() < self.oversampling:
image, label, cords = self.rand_foreg_cropd(image, label)
else:
image, label, cords = self._rand_crop(image, label)
data.update({"image": image, "label": label})
return data
@staticmethod
def randrange(max_range):
return 0 if max_range == 0 else random.randrange(max_range)
def get_cords(self, cord, idx):
return cord[idx], cord[idx] + self.patch_size[idx]
def _rand_crop(self, image, label):
ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
cord = [self.randrange(x) for x in ranges]
low_x, high_x = self.get_cords(cord, 0)
low_y, high_y = self.get_cords(cord, 1)
low_z, high_z = self.get_cords(cord, 2)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
def rand_foreg_cropd(self, image, label):
def adjust(foreg_slice, patch_size, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = self.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0:
high += diff
elif diff > 0:
low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices:
return self._rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
class RandFlip:
def __init__(self):
self.axis = [1, 2, 3]
self.prob = 1 / len(self.axis)
def flip(self, data, axis):
data["image"] = np.flip(data["image"], axis=axis).copy()
data["label"] = np.flip(data["label"], axis=axis).copy()
return data
def __call__(self, data):
for axis in self.axis:
if random.random() < self.prob:
data = self.flip(data, axis)
return data
class Cast:
def __init__(self, types):
self.types = types
def __call__(self, data):
data["image"] = data["image"].astype(self.types[0])
data["label"] = data["label"].astype(self.types[1])
return data
class RandomBrightnessAugmentation:
def __init__(self, factor, prob):
self.prob = prob
self.factor = factor
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
image = (image * (1 + factor)).astype(image.dtype)
data.update({"image": image})
return data
class GaussianNoise:
def __init__(self, mean, std, prob):
self.mean = mean
self.std = std
self.prob = prob
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
scale = np.random.uniform(low=0.0, high=self.std)
noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
data.update({"image": image + noise})
return data
class PytTrain(Dataset):
def __init__(self, images, labels, **kwargs):
self.images, self.labels = images, labels
self.train_transforms = get_train_transforms()
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
self.patch_size = patch_size
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
data = self.rand_crop(data)
data = self.train_transforms(data)
return data["image"], data["label"]
class PytVal(Dataset):
def __init__(self, images, labels):
self.images, self.labels = images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return np.load(self.images[idx]), np.load(self.labels[idx])