Skip to content

Commit 7668101

Browse files
authored
round epoch only in console (#30237)
1 parent fe2d20d commit 7668101

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3048,7 +3048,7 @@ def log(self, logs: Dict[str, float]) -> None:
30483048
The values to log.
30493049
"""
30503050
if self.state.epoch is not None:
3051-
logs["epoch"] = round(self.state.epoch, 2)
3051+
logs["epoch"] = self.state.epoch
30523052
if self.args.include_num_input_tokens_seen:
30533053
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
30543054

src/transformers/trainer_callback.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616
Callbacks to use with the Trainer class and customize the training loop.
1717
"""
18+
import copy
1819
import dataclasses
1920
import json
2021
from dataclasses import dataclass
@@ -520,7 +521,12 @@ def on_predict(self, args, state, control, **kwargs):
520521

521522
def on_log(self, args, state, control, logs=None, **kwargs):
522523
if state.is_world_process_zero and self.training_bar is not None:
524+
# avoid modifying the logs object as it is shared between callbacks
525+
logs = copy.deepcopy(logs)
523526
_ = logs.pop("total_flos", None)
527+
# round numbers so that it looks better in console
528+
if "epoch" in logs:
529+
logs["epoch"] = round(logs["epoch"], 2)
524530
self.training_bar.write(str(logs))
525531

526532
def on_train_end(self, args, state, control, **kwargs):

0 commit comments

Comments
 (0)