Skip to content

Commit 691a34f

Browse files
committed
Merge: [BERT/PyT] [ELECTRA/TF2] resume p2 option, fix early stopping
2 parents 340db9e + 544a2d6 commit 691a34f

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

PyTorch/LanguageModeling/BERT/run_pretraining.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ def parse_arguments():
250250
default=False,
251251
action='store_true',
252252
help="Whether to train with seq len 512")
253+
parser.add_argument('--resume_phase2',
254+
default=False,
255+
action='store_true',
256+
help="Whether to resume training with seq len 512")
253257
parser.add_argument('--allreduce_post_accumulation',
254258
default=False,
255259
action='store_true',
@@ -427,13 +431,13 @@ def prepare_model_and_optimizer(args, device, sequence_output_is_dense):
427431
model.checkpoint_activations(args.checkpoint_activations)
428432

429433
if args.resume_from_checkpoint:
430-
# For phase2, need to reset the learning rate and step count in the checkpoint
431-
if args.phase2 or args.init_checkpoint :
434+
# For phase2 from scratch, need to reset the learning rate and step count in the checkpoint. Else restore values in checkpoint.
435+
if (args.phase2 and not args.resume_phase2) or args.init_checkpoint :
432436
for group in checkpoint['optimizer']['param_groups'] :
433437
group['step'].zero_()
434438
group['lr'].fill_(args.learning_rate)
435439
else :
436-
if 'grad_scaler' in checkpoint and not args.phase2:
440+
if 'grad_scaler' in checkpoint and (not args.phase2 or args.resume_phase2):
437441
grad_scaler.load_state_dict(checkpoint['grad_scaler'])
438442
optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False)
439443

TensorFlow2/LanguageModeling/ELECTRA/run_pretraining.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def main(e2e_start_time):
477477
iter_save_path = iter_manager.save(checkpoint_number=step)
478478
log(" ** Saved iterator checkpoint for step {}: {}".format(step, iter_save_path), all_rank=True)
479479
local_step += 1
480-
if (local_step % (config.steps_this_run * args.gradient_accumulation_steps) == 0):
480+
if config.steps_this_run != -1 and (local_step % (config.steps_this_run * args.gradient_accumulation_steps) == 0):
481481
#terminating run sooner as steps_this_run has been reached
482482
log("terminating as steps_this_run:{} has been reached".format(config.steps_this_run))
483483
break

0 commit comments

Comments
 (0)