You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
25 lines
1.5 KiB
25 lines
1.5 KiB
from test.external.mlperf_retinanet.model.boxes import box_iou
|
|
from test.external.mlperf_retinanet.model.utils import Matcher
|
|
|
|
import torch
|
|
|
|
# This applies the filtering in https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/retinanet.py#L117
|
|
# and https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/retinanet.py#L203
|
|
# to match with tinygrad's dataloader implementation.
|
|
def postprocess_targets(targets, anchors):
|
|
proposal_matcher, matched_idxs = Matcher(0.5, 0.4, allow_low_quality_matches=True), []
|
|
for anchors_per_image, targets_per_image in zip(anchors, targets):
|
|
if targets_per_image['boxes'].numel() == 0:
|
|
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64,
|
|
device=anchors_per_image.device))
|
|
continue
|
|
|
|
match_quality_matrix = box_iou(targets_per_image['boxes'], anchors_per_image)
|
|
matched_idxs.append(proposal_matcher(match_quality_matrix))
|
|
|
|
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
|
|
foreground_idxs_per_image = matched_idxs_per_image >= 0
|
|
targets_per_image["boxes"] = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
|
|
targets_per_image["labels"] = targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]]
|
|
|
|
return targets
|
|
|