@@ -67,6 +67,9 @@ def make_parser():
67
67
help = 'manually set random seed for torch' )
68
68
parser .add_argument ('--checkpoint' , type = str , default = None ,
69
69
help = 'path to model checkpoint file' )
70
+ parser .add_argument ('--torchvision-weights-version' , type = str , default = "IMAGENET1K_V2" ,
71
+ choices = ['IMAGENET1K_V1' , 'IMAGENET1K_V2' , 'DEFAULT' ],
72
+ help = 'The torchvision weights version to use when --checkpoint is not specified' )
70
73
parser .add_argument ('--save' , type = str , default = None ,
71
74
help = 'save model checkpoints in the specified directory' )
72
75
parser .add_argument ('--mode' , type = str , default = 'training' ,
@@ -97,9 +100,19 @@ def make_parser():
97
100
' backbone model declared with the --backbone argument.'
98
101
' When it is not provided, pretrained model from torchvision'
99
102
' will be downloaded.' )
100
- parser .add_argument ('--num-workers' , type = int , default = 4 )
101
- parser .add_argument ('--amp' , action = 'store_true' ,
102
- help = 'Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.' )
103
+ parser .add_argument ('--num-workers' , type = int , default = 8 )
104
+ parser .add_argument ("--amp" , dest = 'amp' , action = "store_true" ,
105
+ help = "Enable Automatic Mixed Precision (AMP)." )
106
+ parser .add_argument ("--no-amp" , dest = 'amp' , action = "store_false" ,
107
+ help = "Disable Automatic Mixed Precision (AMP)." )
108
+ parser .set_defaults (amp = True )
109
+ parser .add_argument ("--allow-tf32" , dest = 'allow_tf32' , action = "store_true" ,
110
+ help = "Allow TF32 computations on supported GPUs." )
111
+ parser .add_argument ("--no-allow-tf32" , dest = 'allow_tf32' , action = "store_false" ,
112
+ help = "Disable TF32 computations." )
113
+ parser .set_defaults (allow_tf32 = True )
114
+ parser .add_argument ('--data-layout' , default = "channels_last" , choices = ['channels_first' , 'channels_last' ],
115
+ help = "Model data layout. It's recommended to use channels_first with --no-amp" )
103
116
parser .add_argument ('--log-interval' , type = int , default = 20 ,
104
117
help = 'Logging interval.' )
105
118
parser .add_argument ('--json-summary' , type = str , default = None ,
@@ -150,7 +163,9 @@ def train(train_loop_func, logger, args):
150
163
val_dataset = get_val_dataset (args )
151
164
val_dataloader = get_val_dataloader (val_dataset , args )
152
165
153
- ssd300 = SSD300 (backbone = ResNet (args .backbone , args .backbone_path ))
166
+ ssd300 = SSD300 (backbone = ResNet (backbone = args .backbone ,
167
+ backbone_path = args .backbone_path ,
168
+ weights = args .torchvision_weights_version ))
154
169
args .learning_rate = args .learning_rate * args .N_gpu * (args .batch_size / 32 )
155
170
start_epoch = 0
156
171
iteration = 0
@@ -223,6 +238,7 @@ def train(train_loop_func, logger, args):
223
238
obj ['model' ] = ssd300 .module .state_dict ()
224
239
else :
225
240
obj ['model' ] = ssd300 .state_dict ()
241
+ os .makedirs (args .save , exist_ok = True )
226
242
save_path = os .path .join (args .save , f'epoch_{ epoch } .pt' )
227
243
torch .save (obj , save_path )
228
244
logger .log ('model path' , save_path )
@@ -261,6 +277,8 @@ def log_params(logger, args):
261
277
if args .local_rank == 0 :
262
278
os .makedirs ('./models' , exist_ok = True )
263
279
280
+ torch .backends .cuda .matmul .allow_tf32 = args .allow_tf32
281
+ torch .backends .cudnn .allow_tf32 = args .allow_tf32
264
282
torch .backends .cudnn .benchmark = True
265
283
266
284
# write json only on the main thread
0 commit comments