Skip to content

Commit b3b7448

Browse files
Vectorize box decoding in FCOS (#6203)
* basic structure * added constrains * fixed errors * thanks to vadim! * addressing the comments and added docstrign * Apply suggestions from code review Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 329b978 commit b3b7448

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

torchvision/models/detection/_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,42 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
300300
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
301301
return pred_boxes
302302

303+
def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
304+
"""
305+
Vectorized version of `decode_single` method.
306+
307+
Args:
308+
rel_codes (Tensor): encoded boxes
309+
boxes (List[Tensor]): List of reference boxes.
310+
311+
Returns:
312+
Tensor: the predicted boxes with the encoded relative box offsets.
313+
314+
.. note::
315+
This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
316+
317+
"""
318+
319+
boxes = torch.stack(boxes).to(dtype=rel_codes.dtype)
320+
321+
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
322+
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
323+
324+
if self.normalize_by_size:
325+
boxes_w = boxes[..., 2] - boxes[..., 0]
326+
boxes_h = boxes[..., 3] - boxes[..., 1]
327+
328+
list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
329+
rel_codes = rel_codes * list_box_size
330+
331+
pred_boxes1 = ctr_x - rel_codes[..., 0]
332+
pred_boxes2 = ctr_y - rel_codes[..., 1]
333+
pred_boxes3 = ctr_x + rel_codes[..., 2]
334+
pred_boxes4 = ctr_y + rel_codes[..., 3]
335+
336+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
337+
return pred_boxes
338+
303339

304340
class Matcher:
305341
"""

torchvision/models/detection/fcos.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,12 @@ def compute_loss(
8787
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
8888

8989
# regression loss: GIoU loss
90-
# TODO: vectorize this instead of using a for loop
91-
pred_boxes = [
92-
self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
93-
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
94-
]
90+
91+
pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)
92+
9593
# amp issue: pred_boxes need to convert float
9694
loss_bbox_reg = generalized_box_iou_loss(
97-
torch.stack(pred_boxes)[foregroud_mask].float(),
95+
pred_boxes[foregroud_mask],
9896
torch.stack(all_gt_boxes_targets)[foregroud_mask],
9997
reduction="sum",
10098
)

0 commit comments

Comments
 (0)