Skip to content

Commit 6655dac

Browse files
committed
ufmt format
1 parent a630986 commit 6655dac

File tree

4 files changed

+84
-62
lines changed

4 files changed

+84
-62
lines changed

references/classification/train.py

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
2424
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
2525
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
2626

27-
header = 'Epoch: [{}]'.format(epoch)
27+
header = "Epoch: [{}]".format(epoch)
2828
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
2929
start_time = time.time()
3030
image, target = image.to(device), target.to(device)
@@ -219,12 +219,18 @@ def main(args):
219219

220220
opt_name = args.opt.lower()
221221
if opt_name.startswith("sgd"):
222-
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
223-
nesterov="nesterov" in opt_name)
224-
elif opt_name == 'rmsprop':
225-
optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
226-
eps=0.0316, alpha=0.9)
227-
elif opt_name == 'adamw':
222+
optimizer = torch.optim.SGD(
223+
parameters,
224+
lr=args.lr,
225+
momentum=args.momentum,
226+
weight_decay=args.weight_decay,
227+
nesterov="nesterov" in opt_name,
228+
)
229+
elif opt_name == "rmsprop":
230+
optimizer = torch.optim.RMSprop(
231+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
232+
)
233+
elif opt_name == "adamw":
228234
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
229235
else:
230236
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
@@ -285,18 +291,18 @@ def main(args):
285291
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
286292

287293
if args.resume:
288-
checkpoint = torch.load(args.resume, map_location='cpu')
289-
model_without_ddp.load_state_dict(checkpoint['model'])
294+
checkpoint = torch.load(args.resume, map_location="cpu")
295+
model_without_ddp.load_state_dict(checkpoint["model"])
290296
if not args.test_only:
291-
optimizer.load_state_dict(checkpoint['optimizer'])
292-
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
293-
args.start_epoch = checkpoint['epoch'] + 1
297+
optimizer.load_state_dict(checkpoint["optimizer"])
298+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
299+
args.start_epoch = checkpoint["epoch"] + 1
294300
if model_ema:
295301
model_ema.load_state_dict(checkpoint["model_ema"])
296302

297303
if args.test_only:
298304
if model_ema:
299-
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
305+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
300306
else:
301307
evaluate(model, criterion, data_loader_test, device=device)
302308
return
@@ -331,42 +337,52 @@ def main(args):
331337

