Skip to content

Commit 74feb19

Browse files
authored
Skip big models per platform/device (#6539)
* Skip big models per platform/device * Specifying skips on Windows only. * Simplify and clean up code.
1 parent 9b432d0 commit 74feb19

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

test/test_models.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import operator
44
import os
55
import pkgutil
6+
import platform
67
import sys
78
import warnings
89
from collections import OrderedDict
@@ -343,12 +344,25 @@ def _check_input_backprop(model, inputs):
343344
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
344345

345346

346-
# skip big models to reduce memory usage on CI test
347+
# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
347348
skipped_big_models = {
348-
"vit_h_14",
349-
"regnet_y_128gf",
349+
"vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")},
350+
"regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")},
351+
"mvit_v1_b": {("Windows", "cuda")},
352+
"mvit_v2_s": {("Windows", "cuda")},
350353
}
351354

355+
356+
def is_skippable(model_name, device):
357+
if model_name not in skipped_big_models:
358+
return False
359+
360+
platform_system = platform.system()
361+
device_name = str(device).split(":")[0]
362+
363+
return (platform_system, device_name) in skipped_big_models[model_name]
364+
365+
352366
# The following contains configuration and expected values to be used tests that are model specific
353367
_model_tests_values = {
354368
"retinanet_resnet50_fpn": {
@@ -612,7 +626,7 @@ def test_classification_model(model_fn, dev):
612626
"input_shape": (1, 3, 224, 224),
613627
}
614628
model_name = model_fn.__name__
615-
if SKIP_BIG_MODEL and model_name in skipped_big_models:
629+
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
616630
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
617631
kwargs = {**defaults, **_model_params.get(model_name, {})}
618632
num_classes = kwargs.get("num_classes")
@@ -841,7 +855,7 @@ def test_video_model(model_fn, dev):
841855
"num_classes": 50,
842856
}
843857
model_name = model_fn.__name__
844-
if SKIP_BIG_MODEL and model_name in skipped_big_models:
858+
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
845859
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
846860
kwargs = {**defaults, **_model_params.get(model_name, {})}
847861
num_classes = kwargs.get("num_classes")

0 commit comments

Comments
 (0)