Skip to content

Commit 9a17cf9

Browse files
carmoccaawaelchli
authored andcommitted
Support special test parametrizations (#10569)
1 parent f96d769 commit 9a17cf9

File tree

9 files changed

+82
-117
lines changed

9 files changed

+82
-117
lines changed

tests/accelerators/test_ddp.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,8 @@ def setup(self, stage: Optional[str] = None) -> None:
109109

110110

111111
@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
112-
def test_ddp_wrapper_16(tmpdir):
113-
_test_ddp_wrapper(tmpdir, precision=16)
114-
115-
116-
@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
117-
def test_ddp_wrapper_32(tmpdir):
118-
_test_ddp_wrapper(tmpdir, precision=32)
119-
120-
121-
def _test_ddp_wrapper(tmpdir, precision):
112+
@pytest.mark.parametrize("precision", (16, 32))
113+
def test_ddp_wrapper(tmpdir, precision):
122114
"""Test parameters to ignore are carried over for DDP."""
123115

124116
class WeirdModule(torch.nn.Module):

tests/callbacks/test_pruning.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,27 +161,18 @@ def test_pruning_callback(
161161

162162

163163
@RunIf(special=True, min_gpus=2)
164-
def test_pruning_callback_ddp_0(tmpdir):
164+
@pytest.mark.parametrize("parameters_to_prune", (False, True))
165+
@pytest.mark.parametrize("use_global_unstructured", (False, True))
166+
def test_pruning_callback_ddp(tmpdir, parameters_to_prune, use_global_unstructured):
165167
train_with_pruning_callback(
166-
tmpdir, parameters_to_prune=False, use_global_unstructured=False, strategy="ddp", gpus=2
168+
tmpdir,
169+
parameters_to_prune=parameters_to_prune,
170+
use_global_unstructured=use_global_unstructured,
171+
strategy="ddp",
172+
gpus=2,
167173
)
168174

169175

170-
@RunIf(special=True, min_gpus=2)
171-
def test_pruning_callback_ddp_1(tmpdir):
172-
train_with_pruning_callback(tmpdir, parameters_to_prune=False, use_global_unstructured=True, strategy="ddp", gpus=2)
173-
174-
175-
@RunIf(special=True, min_gpus=2)
176-
def test_pruning_callback_ddp_2(tmpdir):
177-
train_with_pruning_callback(tmpdir, parameters_to_prune=True, use_global_unstructured=False, strategy="ddp", gpus=2)
178-
179-
180-
@RunIf(special=True, min_gpus=2)
181-
def test_pruning_callback_ddp_3(tmpdir):
182-
train_with_pruning_callback(tmpdir, parameters_to_prune=True, use_global_unstructured=True, strategy="ddp", gpus=2)
183-
184-
185176
@RunIf(min_gpus=2, skip_windows=True)
186177
def test_pruning_callback_ddp_spawn(tmpdir):
187178
train_with_pruning_callback(tmpdir, use_global_unstructured=True, strategy="ddp_spawn", gpus=2)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -522,20 +522,11 @@ def test_tqdm_progress_bar_can_be_pickled():
522522

523523

524524
@RunIf(min_gpus=2, special=True)
525-
def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir):
526-
_test_progress_bar_max_val_check_interval(
527-
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2
528-
)
529-
530-
531-
@RunIf(min_gpus=2, special=True)
532-
def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir):
533-
_test_progress_bar_max_val_check_interval(
534-
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5
535-
)
536-
537-
538-
def _test_progress_bar_max_val_check_interval(
525+
@pytest.mark.parametrize(
526+
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
527+
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
528+
)
529+
def test_progress_bar_max_val_check_interval(
539530
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
540531
):
541532
world_size = 2

tests/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,8 @@ def training_step(self, batch, batch_idx):
8888

8989
@mock.patch("torch.save")
9090
@RunIf(special=True, min_gpus=2)
91-
def test_top_k_ddp_0(save_mock, tmpdir):
92-
_top_k_ddp(save_mock, tmpdir, k=1, epochs=1, val_check_interval=1.0, expected=1)
93-
94-
95-
@mock.patch("torch.save")
96-
@RunIf(special=True, min_gpus=2)
97-
def test_top_k_ddp_1(save_mock, tmpdir):
98-
_top_k_ddp(save_mock, tmpdir, k=2, epochs=2, val_check_interval=0.3, expected=4)
99-
100-
101-
def _top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
91+
@pytest.mark.parametrize(["k", "epochs", "val_check_interval", "expected"], [(1, 1, 1.0, 1), (2, 2, 0.3, 4)])
92+
def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected):
10293
class TestModel(BoringModel):
10394
def training_step(self, batch, batch_idx):
10495
local_rank = int(os.getenv("LOCAL_RANK"))

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,16 @@ def single_process_pg():
156156
torch.distributed.destroy_process_group()
157157
os.environ.clear()
158158
os.environ.update(orig_environ)
159+
160+
161+
def pytest_collection_modifyitems(items):
162+
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") != "1":
163+
return
164+
# filter out non-special tests
165+
items[:] = [
166+
item
167+
for item in items
168+
for marker in item.own_markers
169+
# has `@RunIf(special=True)`
170+
if marker.name == "skipif" and marker.kwargs.get("special")
171+
]

