Skip to content

Commit dcffca7

Browse files
authored
Parametrize deepspeed hook test (#11308)
1 parent 7b0272a commit dcffca7

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

tests/models/test_hooks.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -429,29 +429,20 @@ def _predict_batch(trainer, model, batches):
429429
return out
430430

431431

432-
@RunIf(deepspeed=True, min_gpus=1, standalone=True)
433-
@pytest.mark.parametrize("automatic_optimization", (True, False))
434-
def test_trainer_model_hook_system_fit_deepspeed(tmpdir, automatic_optimization):
435-
_run_trainer_model_hook_system_fit(
436-
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=automatic_optimization
437-
)
438-
439-
440432
@pytest.mark.parametrize(
441433
"kwargs",
442434
[
443435
{},
444436
# these precision plugins modify the optimization flow, so testing them explicitly
445437
pytest.param(dict(gpus=1, precision=16, amp_backend="native"), marks=RunIf(min_gpus=1)),
446438
pytest.param(dict(gpus=1, precision=16, amp_backend="apex"), marks=RunIf(amp_apex=True, min_gpus=1)),
439+
pytest.param(
440+
dict(gpus=1, precision=16, strategy="deepspeed"), marks=RunIf(deepspeed=True, min_gpus=1, standalone=True)
441+
),
447442
],
448443
)
449444
@pytest.mark.parametrize("automatic_optimization", (True, False))
450445
def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization):
451-
_run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization)
452-
453-
454-
def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
455446
called = []
456447

457448
class TestModel(HookedModel):

0 commit comments

Comments
 (0)