You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					51 lines
				
				1.8 KiB
			
		
		
			
		
	
	
					51 lines
				
				1.8 KiB
			| 
											3 days ago
										 | # Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py
 | ||
|  | 
 | ||
|  | import torch
 | ||
|  | import torch.nn.functional as F
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def sigmoid_focal_loss(
 | ||
|  |     inputs: torch.Tensor,
 | ||
|  |     targets: torch.Tensor,
 | ||
|  |     alpha: float = 0.25,
 | ||
|  |     gamma: float = 2,
 | ||
|  |     reduction: str = "none",
 | ||
|  | ):
 | ||
|  |     """
 | ||
|  |     Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
 | ||
|  |     Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
 | ||
|  | 
 | ||
|  |     Args:
 | ||
|  |         inputs: A float tensor of arbitrary shape.
 | ||
|  |                 The predictions for each example.
 | ||
|  |         targets: A float tensor with the same shape as inputs. Stores the binary
 | ||
|  |                 classification label for each element in inputs
 | ||
|  |                 (0 for the negative class and 1 for the positive class).
 | ||
|  |         alpha: (optional) Weighting factor in range (0,1) to balance
 | ||
|  |                 positive vs negative examples or -1 for ignore. Default = 0.25
 | ||
|  |         gamma: Exponent of the modulating factor (1 - p_t) to
 | ||
|  |                balance easy vs hard examples.
 | ||
|  |         reduction: 'none' | 'mean' | 'sum'
 | ||
|  |                  'none': No reduction will be applied to the output.
 | ||
|  |                  'mean': The output will be averaged.
 | ||
|  |                  'sum': The output will be summed.
 | ||
|  |     Returns:
 | ||
|  |         Loss tensor with the reduction option applied.
 | ||
|  |     """
 | ||
|  |     p = torch.sigmoid(inputs)
 | ||
|  |     ce_loss = F.binary_cross_entropy_with_logits(
 | ||
|  |         inputs, targets, reduction="none"
 | ||
|  |     )
 | ||
|  |     p_t = p * targets + (1 - p) * (1 - targets)
 | ||
|  |     loss = ce_loss * ((1 - p_t) ** gamma)
 | ||
|  | 
 | ||
|  |     if alpha >= 0:
 | ||
|  |         alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
 | ||
|  |         loss = alpha_t * loss
 | ||
|  | 
 | ||
|  |     if reduction == "mean":
 | ||
|  |         loss = loss.mean()
 | ||
|  |     elif reduction == "sum":
 | ||
|  |         loss = loss.sum()
 | ||
|  | 
 | ||
|  |     return loss
 |