Skip to content

Commit bbc1aac

Browse files
lezwonvfdev-5datumbox
authored
Add SimpleCopyPaste augmentation (#5825)
* added simple POC * added jitter and crop options * added references * moved simplecopypaste to detection module * working POC for simple copy paste in detection * added comments * remove transforms from class updated the labels added gaussian blur * removed loop for mask calculation * replaced Gaussian blur with functional api * added inplace operations * added changes to accept tuples instead of tensors * - make copy paste functional - make only one copy of batch and target * add inplace support within copy paste functional * Updated code for copy-paste transform * Fixed code formatting * [skip ci] removed manual thresholding * Replaced cropping by resizing data to paste * Removed inplace arg (as useless) and put a check on iscrowd target * code-formatting * Updated copypaste op to make it torch scriptable Added fallbacks to support LSJ * Fixed flake8 * Updates according to the review Co-authored-by: vfdev-5 <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 369317f commit bbc1aac

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

references/detection/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from coco_utils import get_coco, get_coco_kp
3232
from engine import train_one_epoch, evaluate
3333
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
34+
from torchvision.transforms import InterpolationMode
35+
from transforms import SimpleCopyPaste
3436

3537

3638
def get_dataset(name, image_set, transform, data_path):
@@ -145,6 +147,13 @@ def get_args_parser(add_help=True):
145147
# Mixed precision training parameters
146148
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
147149

150+
# Use CopyPaste augmentation training parameter
151+
parser.add_argument(
152+
"--use-copypaste",
153+
action="store_true",
154+
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
155+
)
156+
148157
return parser
149158

150159

@@ -180,8 +189,20 @@ def main(args):
180189
else:
181190
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
182191

192+
train_collate_fn = utils.collate_fn
193+
if args.use_copypaste:
194+
if args.data_augmentation != "lsj":
195+
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
196+
197+
copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)
198+
199+
def copypaste_collate_fn(batch):
200+
return copypaste(*utils.collate_fn(batch))
201+
202+
train_collate_fn = copypaste_collate_fn
203+
183204
data_loader = torch.utils.data.DataLoader(
184-
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
205+
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
185206
)
186207

