@@ -442,25 +442,9 @@ def forward(
442
442
443
443
444
444
class SimpleCopyPaste (torch .nn .Module ):
445
- def __init__ (self , jittering_type : str = "LSJ" ):
445
+ def __init__ (self ):
446
446
super ().__init__ ()
447
447
448
- if jittering_type == "LSJ" :
449
- scale_range = (0.1 , 2.0 )
450
- elif jittering_type == "SSJ" :
451
- scale_range = (0.8 , 1.25 )
452
- else :
453
- # TODO: add invalid option error
454
- raise ValueError ("Invalid jittering type" )
455
-
456
- self .transforms = Compose (
457
- [
458
- ScaleJitter (target_size = (1024 , 1024 ), scale_range = scale_range ),
459
- FixedSizeCrop (size = (1024 , 1024 ), fill = 105 ),
460
- RandomHorizontalFlip (0.5 ),
461
- ]
462
- )
463
-
464
448
def combine_masks (self , masks ):
465
449
return masks .sum (dim = 0 ).greater (0 )
466
450
@@ -472,22 +456,18 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
472
456
if not batch .is_floating_point ():
473
457
raise TypeError (f"Batch dtype should be a float tensor. Got { batch .dtype } ." )
474
458
475
- for i , (image , mask ) in enumerate (zip (batch , target )):
476
- batch [i ], target [i ] = self .transforms (image , mask )
477
-
478
459
# create copy of batch and target as the original will be modified
479
460
batch_rolled = batch .roll (1 , 0 ).detach ().clone ()
480
461
target_rolled = copy .deepcopy (target [- 1 :] + target [:- 1 ])
481
462
482
- # TODO: select a random subset of objects from one of the images and paste them onto the other image
483
-
484
- # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask
485
-
486
463
# collect binary paste masks for all images
487
464
paste_masks = []
488
465
489
466
for source_image , paste_image , source_data , paste_data in zip (batch , batch_rolled , target , target_rolled ):
490
- paste_alpha_mask = self .combine_masks (paste_data ["masks" ])
467
+ number_of_masks = len (paste_data ["masks" ])
468
+ random_selection = torch .randint (0 , number_of_masks , (number_of_masks ,)).unique ()
469
+
470
+ paste_alpha_mask = self .combine_masks (paste_data ["masks" ][random_selection ])
491
471
paste_masks .append (paste_alpha_mask )
492
472
493
473
# update original masks
@@ -496,21 +476,24 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
496
476
497
477
# remove masks where no annotations are present (all values are 0)
498
478
mask_filter = source_data ["masks" ].sum ((2 , 1 )).not_equal (0 )
499
- filtered_masks = source_data ["masks" ][mask_filter ]
479
+ source_data ["masks" ] = source_data ["masks" ][mask_filter ]
480
+ source_data ["boxes" ] = ops .masks_to_boxes (source_data ["masks" ])
481
+ source_data ["labels" ] = source_data ["labels" ][mask_filter ]
482
+ source_data ["area" ] = source_data ["area" ][mask_filter ]
483
+ source_data ["iscrowd" ] = source_data ["iscrowd" ][mask_filter ]
500
484
501
- # update bboxes based on new masks
502
- source_data ["boxes" ] = ops .masks_to_boxes (filtered_masks )
503
485
# TODO: update area
504
486
505
487
# concatenate paste data with original data
506
- source_data ["masks" ] = torch .cat ((source_data ["masks" ], paste_data ["masks" ]))
507
- source_data ["boxes" ] = torch .cat ((source_data ["boxes" ], paste_data ["boxes" ]))
508
- source_data ["labels" ] = torch .cat ((source_data ["labels" ], paste_data ["labels" ]))
509
- source_data ["area" ] = torch .cat ((source_data ["area" ], paste_data ["area" ]))
510
- source_data ["iscrowd" ] = torch .cat ((source_data ["iscrowd" ], paste_data ["iscrowd" ]))
488
+ source_data ["masks" ] = torch .cat ((source_data ["masks" ], paste_data ["masks" ][ random_selection ] ))
489
+ source_data ["boxes" ] = torch .cat ((source_data ["boxes" ], paste_data ["boxes" ][ random_selection ] ))
490
+ source_data ["labels" ] = torch .cat ((source_data ["labels" ], paste_data ["labels" ][ random_selection ] ))
491
+ source_data ["area" ] = torch .cat ((source_data ["area" ], paste_data ["area" ][ random_selection ] ))
492
+ source_data ["iscrowd" ] = torch .cat ((source_data ["iscrowd" ], paste_data ["iscrowd" ][ random_selection ] ))
511
493
512
494
# update the original images with paste images
513
- paste_masks = torch .stack (paste_masks )
495
+ paste_masks = torch .stack (paste_masks ).to (torch .uint8 )
496
+ paste_masks = T .GaussianBlur ((5 , 5 ), sigma = 2 )(paste_masks ) # Adds Gaussian Filter
514
497
batch .mul_ (torch .unsqueeze (torch .logical_not (paste_masks ), 1 ))
515
498
516
499
paste_images = batch_rolled * torch .unsqueeze (paste_masks , 1 )
0 commit comments