Skip to content

Commit 396f170

Browse files
committed
override for fix from huggingface/transformers#37162
1 parent 6c305bb commit 396f170

File tree

6 files changed

+78
-10
lines changed

6 files changed

+78
-10
lines changed

src/axolotl/core/trainers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from axolotl.core.trainers.mixins import (
3030
OptimizerMixin,
31+
RngLoaderMixin,
3132
SchedulerMixin,
3233
SequenceParallelMixin,
3334
)
@@ -40,7 +41,9 @@
4041
LOG = logging.getLogger(__name__)
4142

4243

43-
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
44+
class AxolotlTrainer(
45+
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
46+
):
4447
"""Extend the base Trainer for axolotl helpers"""
4548

4649
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]

src/axolotl/core/trainers/dpo/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers.utils import is_sagemaker_mp_enabled
1414
from trl import DPOTrainer
1515

16-
from axolotl.core.trainers.mixins import SchedulerMixin
16+
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
1717
from axolotl.core.trainers.utils import (
1818
sanitize_kwargs_for_ds_tagging,
1919
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@
2323
import smdistributed.modelparallel.torch as smp
2424

2525

26-
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
26+
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
2727
"""
2828
Extend the base DPOTrainer for axolotl helpers
2929
"""

src/axolotl/core/trainers/grpo/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from trl import GRPOTrainer
99
from trl.extras.profiling import profiling_decorator
1010

11-
from axolotl.core.trainers.base import SchedulerMixin
11+
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
1212

1313
if is_deepspeed_available():
1414
import deepspeed
1515

1616

17-
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
17+
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
1818
"""
1919
Extend the base GRPOTrainer for axolotl helpers
2020
"""

src/axolotl/core/trainers/mixins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# flake8: noqa
55

66
from .optimizer import OptimizerMixin
7+
from .rng_state_loader import RngLoaderMixin
78
from .scheduler import SchedulerMixin
89
from .sequence_parallel import SequenceParallelMixin
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Temporary fix/override for bug in resume from checkpoint
3+
4+
See https://github.com/huggingface/transformers/pull/37162
5+
"""
6+
7+
import logging
8+
import os
9+
import random
10+
11+
import numpy as np
12+
import torch
13+
from transformers import Trainer, is_torch_npu_available
14+
from transformers.trainer import safe_globals
15+
from transformers.trainer_pt_utils import set_rng_state_for_device
16+
from transformers.training_args import ParallelMode
17+
18+
LOG = logging.getLogger(__name__)
19+
20+
21+
class RngLoaderMixin(Trainer):
22+
"""
23+
mixin for method override to load RNG states from a checkpoint
24+
"""
25+
26+
def _load_rng_state(self, checkpoint):
27+
# Load RNG states from `checkpoint`
28+
if checkpoint is None:
29+
return
30+
31+
if self.args.world_size > 1:
32+
process_index = self.args.process_index
33+
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
34+
if not os.path.isfile(rng_file):
35+
LOG.info(
36+
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
37+
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
38+
)
39+
return
40+
else:
41+
rng_file = os.path.join(checkpoint, "rng_state.pth")
42+
if not os.path.isfile(rng_file):
43+
LOG.info(
44+
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
45+
"fashion, reproducibility is not guaranteed."
46+
)
47+
return
48+
49+
with safe_globals():
50+
checkpoint_rng_state = torch.load(rng_file) # nosec B614
51+
random.setstate(checkpoint_rng_state["python"])
52+
np.random.set_state(checkpoint_rng_state["numpy"])
53+
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
54+
55+
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
56+
if torch.cuda.is_available():
57+
set_rng_state_for_device(
58+
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
59+
)
60+
if is_torch_npu_available():
61+
set_rng_state_for_device(
62+
"NPU", torch.npu, checkpoint_rng_state, is_distributed
63+
)

src/axolotl/core/trainers/trl.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
RewardTrainer,
1414
)
1515

16+
from axolotl.core.trainers.mixins import RngLoaderMixin
1617
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
1718

1819

@@ -74,7 +75,7 @@ def train(
7475
)
7576

7677

77-
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
78+
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
7879
"""
7980
Extend the base ORPOTrainer for axolotl helpers
8081
"""
@@ -154,15 +155,15 @@ def get_batch_loss_metrics(
154155
return loss, metrics
155156

156157

157-
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
158+
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
158159
"""
159160
Extend the base KTOTrainer for axolotl helpers
160161
"""
161162

162163
tag_names = ["axolotl", "kto"]
163164

164165

165-
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
166+
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
166167
"""
167168
Extend the base CPOTrainer for axolotl helpers
168169
"""
@@ -244,15 +245,15 @@ def get_batch_loss_metrics(
244245
return loss, metrics
245246

246247

247-
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
248+
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
248249
"""
249250
Extend the base RewardTrainer for axolotl helpers
250251
"""
251252

252253
tag_names = ["axolotl", "reward"]
253254

254255

255-
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
256+
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
256257
"""
257258
Extend the base trl.PRMTrainer for axolotl helpers
258259
"""

0 commit comments

Comments
 (0)