File tree 3 files changed +24
-3
lines changed 3 files changed +24
-3
lines changed Original file line number Diff line number Diff line change @@ -144,7 +144,7 @@ def main(args):
144
144
sampler = test_sampler , num_workers = args .workers , pin_memory = True )
145
145
146
146
print ("Creating model" )
147
- model = torchvision .models .__dict__ [args .model ]()
147
+ model = torchvision .models .__dict__ [args .model ](pretrained = args . pretrained )
148
148
model .to (device )
149
149
if args .distributed and args .sync_bn :
150
150
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -242,6 +242,12 @@ def parse_args():
242
242
help = "Only test the model" ,
243
243
action = "store_true" ,
244
244
)
245
+ parser .add_argument (
246
+ "--pretrained" ,
247
+ dest = "pretrained" ,
248
+ help = "Use pre-trained models from the modelzoo" ,
249
+ action = "store_true" ,
250
+ )
245
251
246
252
# distributed training parameters
247
253
parser .add_argument ('--world-size' , default = 1 , type = int ,
Original file line number Diff line number Diff line change @@ -76,7 +76,8 @@ def main(args):
76
76
collate_fn = utils .collate_fn )
77
77
78
78
print ("Creating model" )
79
- model = torchvision .models .detection .__dict__ [args .model ](num_classes = num_classes )
79
+ model = torchvision .models .detection .__dict__ [args .model ](num_classes = num_classes ,
80
+ pretrained = args .pretrained )
80
81
model .to (device )
81
82
82
83
model_without_ddp = model
@@ -156,6 +157,12 @@ def main(args):
156
157
help = "Only test the model" ,
157
158
action = "store_true" ,
158
159
)
160
+ parser .add_argument (
161
+ "--pretrained" ,
162
+ dest = "pretrained" ,
163
+ help = "Use pre-trained models from the modelzoo" ,
164
+ action = "store_true" ,
165
+ )
159
166
160
167
# distributed training parameters
161
168
parser .add_argument ('--world-size' , default = 1 , type = int ,
Original file line number Diff line number Diff line change @@ -121,7 +121,9 @@ def main(args):
121
121
sampler = test_sampler , num_workers = args .workers ,
122
122
collate_fn = utils .collate_fn )
123
123
124
- model = torchvision .models .segmentation .__dict__ [args .model ](num_classes = num_classes , aux_loss = args .aux_loss )
124
+ model = torchvision .models .segmentation .__dict__ [args .model ](num_classes = num_classes ,
125
+ aux_loss = args .aux_loss ,
126
+ pretrained = args .pretrained )
125
127
model .to (device )
126
128
if args .distributed :
127
129
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -205,6 +207,12 @@ def parse_args():
205
207
help = "Only test the model" ,
206
208
action = "store_true" ,
207
209
)
210
+ parser .add_argument (
211
+ "--pretrained" ,
212
+ dest = "pretrained" ,
213
+ help = "Use pre-trained models from the modelzoo" ,
214
+ action = "store_true" ,
215
+ )
208
216
# distributed training parameters
209
217
parser .add_argument ('--world-size' , default = 1 , type = int ,
210
218
help = 'number of distributed processes' )
You can’t perform that action at this time.
0 commit comments