@@ -300,6 +300,42 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
300
300
pred_boxes = torch .stack ((pred_boxes1 , pred_boxes2 , pred_boxes3 , pred_boxes4 ), dim = 1 )
301
301
return pred_boxes
302
302
303
+ def decode_all (self , rel_codes : Tensor , boxes : List [Tensor ]) -> Tensor :
304
+ """
305
+ Vectorized version of `decode_single` method.
306
+
307
+ Args:
308
+ rel_codes (Tensor): encoded boxes
309
+ boxes (List[Tensor]): List of reference boxes.
310
+
311
+ Returns:
312
+ Tensor: the predicted boxes with the encoded relative box offsets.
313
+
314
+ .. note::
315
+ This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
316
+
317
+ """
318
+
319
+ boxes = torch .stack (boxes ).to (dtype = rel_codes .dtype )
320
+
321
+ ctr_x = 0.5 * (boxes [..., 0 ] + boxes [..., 2 ])
322
+ ctr_y = 0.5 * (boxes [..., 1 ] + boxes [..., 3 ])
323
+
324
+ if self .normalize_by_size :
325
+ boxes_w = boxes [..., 2 ] - boxes [..., 0 ]
326
+ boxes_h = boxes [..., 3 ] - boxes [..., 1 ]
327
+
328
+ list_box_size = torch .stack ((boxes_w , boxes_h , boxes_w , boxes_h ), dim = - 1 )
329
+ rel_codes = rel_codes * list_box_size
330
+
331
+ pred_boxes1 = ctr_x - rel_codes [..., 0 ]
332
+ pred_boxes2 = ctr_y - rel_codes [..., 1 ]
333
+ pred_boxes3 = ctr_x + rel_codes [..., 2 ]
334
+ pred_boxes4 = ctr_y + rel_codes [..., 3 ]
335
+
336
+ pred_boxes = torch .stack ((pred_boxes1 , pred_boxes2 , pred_boxes3 , pred_boxes4 ), dim = - 1 )
337
+ return pred_boxes
338
+
303
339
304
340
class Matcher :
305
341
"""
0 commit comments