Skip to content

Commit 96fa820

Browse files
cleanup for box encoding and decoding in FCOS (#6277)
* cleaning up box decoding * minor nits * cleanup for box encoding also addded.
1 parent b30fa5c commit 96fa820

File tree

3 files changed

+18
-76
lines changed

3 files changed

+18
-76
lines changed

test/test_models_detection_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def test_box_linear_coder(self):
3030

3131
proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()
3232

33-
rel_codes = box_coder.encode_single(boxes, proposals)
34-
pred_boxes = box_coder.decode_single(rel_codes, boxes)
33+
rel_codes = box_coder.encode(boxes, proposals)
34+
pred_boxes = box_coder.decode(rel_codes, boxes)
3535
torch.allclose(proposals, pred_boxes)
3636

3737
@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])

torchvision/models/detection/_utils.py

Lines changed: 4 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -237,42 +237,10 @@ def __init__(self, normalize_by_size: bool = True) -> None:
237237
"""
238238
self.normalize_by_size = normalize_by_size
239239

240-
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
240+
def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
241241
"""
242242
Encode a set of proposals with respect to some reference boxes
243243
244-
Args:
245-
reference_boxes (Tensor): reference boxes
246-
proposals (Tensor): boxes to be encoded
247-
248-
Returns:
249-
Tensor: the encoded relative box offsets that can be used to
250-
decode the boxes.
251-
"""
252-
# get the center of reference_boxes
253-
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
254-
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])
255-
256-
# get box regression transformation deltas
257-
target_l = reference_boxes_ctr_x - proposals[:, 0]
258-
target_t = reference_boxes_ctr_y - proposals[:, 1]
259-
target_r = proposals[:, 2] - reference_boxes_ctr_x
260-
target_b = proposals[:, 3] - reference_boxes_ctr_y
261-
262-
targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
263-
if self.normalize_by_size:
264-
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
265-
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
266-
reference_boxes_size = torch.stack(
267-
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
268-
)
269-
targets = targets / reference_boxes_size
270-
271-
return targets
272-
273-
def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
274-
"""
275-
vectorized version of `encode_single`
276244
Args:
277245
reference_boxes (Tensor): reference boxes
278246
proposals (Tensor): boxes to be encoded
@@ -304,7 +272,8 @@ def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
304272
targets = targets / reference_boxes_size
305273
return targets
306274

307-
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
275+
def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
276+
308277
"""
309278
From a set of original boxes and encoded relative box offsets,
310279
get the decoded boxes.
@@ -313,35 +282,6 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
313282
rel_codes (Tensor): encoded boxes
314283
boxes (Tensor): reference boxes.
315284
316-
Returns:
317-
Tensor: the predicted boxes with the encoded relative box offsets.
318-
"""
319-
320-
boxes = boxes.to(rel_codes.dtype)
321-
322-
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
323-
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
324-
if self.normalize_by_size:
325-
boxes_w = boxes[:, 2] - boxes[:, 0]
326-
boxes_h = boxes[:, 3] - boxes[:, 1]
327-
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
328-
rel_codes = rel_codes * boxes_size
329-
330-
pred_boxes1 = ctr_x - rel_codes[:, 0]
331-
pred_boxes2 = ctr_y - rel_codes[:, 1]
332-
pred_boxes3 = ctr_x + rel_codes[:, 2]
333-
pred_boxes4 = ctr_y + rel_codes[:, 3]
334-
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
335-
return pred_boxes
336-
337-
def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
338-
"""
339-
Vectorized version of `decode_single` method.
340-
341-
Args:
342-
rel_codes (Tensor): encoded boxes
343-
boxes (List[Tensor]): List of reference boxes.
344-
345285
Returns:
346286
Tensor: the predicted boxes with the encoded relative box offsets.
347287
@@ -350,7 +290,7 @@ def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
350290
351291
"""
352292

353-
boxes = torch.stack(boxes).to(dtype=rel_codes.dtype)
293+
boxes = boxes.to(dtype=rel_codes.dtype)
354294

355295
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
356296
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])

torchvision/models/detection/fcos.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ def compute_loss(
7474
all_gt_classes_targets.append(gt_classes_targets)
7575
all_gt_boxes_targets.append(gt_boxes_targets)
7676

77-
all_gt_classes_targets = torch.stack(all_gt_classes_targets)
77+
# List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
78+
all_gt_boxes_targets, all_gt_classes_targets, anchors = (
79+
torch.stack(all_gt_boxes_targets),
80+
torch.stack(all_gt_classes_targets),
81+
torch.stack(anchors),
82+
)
83+
7884
# compute foregroud
7985
foregroud_mask = all_gt_classes_targets >= 0
8086
num_foreground = foregroud_mask.sum().item()
@@ -84,14 +90,10 @@ def compute_loss(
8490
gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
8591
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
8692

87-
# regression loss: GIoU loss
88-
89-
pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)
90-
91-
# List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors`
92-
all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors)
93-
9493
# amp issue: pred_boxes need to convert float
94+
pred_boxes = self.box_coder.decode(bbox_regression, anchors)
95+
96+
# regression loss: GIoU loss
9597
loss_bbox_reg = generalized_box_iou_loss(
9698
pred_boxes[foregroud_mask],
9799
all_gt_boxes_targets[foregroud_mask],
@@ -100,7 +102,7 @@ def compute_loss(
100102

101103
# ctrness loss
102104

103-
bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets)
105+
bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
104106

105107
if len(bbox_reg_targets) == 0:
106108
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
@@ -522,7 +524,7 @@ def postprocess_detections(
522524
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
523525
labels_per_level = topk_idxs % num_classes
524526

525-
boxes_per_level = self.box_coder.decode_single(
527+
boxes_per_level = self.box_coder.decode(
526528
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
527529
)
528530
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

0 commit comments

Comments
 (0)