|
3 | 3 | import torch
|
4 | 4 | import torchvision
|
5 | 5 | from torch import nn, Tensor
|
| 6 | +from torchvision import ops |
6 | 7 | from torchvision.transforms import functional as F
|
7 | 8 | from torchvision.transforms import transforms as T, InterpolationMode
|
8 | 9 |
|
@@ -437,3 +438,157 @@ def forward(
|
437 | 438 | )
|
438 | 439 |
|
439 | 440 | return image, target
|
| 441 | + |
| 442 | + |
| 443 | +def _copy_paste( |
| 444 | + image: torch.Tensor, |
| 445 | + target: Dict[str, Tensor], |
| 446 | + paste_image: torch.Tensor, |
| 447 | + paste_target: Dict[str, Tensor], |
| 448 | + blending: bool = True, |
| 449 | + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, |
| 450 | +) -> Tuple[torch.Tensor, Dict[str, Tensor]]: |
| 451 | + |
| 452 | + # Random paste targets selection: |
| 453 | + num_masks = len(paste_target["masks"]) |
| 454 | + |
| 455 | + if num_masks < 1: |
| 456 | + # Such degerante case with num_masks=0 can happen with LSJ |
| 457 | + # Let's just return (image, target) |
| 458 | + return image, target |
| 459 | + |
| 460 | + # We have to please torch script by explicitly specifying dtype as torch.long |
| 461 | + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) |
| 462 | + random_selection = torch.unique(random_selection).to(torch.long) |
| 463 | + |
| 464 | + paste_masks = paste_target["masks"][random_selection] |
| 465 | + paste_boxes = paste_target["boxes"][random_selection] |
| 466 | + paste_labels = paste_target["labels"][random_selection] |
| 467 | + |
| 468 | + masks = target["masks"] |
| 469 | + |
| 470 | + # We resize source and paste data if they have different sizes |
| 471 | + # This is something we introduced here as originally the algorithm works |
| 472 | + # on equal-sized data (for example, coming from LSJ data augmentations) |
| 473 | + size1 = image.shape[-2:] |
| 474 | + size2 = paste_image.shape[-2:] |
| 475 | + if size1 != size2: |
| 476 | + paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) |
| 477 | + paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) |
| 478 | + # resize bboxes: |
| 479 | + ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) |
| 480 | + paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) |
| 481 | + |
| 482 | + paste_alpha_mask = paste_masks.sum(dim=0) > 0 |
| 483 | + |
| 484 | + if blending: |
| 485 | + paste_alpha_mask = F.gaussian_blur( |
| 486 | + paste_alpha_mask.unsqueeze(0), |
| 487 | + kernel_size=(5, 5), |
| 488 | + sigma=[ |
| 489 | + 2.0, |
| 490 | + ], |
| 491 | + ) |
| 492 | + |
| 493 | + # Copy-paste images: |
| 494 | + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) |
| 495 | + |
| 496 | + # Copy-paste masks: |
| 497 | + masks = masks * (~paste_alpha_mask) |
| 498 | + non_all_zero_masks = masks.sum((-1, -2)) > 0 |
| 499 | + masks = masks[non_all_zero_masks] |
| 500 | + |
| 501 | + # Do a shallow copy of the target dict |
| 502 | + out_target = {k: v for k, v in target.items()} |
| 503 | + |
| 504 | + out_target["masks"] = torch.cat([masks, paste_masks]) |
| 505 | + |
| 506 | + # Copy-paste boxes and labels |
| 507 | + boxes = ops.masks_to_boxes(masks) |
| 508 | + out_target["boxes"] = torch.cat([boxes, paste_boxes]) |
| 509 | + |
| 510 | + labels = target["labels"][non_all_zero_masks] |
| 511 | + out_target["labels"] = torch.cat([labels, paste_labels]) |
| 512 | + |
| 513 | + # Update additional optional keys: area and iscrowd if exist |
| 514 | + if "area" in target: |
| 515 | + out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) |
| 516 | + |
| 517 | + if "iscrowd" in target and "iscrowd" in paste_target: |
| 518 | + # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) |
| 519 | + # For example, if previous transforms geometrically modifies masks/boxes/labels but |
| 520 | + # does not update "iscrowd" |
| 521 | + if len(target["iscrowd"]) == len(non_all_zero_masks): |
| 522 | + iscrowd = target["iscrowd"][non_all_zero_masks] |
| 523 | + paste_iscrowd = paste_target["iscrowd"][random_selection] |
| 524 | + out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) |
| 525 | + |
| 526 | + # Check for degenerated boxes and remove them |
| 527 | + boxes = out_target["boxes"] |
| 528 | + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] |
| 529 | + if degenerate_boxes.any(): |
| 530 | + valid_targets = ~degenerate_boxes.any(dim=1) |
| 531 | + |
| 532 | + out_target["boxes"] = boxes[valid_targets] |
| 533 | + out_target["masks"] = out_target["masks"][valid_targets] |
| 534 | + out_target["labels"] = out_target["labels"][valid_targets] |
| 535 | + |
| 536 | + if "area" in out_target: |
| 537 | + out_target["area"] = out_target["area"][valid_targets] |
| 538 | + if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): |
| 539 | + out_target["iscrowd"] = out_target["iscrowd"][valid_targets] |
| 540 | + |
| 541 | + return image, out_target |
| 542 | + |
| 543 | + |
| 544 | +class SimpleCopyPaste(torch.nn.Module): |
| 545 | + def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR): |
| 546 | + super().__init__() |
| 547 | + self.resize_interpolation = resize_interpolation |
| 548 | + self.blending = blending |
| 549 | + |
| 550 | + def forward( |
| 551 | + self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]] |
| 552 | + ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]: |
| 553 | + torch._assert( |
| 554 | + isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]), |
| 555 | + "images should be a list of tensors", |
| 556 | + ) |
| 557 | + torch._assert( |
| 558 | + isinstance(targets, (list, tuple)) and len(images) == len(targets), |
| 559 | + "targets should be a list of the same size as images", |
| 560 | + ) |
| 561 | + for target in targets: |
| 562 | + # Can not check for instance type dict with inside torch.jit.script |
| 563 | + # torch._assert(isinstance(target, dict), "targets item should be a dict") |
| 564 | + for k in ["masks", "boxes", "labels"]: |
| 565 | + torch._assert(k in target, f"Key {k} should be present in targets") |
| 566 | + torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor") |
| 567 | + |
| 568 | + # images = [t1, t2, ..., tN] |
| 569 | + # Let's define paste_images as shifted list of input images |
| 570 | + # paste_images = [t2, t3, ..., tN, t1] |
| 571 | + # FYI: in TF they mix data on the dataset level |
| 572 | + images_rolled = images[-1:] + images[:-1] |
| 573 | + targets_rolled = targets[-1:] + targets[:-1] |
| 574 | + |
| 575 | + output_images: List[torch.Tensor] = [] |
| 576 | + output_targets: List[Dict[str, Tensor]] = [] |
| 577 | + |
| 578 | + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): |
| 579 | + output_image, output_data = _copy_paste( |
| 580 | + image, |
| 581 | + target, |
| 582 | + paste_image, |
| 583 | + paste_target, |
| 584 | + blending=self.blending, |
| 585 | + resize_interpolation=self.resize_interpolation, |
| 586 | + ) |
| 587 | + output_images.append(output_image) |
| 588 | + output_targets.append(output_data) |
| 589 | + |
| 590 | + return output_images, output_targets |
| 591 | + |
| 592 | + def __repr__(self) -> str: |
| 593 | + s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})" |
| 594 | + return s |
0 commit comments