187208
data_loader_test = torch.utils.data.DataLoader(

references/detection/transforms.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torchvision
55
from torch import nn, Tensor
6+
from torchvision import ops
67
from torchvision.transforms import functional as F
78
from torchvision.transforms import transforms as T, InterpolationMode
89

@@ -437,3 +438,157 @@ def forward(
437438
)
438439

439440
return image, target
441+
442+
443+
def _copy_paste(
444+
image: torch.Tensor,
445+
target: Dict[str, Tensor],
446+
paste_image: torch.Tensor,
447+
paste_target: Dict[str, Tensor],
448+
blending: bool = True,
449+
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
450+
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
451+
452+
# Random paste targets selection:
453+
num_masks = len(paste_target["masks"])
454+
455+
if num_masks < 1:
456+
# Such degerante case with num_masks=0 can happen with LSJ
457+
# Let's just return (image, target)
458+
return image, target
459+
460+
# We have to please torch script by explicitly specifying dtype as torch.long
461+
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
462+
random_selection = torch.unique(random_selection).to(torch.long)
463+
464+
paste_masks = paste_target["masks"][random_selection]
465+
paste_boxes = paste_target["boxes"][random_selection]
466+
paste_labels = paste_target["labels"][random_selection]
467+
468+
masks = target["masks"]
469+
470+
# We resize source and paste data if they have different sizes
471+
# This is something we introduced here as originally the algorithm works
472+
# on equal-sized data (for example, coming from LSJ data augmentations)
473+
size1 = image.shape[-2:]
474+
size2 = paste_image.shape[-2:]
475+
if size1 != size2:
476+
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
477+
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
478+
# resize bboxes:
479+
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
480+
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
481+
482+
paste_alpha_mask = paste_masks.sum(dim=0) > 0
483+
484+
if blending:
485+
paste_alpha_mask = F.gaussian_blur(
486+
paste_alpha_mask.unsqueeze(0),
487+
kernel_size=(5, 5),
488+
sigma=[
489+
2.0,
490+
],
491+
)
492+
493+
# Copy-paste images:
494+
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
495+
496+
# Copy-paste masks:
497+
masks = masks * (~paste_alpha_mask)
498+
non_all_zero_masks = masks.sum((-1, -2)) > 0
499+
masks = masks[non_all_zero_masks]
500+
501+
# Do a shallow copy of the target dict
502+
out_target = {k: v for k, v in target.items()}
503+
504+
out_target["masks"] = torch.cat([masks, paste_masks])
505+
506+
# Copy-paste boxes and labels
507+
boxes = ops.masks_to_boxes(masks)
508+
out_target["boxes"] = torch.cat([boxes, paste_boxes])
509+
510+
labels = target["labels"][non_all_zero_masks]
511+
out_target["labels"] = torch.cat([labels, paste_labels])
512+
513+
# Update additional optional keys: area and iscrowd if exist
514+
if "area" in target:
515+
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
516+
517+
if "iscrowd" in target and "iscrowd" in paste_target:
518+
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
519+
# For example, if previous transforms geometrically modifies masks/boxes/labels but
520+
# does not update "iscrowd"
521+
if len(target["iscrowd"]) == len(non_all_zero_masks):
522+
iscrowd = target["iscrowd"][non_all_zero_masks]
523+
paste_iscrowd = paste_target["iscrowd"][random_selection]
524+
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
525+
526+
# Check for degenerated boxes and remove them
527+
boxes = out_target["boxes"]
528+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
529+
if degenerate_boxes.any():
530+
valid_targets = ~degenerate_boxes.any(dim=1)
531+
532+
out_target["boxes"] = boxes[valid_targets]
533+
out_target["masks"] = out_target["masks"][valid_targets]
534+
out_target["labels"] = out_target["labels"][valid_targets]
535+
536+
if "area" in out_target:
537+
out_target["area"] = out_target["area"][valid_targets]
538+
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
539+
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
540+
541+
return image, out_target
542+
543+
544+
class SimpleCopyPaste(torch.nn.Module):
545+
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
546+
super().__init__()
547+
self.resize_interpolation = resize_interpolation
548+
self.blending = blending
549+
550+
def forward(
551+
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
552+
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
553+
torch._assert(
554+
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
555+
"images should be a list of tensors",
556+
)
557+
torch._assert(
558+
isinstance(targets, (list, tuple)) and len(images) == len(targets),
559+
"targets should be a list of the same size as images",
560+
)
561+
for target in targets:
562+
# Can not check for instance type dict with inside torch.jit.script
563+
# torch._assert(isinstance(target, dict), "targets item should be a dict")
564+
for k in ["masks", "boxes", "labels"]:
565+
torch._assert(k in target, f"Key {k} should be present in targets")
566+
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
567+
568+
# images = [t1, t2, ..., tN]
569+
# Let's define paste_images as shifted list of input images
570+
# paste_images = [t2, t3, ..., tN, t1]
571+
# FYI: in TF they mix data on the dataset level
572+
images_rolled = images[-1:] + images[:-1]
573+
targets_rolled = targets[-1:] + targets[:-1]
574+
575+
output_images: List[torch.Tensor] = []
576+
output_targets: List[Dict[str, Tensor]] = []
577+
578+
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
579+
output_image, output_data = _copy_paste(
580+
image,
581+
target,
582+
paste_image,
583+
paste_target,
584+
blending=self.blending,
585+
resize_interpolation=self.resize_interpolation,
586+
)
587+
output_images.append(output_image)
588+
output_targets.append(output_data)
589+
590+
return output_images, output_targets
591+
592+
def __repr__(self) -> str:
593+
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
594+
return s

0 commit comments

Comments
 (0)