Skip to content

Commit 082834d

Browse files
emapcoSunMarc
andauthored
fix: prevent model access error during Optuna hyperparameter tuning (#36395)
* fix: prevent model access error during Optuna hyperparameter tuning The `transformers.integrations.integration_utils.run_hp_search_optuna` function releases model memory and sets trainer.model to None after each trial. This causes an AttributeError when subsequent Trainer.train calls attempt to access the model before reinitialization. This is only an issue when `fp16_full_eval` or `bf16_full_eval` flags are enabled. * Update src/transformers/trainer.py Co-authored-by: Marc Sun <[email protected]> --------- Co-authored-by: Marc Sun <[email protected]>
1 parent 6513e5e commit 082834d

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/transformers/trainer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,12 @@ def train(
21802180

21812181
# do_train is not a reliable argument, as it might not be set and .train() still called, so
21822182
# the following is a workaround:
2183-
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel:
2183+
if (
2184+
(args.fp16_full_eval or args.bf16_full_eval)
2185+
and not args.do_train
2186+
and not self.is_model_parallel
2187+
and self.model_init is None
2188+
):
21842189
self._move_model_to_device(self.model, args.device)
21852190

21862191
if "model_path" in kwargs:

tests/trainer/test_trainer.py

+32
Original file line numberDiff line numberDiff line change
@@ -4998,6 +4998,38 @@ def compute_objective(metrics: Dict[str, float]) -> List[float]:
49984998
)
49994999

50005000

5001+
@require_torch
5002+
@require_optuna
5003+
class TrainerHyperParameterOptunaIntegrationTestWithFullEval(unittest.TestCase):
5004+
def test_hyperparameter_search(self):
5005+
def hp_space(trial):
5006+
return {}
5007+
5008+
def model_init(trial):
5009+
if trial is not None:
5010+
a = trial.suggest_int("a", -4, 4)
5011+
b = trial.suggest_int("b", -4, 4)
5012+
else:
5013+
a = 0
5014+
b = 0
5015+
config = RegressionModelConfig(a=a, b=b, double_output=False)
5016+
5017+
return RegressionPreTrainedModel(config)
5018+
5019+
with tempfile.TemporaryDirectory() as tmp_dir:
5020+
trainer = get_regression_trainer(
5021+
output_dir=tmp_dir,
5022+
disable_tqdm=True,
5023+
model_init=model_init,
5024+
fp16_full_eval=True,
5025+
)
5026+
trainer.hyperparameter_search(
5027+
direction="minimize",
5028+
hp_space=hp_space,
5029+
n_trials=2,
5030+
)
5031+
5032+
50015033
@require_torch
50025034
@require_ray
50035035
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)