Skip to content

Post-paper Detection Optimizations #5444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0f6fa39
Use frozen BN only if pre-trained.
datumbox Feb 18, 2022
7a94595
Add LSJ and ability to from scratch training.
datumbox Feb 19, 2022
89a5b9d
Fixing formatter
datumbox Feb 19, 2022
20470c1
Merge branch 'main' into references/detection_recipe
datumbox Feb 20, 2022
22d7f47
Merge branch 'main' into references/detection_recipe
datumbox Feb 24, 2022
a0322dd
Merge branch 'main' into references/detection_recipe
datumbox Feb 24, 2022
2943182
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
629e149
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
53fbd71
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
d3b8dad
Merge branch 'main' into references/detection_recipe
datumbox Mar 2, 2022
5aa97c3
Merge branch 'main' into references/detection_recipe
datumbox Mar 4, 2022
8537c48
Adding `--opt` and `--norm-weight-decay` support in Detection.
datumbox Mar 5, 2022
f7f8e2f
Fix error message
datumbox Mar 5, 2022
ed2a24c
Make ScaleJitter proportional.
datumbox Mar 6, 2022
bc7a8a9
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
a1786bb
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
6c12921
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
bcf0afc
Adding more norm layers in split_normalization_params.
datumbox Mar 8, 2022
9c66a7c
Merge branch 'main' into references/detection_recipe
datumbox Mar 10, 2022
65e4116
Add FixedSizeCrop
datumbox Mar 10, 2022
ab63af6
Temporary fix for fill values on PIL
datumbox Mar 10, 2022
7365cdc
Merge branch 'main' into references/detection_recipe
datumbox Mar 11, 2022
c714c66
Fix the bug on fill.
datumbox Mar 11, 2022
c415639
Merge branch 'main' into references/detection_recipe
datumbox Mar 12, 2022
13fb5b3
Add RandomShortestSize.
datumbox Mar 12, 2022
0d230ab
Skip resize when an augmentation method is used.
datumbox Mar 12, 2022
a187917
multiscale in [480, 800]
datumbox Mar 13, 2022
4b4d300
Merge branch 'main' into references/detection_recipe
datumbox Mar 14, 2022
8dd6975
Merge branch 'main' into references/detection_recipe
datumbox Mar 14, 2022
efcf9ed
Merge branch 'main' into references/detection_recipe
datumbox Mar 22, 2022
7542a94
Add missing star
datumbox Mar 23, 2022
c67893c
Add new RetinaNet variant.
datumbox Mar 23, 2022
0d7917c
Add tests.
datumbox Mar 23, 2022
7354684
Update expected file for old retina
datumbox Mar 23, 2022
bb8aac0
Merge branch 'main' into references/detection_recipe
datumbox Mar 23, 2022
38ef843
Fixing tests
datumbox Mar 23, 2022
cd9c302
Add FrozenBN to retinav2
datumbox Mar 24, 2022
29c57f6
Fix network initialization issues
datumbox Mar 24, 2022
19f2b25
Adding BN support in MaskRCNNHeads and FPN
datumbox Mar 24, 2022
124fd8a
Adding support of FasterRCNNHeads
datumbox Mar 24, 2022
53aa8b7
Introduce norm_layers in backbone utils.
datumbox Mar 25, 2022
f9ba509
Bigger RPN head + 2x rcnn v2 models.
datumbox Mar 25, 2022
e5cbb97
Merge branch 'main' into references/detection_recipe
datumbox Mar 30, 2022
592784d
Adding gIoU support to retinanet
datumbox Mar 30, 2022
2cff640
Fix assert
datumbox Mar 30, 2022
a6f0ea7
Merge branch 'main' into references/detection_recipe
datumbox Mar 31, 2022
61412df
Add back nesterov momentum
datumbox Apr 1, 2022
99479ee
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
08307ca
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
a322dd2
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
eb649e8
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
24b8643
Rename and extend `FastRCNNConvFCHead` to support arbitrary FCs
datumbox Apr 4, 2022
6488c41
Fix linter
datumbox Apr 4, 2022
00e182a
Merge branch 'main' into references/detection_recipe
datumbox Apr 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
Expand Down
19 changes: 19 additions & 0 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(min_size=(640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
Expand Down
34 changes: 29 additions & 5 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument(
"--lr",
default=0.02,
Expand All @@ -92,6 +93,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
)
Expand Down Expand Up @@ -151,6 +158,7 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
Expand All @@ -161,8 +169,8 @@ def get_args_parser(add_help=True):
def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if not args.prototype and (args.weights or args.weights_backbone):
raise ValueError("The weights parameters works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -201,6 +209,8 @@ def main(args):

print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation == "multiscale":
kwargs["_skip_resize"] = True
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
Expand All @@ -209,7 +219,9 @@ def main(args):
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model = prototype.models.detection.__dict__[args.model](
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand All @@ -219,8 +231,20 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

Expand Down
113 changes: 112 additions & 1 deletion references/detection/transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, Union

import torch
import torchvision
Expand Down Expand Up @@ -326,3 +326,114 @@ def forward(
)

return image, target


class FixedSizeCrop(nn.Module):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__()
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode

def _pad(self, img, target, padding):
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

padding = [pad_left, pad_top, pad_right, pad_bottom]
img = F.pad(img, padding, self.fill, self.padding_mode)
if target is not None:
target["boxes"][:, 0::2] += pad_left
target["boxes"][:, 1::2] += pad_top
if "masks" in target:
target["masks"] = F.pad(target["masks"], padding, 0, "constant")

return img, target

def _crop(self, img, target, top, left, height, width):
img = F.crop(img, top, left, height, width)
if target is not None:
boxes = target["boxes"]
boxes[:, 0::2] -= left
boxes[:, 1::2] -= top
boxes[:, 0::2].clamp_(min=0, max=width)
boxes[:, 1::2].clamp_(min=0, max=height)

is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])

target["boxes"] = boxes[is_valid]
target["labels"] = target["labels"][is_valid]
if "masks" in target:
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)

return img, target

def forward(self, img, target=None):
_, height, width = F.get_dimensions(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)

if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)