tests/helpers/runif.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __new__(
150150
env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0")
151151
conditions.append(env_flag != "1")
152152
reasons.append("Special execution")
153+
# used in tests/conftest.py::pytest_collection_modifyitems
154+
kwargs["special"] = True
153155

154156
if fairscale:
155157
conditions.append(not _FAIRSCALE_AVAILABLE)

tests/models/test_hooks.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,10 @@ def _predict_batch(trainer, model, batches):
423423

424424

425425
@RunIf(deepspeed=True, min_gpus=1, special=True)
426-
def test_trainer_model_hook_system_fit_deepspeed_automatic_optimization(tmpdir):
427-
_run_trainer_model_hook_system_fit(
428-
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=True
429-
)
430-
431-
432-
@RunIf(deepspeed=True, min_gpus=1, special=True)
433-
def test_trainer_model_hook_system_fit_deepspeed_manual_optimization(tmpdir):
426+
@pytest.mark.parametrize("automatic_optimization", (True, False))
427+
def test_trainer_model_hook_system_fit_deepspeed(tmpdir, automatic_optimization):
434428
_run_trainer_model_hook_system_fit(
435-
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=False
429+
dict(gpus=1, precision=16, strategy="deepspeed"), tmpdir, automatic_optimization=automatic_optimization
436430
)
437431

438432

tests/special_tests.sh

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,49 @@ set -e
1717
# this environment variable allows special tests to run
1818
export PL_RUNNING_SPECIAL_TESTS=1
1919
# python arguments
20-
defaults='-m coverage run --source pytorch_lightning --append -m pytest --durations=0 --capture=no --disable-warnings'
20+
defaults='-m coverage run --source pytorch_lightning --append -m pytest --capture=no'
2121

22-
# find tests marked as `@RunIf(special=True)`
23-
grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True')
24-
# file paths
25-
files=$(echo "$grep_output" | cut -f1 -d:)
26-
files_arr=($files)
27-
# line numbers
28-
linenos=$(echo "$grep_output" | cut -f2 -d:)
29-
linenos_arr=($linenos)
22+
# find tests marked as `@RunIf(special=True)`. done manually instead of with pytest because it is faster
23+
grep_output=$(grep --recursive --word-regexp 'tests' 'benchmarks' --regexp 'special=True' --include '*.py' --exclude 'tests/conftest.py')
24+
25+
# file paths, remove duplicates
26+
files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq)
27+
28+
# get the list of parametrizations. we need to call them separately. the last two lines are removed.
29+
# note: if there's a syntax error, this will fail with some garbled output
30+
if [[ "$OSTYPE" == "darwin"* ]]; then
31+
parametrizations=$(pytest $files --collect-only --quiet | tail -r | sed -e '1,3d' | tail -r)
32+
else
33+
parametrizations=$(pytest $files --collect-only --quiet | head -n -2)
34+
fi
35+
parametrizations_arr=($parametrizations)
3036

3137
# tests to skip - space separated
32-
blocklist='test_pytorch_profiler_nested_emit_nvtx'
38+
blocklist='tests/profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx'
3339
report=''
3440

