Skip to content

Commit d63c22e

Browse files
authored
Merge pull request huggingface#13 from ROCmSoftwarePlatform/include_stable_train_metrics
Adding code to print stable_train_samples_per_second values
2 parents dc78c95 + 26e1ad6 commit d63c22e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/transformers/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,7 @@ def train(
13731373
# AT THE VERY END!
13741374
_ = list(train_dataloader.sampler)
13751375

1376+
start_train_stable_time = 0
13761377
for epoch in range(epochs_trained, num_train_epochs):
13771378
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
13781379
train_dataloader.sampler.set_epoch(epoch)
@@ -1402,6 +1403,9 @@ def train(
14021403
step = -1
14031404
for step, inputs in enumerate(epoch_iterator):
14041405

1406+
if (self.state.global_step == 10):
1407+
start_train_stable_time = time.time()
1408+
14051409
# Skip past any already trained steps if resuming training
14061410
if steps_trained_in_current_epoch > 0:
14071411
steps_trained_in_current_epoch -= 1
@@ -1549,6 +1553,11 @@ def train(
15491553
train_loss = self._total_loss_scalar / self.state.global_step
15501554

15511555
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1556+
1557+
total_samples = args.max_steps*total_train_batch_size if args.max_steps > 0 else num_examples*num_train_epochs
1558+
perf_samples = total_samples - 10*total_train_batch_size
1559+
stable_train_metrics = speed_metrics("stable_train", start_train_stable_time, perf_samples)
1560+
15521561
self.store_flos()
15531562
metrics["total_flos"] = self.state.total_flos
15541563
metrics["train_loss"] = train_loss
@@ -1559,6 +1568,8 @@ def train(
15591568

15601569
self.log(metrics)
15611570

1571+
self.log(stable_train_metrics)
1572+
15621573
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
15631574

15641575
return TrainOutput(self.state.global_step, train_loss, metrics)

0 commit comments

Comments
 (0)