Skip to content

Commit 115d2eb

Browse files
authored
Add pretrained arg to reference scripts (#935)
Allows for easily evaluating the pre-trained models in the modelzoo
1 parent 6e5599e commit 115d2eb

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

references/classification/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def main(args):
144144
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
145145

146146
print("Creating model")
147-
model = torchvision.models.__dict__[args.model]()
147+
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
148148
model.to(device)
149149
if args.distributed and args.sync_bn:
150150
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -242,6 +242,12 @@ def parse_args():
242242
help="Only test the model",
243243
action="store_true",
244244
)
245+
parser.add_argument(
246+
"--pretrained",
247+
dest="pretrained",
248+
help="Use pre-trained models from the modelzoo",
249+
action="store_true",
250+
)
245251

246252
# distributed training parameters
247253
parser.add_argument('--world-size', default=1, type=int,

references/detection/train.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def main(args):
7676
collate_fn=utils.collate_fn)
7777

7878
print("Creating model")
79-
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes)
79+
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
80+
pretrained=args.pretrained)
8081
model.to(device)
8182

8283
model_without_ddp = model
@@ -156,6 +157,12 @@ def main(args):
156157
help="Only test the model",
157158
action="store_true",
158159
)
160+
parser.add_argument(
161+
"--pretrained",
162+
dest="pretrained",
163+
help="Use pre-trained models from the modelzoo",
164+
action="store_true",
165+
)
159166

160167
# distributed training parameters
161168
parser.add_argument('--world-size', default=1, type=int,

references/segmentation/train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def main(args):
121121
sampler=test_sampler, num_workers=args.workers,
122122
collate_fn=utils.collate_fn)
123123

124-
model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss)
124+
model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,
125+
aux_loss=args.aux_loss,
126+
pretrained=args.pretrained)
125127
model.to(device)
126128
if args.distributed:
127129
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -205,6 +207,12 @@ def parse_args():
205207
help="Only test the model",
206208
action="store_true",
207209
)
210+
parser.add_argument(
211+
"--pretrained",
212+
dest="pretrained",
213+
help="Use pre-trained models from the modelzoo",
214+
action="store_true",
215+
)
208216
# distributed training parameters
209217
parser.add_argument('--world-size', default=1, type=int,
210218
help='number of distributed processes')

0 commit comments

Comments
 (0)