openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

94 lines
3.4 KiB

# https://github.com/mlcommons/training/blob/master/image_segmentation/pytorch/model/losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class Dice:
def __init__(self,
to_onehot_y: bool = True,
to_onehot_x: bool = False,
use_softmax: bool = True,
use_argmax: bool = False,
include_background: bool = False,
layout: str = "NCDHW"):
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.to_onehot_x = to_onehot_x
self.use_softmax = use_softmax
self.use_argmax = use_argmax
self.smooth_nr = 1e-6
self.smooth_dr = 1e-6
self.layout = layout
def __call__(self, prediction, target):
if self.layout == "NCDHW":
channel_axis = 1
reduce_axis = list(range(2, len(prediction.shape)))
else:
channel_axis = -1
reduce_axis = list(range(1, len(prediction.shape) - 1))
num_pred_ch = prediction.shape[channel_axis]
if self.use_softmax:
prediction = torch.softmax(prediction, dim=channel_axis)
elif self.use_argmax:
prediction = torch.argmax(prediction, dim=channel_axis)
if self.to_onehot_y:
target = to_one_hot(target, self.layout, channel_axis)
if self.to_onehot_x:
prediction = to_one_hot(prediction, self.layout, channel_axis)
if not self.include_background:
assert num_pred_ch > 1, \
f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
if self.layout == "NCDHW":
target = target[:, 1:]
prediction = prediction[:, 1:]
else:
target = target[..., 1:]
prediction = prediction[..., 1:]
assert (target.shape == prediction.shape), \
f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."
intersection = torch.sum(target * prediction, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
prediction_sum = torch.sum(prediction, dim=reduce_axis)
return (2.0 * intersection + self.smooth_nr) / (target_sum + prediction_sum + self.smooth_dr)
def to_one_hot(array, layout, channel_axis):
if len(array.shape) >= 5:
array = torch.squeeze(array, dim=channel_axis)
array = F.one_hot(array.long(), num_classes=3)
if layout == "NCDHW":
array = array.permute(0, 4, 1, 2, 3).float()
return array
class DiceCELoss(nn.Module):
def __init__(self, to_onehot_y, use_softmax, layout, include_background):
super(DiceCELoss, self).__init__()
self.dice = Dice(to_onehot_y=to_onehot_y, use_softmax=use_softmax, layout=layout,
include_background=include_background)
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, y_pred, y_true):
cross_entropy = self.cross_entropy(y_pred, torch.squeeze(y_true, dim=1).long())
dice = torch.mean(1.0 - self.dice(y_pred, y_true))
return (dice + cross_entropy) / 2
class DiceScore:
def __init__(self, to_onehot_y: bool = True, use_argmax: bool = True, layout: str = "NCDHW",
include_background: bool = False):
self.dice = Dice(to_onehot_y=to_onehot_y, to_onehot_x=True, use_softmax=False,
use_argmax=use_argmax, layout=layout, include_background=include_background)
def __call__(self, y_pred, y_true):
return torch.mean(self.dice(y_pred, y_true), dim=0)