38
38
def init_checkpoint_manager (model , optimizer , save_checkpoint_path , load_checkpoint_path ):
39
39
checkpoint = tf .train .Checkpoint (
40
40
model = model ,
41
- optimizer = optimizer
41
+ optimizer = optimizer ,
42
+ epoch = tf .Variable (- 1 , name = 'epoch' )
42
43
)
43
44
44
45
checkpoint_manager = tf .train .CheckpointManager (
@@ -283,7 +284,7 @@ def model_step(batch, model, model_fn, optimizer, amp, first_batch):
283
284
return loss_dict
284
285
285
286
286
- def run_single_epoch (model , model_fn , data_iterator , optimizer , amp , epoch , benchmark , performance_calculator ):
287
+ def run_single_epoch (model , model_fn , data_iterator , optimizer , amp , start_epoch , epoch , benchmark , performance_calculator ):
287
288
288
289
for current_step , batch in enumerate (data_iterator ):
289
290
if benchmark and performance_calculator .completed :
@@ -296,7 +297,7 @@ def run_single_epoch(model, model_fn, data_iterator, optimizer, amp, epoch, benc
296
297
n_samples = len (batch [1 ])
297
298
step_throughput = performance_calculator (n_samples )
298
299
step_dict ["samples/s" ] = step_throughput
299
- dllogger .log (data = step_dict , step = (epoch , current_step ))
300
+ dllogger .log (data = step_dict , step = (start_epoch + epoch , current_step ))
300
301
301
302
302
303
def train (model , model_fn , data_iterator_train , data_iterator_test , optimizer , amp , epochs ,
@@ -307,8 +308,10 @@ def train(model, model_fn, data_iterator_train, data_iterator_test, optimizer, a
307
308
308
309
all_epochs_results = []
309
310
310
- for epoch in range (epochs ):
311
- run_single_epoch (model , model_fn , data_iterator_train , optimizer , amp , epoch , benchmark , performance_calculator )
311
+ start_epoch = checkpoint_manager .checkpoint .epoch .numpy ().item () + 1
312
+
313
+ for epoch in range (epochs - start_epoch ):
314
+ run_single_epoch (model , model_fn , data_iterator_train , optimizer , amp , start_epoch , epoch , benchmark , performance_calculator )
312
315
313
316
if not benchmark :
314
317
# we dump throughput results for consecutive epochs for a regular training job (w/o --benchmark flag)
@@ -321,10 +324,11 @@ def train(model, model_fn, data_iterator_train, data_iterator_test, optimizer, a
321
324
results_data .update (results_eval_test )
322
325
323
326
if save_checkpoint :
327
+ checkpoint_manager .checkpoint .epoch .assign (epoch )
324
328
checkpoint_manager .save ()
325
329
326
330
if hvd .rank () == 0 :
327
- dllogger .log (data = results_data , step = (epoch ,))
331
+ dllogger .log (data = results_data , step = (start_epoch + epoch ,))
328
332
329
333
performance_calculator .init () # restart for another epoch
330
334
0 commit comments