35-
for i in "${!files_arr[@]}"; do
36-
file=${files_arr[$i]}
37-
lineno=${linenos_arr[$i]}
38-
39-
# get code from `@RunIf(special=True)` line to EOF
40-
test_code=$(tail -n +"$lineno" "$file")
41+
for i in "${!parametrizations_arr[@]}"; do
42+
parametrization=${parametrizations_arr[$i]}
4143

42-
# read line by line
43-
while read -r line; do
44-
# if it's a test
45-
if [[ $line == def\ test_* ]]; then
46-
# get the name
47-
test_name=$(echo $line | cut -c 5- | cut -f1 -d\()
44+
# check blocklist
45+
if echo $blocklist | grep -F "${parametrization}"; then
46+
report+="Skipped\t$parametrization\n"
47+
continue
48+
fi
4849

49-
# check blocklist
50-
if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then
51-
report+="Skipped\t$file:$lineno::$test_name\n"
52-
break
53-
fi
50+
# SPECIAL_PATTERN allows filtering the tests to run when debugging.
51+
# use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those
52+
# test with `foo_bar` in their name
53+
if [[ $parametrization != *$SPECIAL_PATTERN* ]]; then
54+
report+="Skipped\t$parametrization\n"
55+
continue
56+
fi
5457

55-
# SPECIAL_PATTERN allows filtering the tests to run when debugging.
56-
# use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those
57-
# test with `foo_bar` in their name
58-
if [[ $line != *$SPECIAL_PATTERN* ]]; then
59-
report+="Skipped\t$file:$lineno::$test_name\n"
60-
break
61-
fi
58+
# run the test
59+
echo "Running ${parametrization}"
60+
python ${defaults} "${parametrization}"
6261

63-
# run the test
64-
report+="Ran\t$file:$lineno::$test_name\n"
65-
python ${defaults} "${file}::${test_name}"
66-
break
67-
fi
68-
done < <(echo "$test_code")
62+
report+="Ran\t$parametrization\n"
6963
done
7064

7165
if nvcc --version; then

tests/trainer/test_trainer.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,29 +1453,26 @@ def test_trainer_predict_cpu(tmpdir, datamodule, enable_progress_bar):
14531453

14541454

14551455
@RunIf(min_gpus=2, special=True)
1456-
@pytest.mark.parametrize("num_gpus", [1, 2])
1457-
def test_trainer_predict_dp(tmpdir, num_gpus):
1458-
predict(tmpdir, strategy="dp", accelerator="gpu", devices=num_gpus)
1459-
1460-
1461-
@RunIf(min_gpus=2, special=True, fairscale=True)
1462-
def test_trainer_predict_ddp(tmpdir):
1463-
predict(tmpdir, strategy="ddp", accelerator="gpu", devices=2)
1464-
1465-
1466-
@RunIf(min_gpus=2, skip_windows=True, special=True)
1467-
def test_trainer_predict_ddp_spawn(tmpdir):
1468-
predict(tmpdir, strategy="dp", accelerator="gpu", devices=2)
1456+
@pytest.mark.parametrize(
1457+
"kwargs",
1458+
[
1459+
{"strategy": "dp", "devices": 1},
1460+
{"strategy": "dp", "devices": 2},
1461+
{"strategy": "ddp", "devices": 2},
1462+
],
1463+
)
1464+
def test_trainer_predict_special(tmpdir, kwargs):
1465+
predict(tmpdir, accelerator="gpu", **kwargs)
14691466

14701467

1471-
@RunIf(min_gpus=1, special=True)
1468+
@RunIf(min_gpus=1)
14721469
def test_trainer_predict_1_gpu(tmpdir):
14731470
predict(tmpdir, accelerator="gpu", devices=1)
14741471

14751472

14761473
@RunIf(skip_windows=True)
1477-
def test_trainer_predict_ddp_cpu(tmpdir):
1478-
predict(tmpdir, strategy="ddp_spawn", accelerator="cpu", devices=2)
1474+
def test_trainer_predict_ddp_spawn(tmpdir):
1475+
predict(tmpdir, strategy="ddp_spawn", accelerator="auto", devices=2)
14791476

14801477

14811478
@pytest.mark.parametrize("dataset_cls", [RandomDataset, RandomIterableDatasetWithLen, RandomIterableDataset])

0 commit comments

Comments
 (0)