332338
def get_args_parser(add_help=True):
333339
import argparse
334-
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
335-
336-
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
337-
parser.add_argument('--model', default='resnet18', help='model')
338-
parser.add_argument('--device', default='cuda', help='device')
339-
parser.add_argument('-b', '--batch-size', default=32, type=int)
340-
parser.add_argument('--epochs', default=90, type=int, metavar='N',
341-
help='number of total epochs to run')
342-
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
343-
help='number of data loading workers (default: 16)')
344-
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
345-
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
346-
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
347-
help='momentum')
348-
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
349-
metavar='W', help='weight decay (default: 1e-4)',
350-
dest='weight_decay')
351-
parser.add_argument('--norm-weight-decay', default=None, type=float,
352-
help='weight decay for Normalization layers (default: None, same value as --wd)')
353-
parser.add_argument('--label-smoothing', default=0.0, type=float,
354-
help='label smoothing (default: 0.0)',
355-
dest='label_smoothing')
356-
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
357-
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
358-
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
359-
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
360-
parser.add_argument('--lr-warmup-method', default="constant", type=str,
361-
help='the warmup method (default: constant)')
362-
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
363-
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
364-
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
365-
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
366-
parser.add_argument('--output-dir', default='.', help='path where to save')
367-
parser.add_argument('--resume', default='', help='resume from checkpoint')
368-
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
369-
help='start epoch')
340+
341+
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
342+
343+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
344+
parser.add_argument("--model", default="resnet18", help="model")
345+
parser.add_argument("--device", default="cuda", help="device")
346+
parser.add_argument("-b", "--batch-size", default=32, type=int)
347+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
348+
parser.add_argument(
349+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
350+
)
351+
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
352+
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
353+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
354+
parser.add_argument(
355+
"--wd",
356+
"--weight-decay",
357+
default=1e-4,
358+
type=float,
359+
metavar="W",
360+
help="weight decay (default: 1e-4)",
361+
dest="weight_decay",
362+
)
363+
parser.add_argument(
364+
"--norm-weight-decay",
365+
default=None,
366+
type=float,
367+
help="weight decay for Normalization layers (default: None, same value as --wd)",
368+
)
369+
parser.add_argument(
370+
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
371+
)
372+
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
373+
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
374+
parser.add_argument("--lr-scheduler", default="steplr", help="the lr scheduler (default: steplr)")
375+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
376+
parser.add_argument(
377+
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
378+
)
379+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
380+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
381+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
382+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
383+
parser.add_argument("--output-dir", default=".", help="path where to save")
384+
parser.add_argument("--resume", default="", help="resume from checkpoint")
385+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
370386
parser.add_argument(
371387
"--cache-dataset",
372388
dest="cache_dataset",
@@ -412,11 +428,17 @@ def get_args_parser(add_help=True):
412428
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
413429
)
414430
parser.add_argument(
415-
'--model-ema-steps', type=int, default=32,
416-
help='the number of iterations that controls how often to update the EMA model (default: 32)')
431+
"--model-ema-steps",
432+
type=int,
433+
default=32,
434+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
435+
)
417436
parser.add_argument(
418-
'--model-ema-decay', type=float, default=0.99998,
419-
help='decay factor for Exponential Moving Average of model parameters (default: 0.99998)')
437+
"--model-ema-decay",
438+
type=float,
439+
default=0.99998,
440+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
441+
)
420442

421443
return parser
422444

references/classification/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def update_parameters(self, model):
179179
if self.n_averaged == 0:
180180
p_swa.detach().copy_(p_model_)
181181
else:
182-
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
183-
self.n_averaged.to(device)))
182+
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
184183
self.n_averaged += 1
185184

186185

test/test_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
import os
33
from abc import ABC, abstractmethod
44
from functools import lru_cache
5+
from functools import lru_cache
6+
from typing import Tuple
57
from typing import Tuple
68

79
import numpy as np
810
import pytest
911
import torch
10-
from functools import lru_cache
1112
from torch import nn, Tensor
1213
from torch.autograd import gradcheck
1314
from torch.nn.modules.utils import _pair
1415
from torchvision import models, ops
15-
from typing import Tuple
1616

1717

1818
class RoIOpTester(ABC):
@@ -1177,7 +1177,7 @@ def test_stochastic_depth(self, mode, p):
11771177

11781178

11791179
class TestUtils:
1180-
@pytest.mark.parametrize('norm_layer', [None, nn.BatchNorm2d, nn.LayerNorm])
1180+
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
11811181
def test_split_normalization_params(self, norm_layer):
11821182
model = models.mobilenet_v3_large(norm_layer=norm_layer)
11831183
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
@@ -1186,5 +1186,5 @@ def test_split_normalization_params(self, norm_layer):
11861186
assert len(params[1]) == 82
11871187

11881188

1189-
if __name__ == '__main__':
1189+
if __name__ == "__main__":
11901190
pytest.main([__file__])

torchvision/ops/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import List, Union
2+
from typing import List, Optional, Tuple, Union
23

34
import torch
45
from torch import nn, Tensor
5-
from typing import List, Optional, Tuple, Union
66

77

88
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
@@ -39,8 +39,9 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
3939
return
4040

4141

42-
def split_normalization_params(model: nn.Module,
43-
norm_classes: Optional[List[type]] = None) -> Tuple[List[Tensor], List[Tensor]]:
42+
def split_normalization_params(
43+
model: nn.Module, norm_classes: Optional[List[type]] = None
44+
) -> Tuple[List[Tensor], List[Tensor]]:
4445
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
4546
if not norm_classes:
4647
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

0 commit comments

Comments
 (0)