Skip to content

Commit e8fedec

Browse files
committed
remove transforms from class
updated the labels added gaussian blur
1 parent 63735c1 commit e8fedec

File tree

1 file changed

+17
-34
lines changed

1 file changed

+17
-34
lines changed

references/detection/transforms.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -442,25 +442,9 @@ def forward(
442442

443443

444444
class SimpleCopyPaste(torch.nn.Module):
445-
def __init__(self, jittering_type: str = "LSJ"):
445+
def __init__(self):
446446
super().__init__()
447447

448-
if jittering_type == "LSJ":
449-
scale_range = (0.1, 2.0)
450-
elif jittering_type == "SSJ":
451-
scale_range = (0.8, 1.25)
452-
else:
453-
# TODO: add invalid option error
454-
raise ValueError("Invalid jittering type")
455-
456-
self.transforms = Compose(
457-
[
458-
ScaleJitter(target_size=(1024, 1024), scale_range=scale_range),
459-
FixedSizeCrop(size=(1024, 1024), fill=105),
460-
RandomHorizontalFlip(0.5),
461-
]
462-
)
463-
464448
def combine_masks(self, masks):
465449
return masks.sum(dim=0).greater(0)
466450

@@ -472,22 +456,18 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
472456
if not batch.is_floating_point():
473457
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
474458

475-
for i, (image, mask) in enumerate(zip(batch, target)):
476-
batch[i], target[i] = self.transforms(image, mask)
477-
478459
# create copy of batch and target as the original will be modified
479460
batch_rolled = batch.roll(1, 0).detach().clone()
480461
target_rolled = copy.deepcopy(target[-1:] + target[:-1])
481462

482-
# TODO: select a random subset of objects from one of the images and paste them onto the other image
483-
484-
# TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask
485-
486463
# collect binary paste masks for all images
487464
paste_masks = []
488465

489466
for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled):
490-
paste_alpha_mask = self.combine_masks(paste_data["masks"])
467+
number_of_masks = len(paste_data["masks"])
468+
random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique()
469+
470+
paste_alpha_mask = self.combine_masks(paste_data["masks"][random_selection])
491471
paste_masks.append(paste_alpha_mask)
492472

493473
# update original masks
@@ -496,21 +476,24 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
496476

497477
# remove masks where no annotations are present (all values are 0)
498478
mask_filter = source_data["masks"].sum((2, 1)).not_equal(0)
499-
filtered_masks = source_data["masks"][mask_filter]
479+
source_data["masks"] = source_data["masks"][mask_filter]
480+
source_data["boxes"] = ops.masks_to_boxes(source_data["masks"])
481+
source_data["labels"] = source_data["labels"][mask_filter]
482+
source_data["area"] = source_data["area"][mask_filter]
483+
source_data["iscrowd"] = source_data["iscrowd"][mask_filter]
500484

501-
# update bboxes based on new masks
502-
source_data["boxes"] = ops.masks_to_boxes(filtered_masks)
503485
# TODO: update area
504486

505487
# concatenate paste data with original data
506-
source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"]))
507-
source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"]))
508-
source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"]))
509-
source_data["area"] = torch.cat((source_data["area"], paste_data["area"]))
510-
source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"]))
488+
source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection]))
489+
source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection]))
490+
source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection]))
491+
source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection]))
492+
source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection]))
511493

512494
# update the original images with paste images
513-
paste_masks = torch.stack(paste_masks)
495+
paste_masks = torch.stack(paste_masks).to(torch.uint8)
496+
paste_masks = T.GaussianBlur((5, 5), sigma=2)(paste_masks) # Adds Gaussian Filter
514497
batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1))
515498

516499
paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1)

0 commit comments

Comments
 (0)