Skip to content

Commit 7d46930

Browse files
tomsiadevnv-kkudrynski
authored andcommitted
[SIM/TF2] Add epoch tracking in checkpoints.
1 parent 77292de commit 7d46930

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

TensorFlow2/Recommendation/SIM/main.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
def init_checkpoint_manager(model, optimizer, save_checkpoint_path, load_checkpoint_path):
3939
checkpoint = tf.train.Checkpoint(
4040
model=model,
41-
optimizer=optimizer
41+
optimizer=optimizer,
42+
epoch=tf.Variable(-1, name='epoch')
4243
)
4344

4445
checkpoint_manager = tf.train.CheckpointManager(
@@ -283,7 +284,7 @@ def model_step(batch, model, model_fn, optimizer, amp, first_batch):
283284
return loss_dict
284285

285286

286-
def run_single_epoch(model, model_fn, data_iterator, optimizer, amp, epoch, benchmark, performance_calculator):
287+
def run_single_epoch(model, model_fn, data_iterator, optimizer, amp, start_epoch, epoch, benchmark, performance_calculator):
287288

288289
for current_step, batch in enumerate(data_iterator):
289290
if benchmark and performance_calculator.completed:
@@ -296,7 +297,7 @@ def run_single_epoch(model, model_fn, data_iterator, optimizer, amp, epoch, benc
296297
n_samples = len(batch[1])
297298
step_throughput = performance_calculator(n_samples)
298299
step_dict["samples/s"] = step_throughput
299-
dllogger.log(data=step_dict, step=(epoch, current_step))
300+
dllogger.log(data=step_dict, step=(start_epoch + epoch, current_step))
300301

301302

302303
def train(model, model_fn, data_iterator_train, data_iterator_test, optimizer, amp, epochs,
@@ -307,8 +308,10 @@ def train(model, model_fn, data_iterator_train, data_iterator_test, optimizer, a
307308

308309
all_epochs_results = []
309310

310-
for epoch in range(epochs):
311-
run_single_epoch(model, model_fn, data_iterator_train, optimizer, amp, epoch, benchmark, performance_calculator)
311+
start_epoch = checkpoint_manager.checkpoint.epoch.numpy().item() + 1
312+
313+
for epoch in range(epochs - start_epoch):
314+
run_single_epoch(model, model_fn, data_iterator_train, optimizer, amp, start_epoch, epoch, benchmark, performance_calculator)
312315

313316
if not benchmark:
314317
# we dump throughput results for consecutive epochs for a regular training job (w/o --benchmark flag)
@@ -321,10 +324,11 @@ def train(model, model_fn, data_iterator_train, data_iterator_test, optimizer, a
321324
results_data.update(results_eval_test)
322325

323326
if save_checkpoint:
327+
checkpoint_manager.checkpoint.epoch.assign(epoch)
324328
checkpoint_manager.save()
325329

326330
if hvd.rank() == 0:
327-
dllogger.log(data=results_data, step=(epoch,))
331+
dllogger.log(data=results_data, step=(start_epoch + epoch,))
328332

329333
performance_calculator.init() # restart for another epoch
330334

TensorFlow2/Recommendation/SIM/scripts/run_model.sh

+6-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Usage: bash scripts/run_model.sh
2828
--batch_size Batch size.
2929
--results_dir Path to output directory. Default: /tmp/sim.
3030
--log_filename Name of output log file within results_dir. Default: log.json.
31+
--save_checkpoint_path Path to output checkpoint after training.
32+
--load_checkpoint_path Path from which to restore checkpoint for inference or suspend/resume training.
3133
EOF
3234
}
3335

@@ -78,9 +80,12 @@ batch_size_option=$(get_option_or_use_default --global_batch_size $batch_size)
7880
epochs_option=$(get_option_or_use_default --epochs $epochs)
7981
results_dir_option=$(get_option_or_use_default --results_dir $results_dir)
8082
log_filename_option=$(get_option_or_use_default --log_filename $log_filename)
83+
save_checkpoint_path_option=$(get_option_or_use_default --save_checkpoint_path $save_checkpoint_path)
84+
load_checkpoint_path_option=$(get_option_or_use_default --load_checkpoint_path $load_checkpoint_path)
8185

8286
command="mpiexec --allow-run-as-root --bind-to socket -np ${gpus} python main.py --dataset_dir ${data_path} --drop_remainder ${epochs_option}
83-
${xla_arg} ${amp_arg} ${benchmark_arg} ${mode_option} ${benchmark_steps_option} ${batch_size_option} ${results_dir_option} ${log_filename_option}"
87+
${xla_arg} ${amp_arg} ${benchmark_arg} ${mode_option} ${benchmark_steps_option} ${batch_size_option} ${results_dir_option} ${log_filename_option}
88+
${save_checkpoint_path_option} ${load_checkpoint_path_option}"
8489

8590
printf "[INFO] Running:\n%s\n" "${command}"
8691
# run

0 commit comments

Comments
 (0)