r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)

img, target = self._crop(img, target, top, left, new_height, new_width)

pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])

return img, target


class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
_, orig_height, orig_width = F.get_dimensions(image)

min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))

new_width = int(orig_width * r)
new_height = int(orig_height * r)

image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)

if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)

return image, target
3 changes: 2 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
**kwargs,
):

if not hasattr(backbone, "out_channels"):
Expand Down Expand Up @@ -253,7 +254,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

super().__init__(backbone, rpn, roi_heads, transform)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def __init__(
nms_thresh: float = 0.6,
detections_per_img: int = 100,
topk_candidates: int = 1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.center_sampling_radius = center_sampling_radius
self.score_thresh = score_thresh
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
keypoint_head=None,
keypoint_predictor=None,
num_keypoints=None,
**kwargs,
):

assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
Expand Down Expand Up @@ -247,6 +248,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
**kwargs,
):

assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
Expand Down Expand Up @@ -245,6 +246,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.mask_roi_pool = mask_roi_pool
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def __init__(
fg_iou_thresh=0.5,
bg_iou_thresh=0.4,
topk_candidates=1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -373,7 +374,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25,
**kwargs: Any,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
)

self.score_thresh = score_thresh
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, Any

import torch
import torchvision
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
image_std: List[float],
size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
):
super().__init__()
if not isinstance(min_size, (list, tuple)):
Expand All @@ -101,6 +102,7 @@ def __init__(
self.image_std = image_std
self.size_divisible = size_divisible
self.fixed_size = fixed_size
self._skip_resize = kwargs.pop("_skip_resize", False)

def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
Expand Down Expand Up @@ -167,6 +169,8 @@ def resize(
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:]
if self.training:
if self._skip_resize:
return image, target
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
Expand Down
8 changes: 7 additions & 1 deletion torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def split_normalization_params(
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
norm_classes = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
]

for t in norm_classes:
if not issubclass(t, nn.Module):
Expand Down