@@ -1373,6 +1373,7 @@ def train(
1373
1373
# AT THE VERY END!
1374
1374
_ = list (train_dataloader .sampler )
1375
1375
1376
+ start_train_stable_time = 0
1376
1377
for epoch in range (epochs_trained , num_train_epochs ):
1377
1378
if isinstance (train_dataloader , DataLoader ) and isinstance (train_dataloader .sampler , DistributedSampler ):
1378
1379
train_dataloader .sampler .set_epoch (epoch )
@@ -1402,6 +1403,9 @@ def train(
1402
1403
step = - 1
1403
1404
for step , inputs in enumerate (epoch_iterator ):
1404
1405
1406
+ if (self .state .global_step == 10 ):
1407
+ start_train_stable_time = time .time ()
1408
+
1405
1409
# Skip past any already trained steps if resuming training
1406
1410
if steps_trained_in_current_epoch > 0 :
1407
1411
steps_trained_in_current_epoch -= 1
@@ -1549,6 +1553,11 @@ def train(
1549
1553
train_loss = self ._total_loss_scalar / self .state .global_step
1550
1554
1551
1555
metrics = speed_metrics ("train" , start_time , num_samples = num_train_samples , num_steps = self .state .max_steps )
1556
+
1557
+ total_samples = args .max_steps * total_train_batch_size if args .max_steps > 0 else num_examples * num_train_epochs
1558
+ perf_samples = total_samples - 10 * total_train_batch_size
1559
+ stable_train_metrics = speed_metrics ("stable_train" , start_train_stable_time , perf_samples )
1560
+
1552
1561
self .store_flos ()
1553
1562
metrics ["total_flos" ] = self .state .total_flos
1554
1563
metrics ["train_loss" ] = train_loss
@@ -1559,6 +1568,8 @@ def train(
1559
1568
1560
1569
self .log (metrics )
1561
1570
1571
+ self .log (stable_train_metrics )
1572
+
1562
1573
self .control = self .callback_handler .on_train_end (args , self .state , self .control )
1563
1574
1564
1575
return TrainOutput (self .state .global_step , train_loss , metrics )
0 commit comments