diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index d808ecffed3..12b3784099f 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -300,6 +300,42 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1) return pred_boxes + def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: + """ + Vectorized version of `decode_single` method. + + Args: + rel_codes (Tensor): encoded boxes + boxes (List[Tensor]): List of reference boxes. + + Returns: + Tensor: the predicted boxes with the encoded relative box offsets. + + .. note:: + This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``. + + """ + + boxes = torch.stack(boxes).to(dtype=rel_codes.dtype) + + ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2]) + ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3]) + + if self.normalize_by_size: + boxes_w = boxes[..., 2] - boxes[..., 0] + boxes_h = boxes[..., 3] - boxes[..., 1] + + list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1) + rel_codes = rel_codes * list_box_size + + pred_boxes1 = ctr_x - rel_codes[..., 0] + pred_boxes2 = ctr_y - rel_codes[..., 1] + pred_boxes3 = ctr_x + rel_codes[..., 2] + pred_boxes4 = ctr_y + rel_codes[..., 3] + + pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1) + return pred_boxes + class Matcher: """ diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index b19da6637bb..9851b7f7c05 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -87,14 +87,12 @@ def compute_loss( loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") # regression loss: GIoU loss - # TODO: vectorize this instead of using a for loop - pred_boxes = [ - self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image) - for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) - ] + + pred_boxes = self.box_coder.decode_all(bbox_regression, anchors) + # amp issue: pred_boxes need to convert float loss_bbox_reg = generalized_box_iou_loss( - torch.stack(pred_boxes)[foregroud_mask].float(), + pred_boxes[foregroud_mask], torch.stack(all_gt_boxes_targets)[foregroud_mask], reduction="sum", )