@@ -237,42 +237,10 @@ def __init__(self, normalize_by_size: bool = True) -> None:
237
237
"""
238
238
self .normalize_by_size = normalize_by_size
239
239
240
- def encode_single (self , reference_boxes : Tensor , proposals : Tensor ) -> Tensor :
240
+ def encode (self , reference_boxes : Tensor , proposals : Tensor ) -> Tensor :
241
241
"""
242
242
Encode a set of proposals with respect to some reference boxes
243
243
244
- Args:
245
- reference_boxes (Tensor): reference boxes
246
- proposals (Tensor): boxes to be encoded
247
-
248
- Returns:
249
- Tensor: the encoded relative box offsets that can be used to
250
- decode the boxes.
251
- """
252
- # get the center of reference_boxes
253
- reference_boxes_ctr_x = 0.5 * (reference_boxes [:, 0 ] + reference_boxes [:, 2 ])
254
- reference_boxes_ctr_y = 0.5 * (reference_boxes [:, 1 ] + reference_boxes [:, 3 ])
255
-
256
- # get box regression transformation deltas
257
- target_l = reference_boxes_ctr_x - proposals [:, 0 ]
258
- target_t = reference_boxes_ctr_y - proposals [:, 1 ]
259
- target_r = proposals [:, 2 ] - reference_boxes_ctr_x
260
- target_b = proposals [:, 3 ] - reference_boxes_ctr_y
261
-
262
- targets = torch .stack ((target_l , target_t , target_r , target_b ), dim = 1 )
263
- if self .normalize_by_size :
264
- reference_boxes_w = reference_boxes [:, 2 ] - reference_boxes [:, 0 ]
265
- reference_boxes_h = reference_boxes [:, 3 ] - reference_boxes [:, 1 ]
266
- reference_boxes_size = torch .stack (
267
- (reference_boxes_w , reference_boxes_h , reference_boxes_w , reference_boxes_h ), dim = 1
268
- )
269
- targets = targets / reference_boxes_size
270
-
271
- return targets
272
-
273
- def encode_all (self , reference_boxes : Tensor , proposals : Tensor ) -> Tensor :
274
- """
275
- vectorized version of `encode_single`
276
244
Args:
277
245
reference_boxes (Tensor): reference boxes
278
246
proposals (Tensor): boxes to be encoded
@@ -304,7 +272,8 @@ def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
304
272
targets = targets / reference_boxes_size
305
273
return targets
306
274
307
- def decode_single (self , rel_codes : Tensor , boxes : Tensor ) -> Tensor :
275
+ def decode (self , rel_codes : Tensor , boxes : Tensor ) -> Tensor :
276
+
308
277
"""
309
278
From a set of original boxes and encoded relative box offsets,
310
279
get the decoded boxes.
@@ -313,35 +282,6 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
313
282
rel_codes (Tensor): encoded boxes
314
283
boxes (Tensor): reference boxes.
315
284
316
- Returns:
317
- Tensor: the predicted boxes with the encoded relative box offsets.
318
- """
319
-
320
- boxes = boxes .to (rel_codes .dtype )
321
-
322
- ctr_x = 0.5 * (boxes [:, 0 ] + boxes [:, 2 ])
323
- ctr_y = 0.5 * (boxes [:, 1 ] + boxes [:, 3 ])
324
- if self .normalize_by_size :
325
- boxes_w = boxes [:, 2 ] - boxes [:, 0 ]
326
- boxes_h = boxes [:, 3 ] - boxes [:, 1 ]
327
- boxes_size = torch .stack ((boxes_w , boxes_h , boxes_w , boxes_h ), dim = 1 )
328
- rel_codes = rel_codes * boxes_size
329
-
330
- pred_boxes1 = ctr_x - rel_codes [:, 0 ]
331
- pred_boxes2 = ctr_y - rel_codes [:, 1 ]
332
- pred_boxes3 = ctr_x + rel_codes [:, 2 ]
333
- pred_boxes4 = ctr_y + rel_codes [:, 3 ]
334
- pred_boxes = torch .stack ((pred_boxes1 , pred_boxes2 , pred_boxes3 , pred_boxes4 ), dim = 1 )
335
- return pred_boxes
336
-
337
- def decode_all (self , rel_codes : Tensor , boxes : List [Tensor ]) -> Tensor :
338
- """
339
- Vectorized version of `decode_single` method.
340
-
341
- Args:
342
- rel_codes (Tensor): encoded boxes
343
- boxes (List[Tensor]): List of reference boxes.
344
-
345
285
Returns:
346
286
Tensor: the predicted boxes with the encoded relative box offsets.
347
287
@@ -350,7 +290,7 @@ def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
350
290
351
291
"""
352
292
353
- boxes = torch . stack ( boxes ) .to (dtype = rel_codes .dtype )
293
+ boxes = boxes .to (dtype = rel_codes .dtype )
354
294
355
295
ctr_x = 0.5 * (boxes [..., 0 ] + boxes [..., 2 ])
356
296
ctr_y = 0.5 * (boxes [..., 1 ] + boxes [..., 3 ])
0 commit comments