|
3 | 3 | import operator
|
4 | 4 | import os
|
5 | 5 | import pkgutil
|
| 6 | +import platform |
6 | 7 | import sys
|
7 | 8 | import warnings
|
8 | 9 | from collections import OrderedDict
|
@@ -343,12 +344,25 @@ def _check_input_backprop(model, inputs):
|
343 | 344 | _model_params[m] = {"input_shape": (1, 3, 64, 64)}
|
344 | 345 |
|
345 | 346 |
|
346 |
| -# skip big models to reduce memory usage on CI test |
| 347 | +# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device). |
347 | 348 | skipped_big_models = {
|
348 |
| - "vit_h_14", |
349 |
| - "regnet_y_128gf", |
| 349 | + "vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")}, |
| 350 | + "regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")}, |
| 351 | + "mvit_v1_b": {("Windows", "cuda")}, |
| 352 | + "mvit_v2_s": {("Windows", "cuda")}, |
350 | 353 | }
|
351 | 354 |
|
| 355 | + |
| 356 | +def is_skippable(model_name, device): |
| 357 | + if model_name not in skipped_big_models: |
| 358 | + return False |
| 359 | + |
| 360 | + platform_system = platform.system() |
| 361 | + device_name = str(device).split(":")[0] |
| 362 | + |
| 363 | + return (platform_system, device_name) in skipped_big_models[model_name] |
| 364 | + |
| 365 | + |
352 | 366 | # The following contains configuration and expected values to be used tests that are model specific
|
353 | 367 | _model_tests_values = {
|
354 | 368 | "retinanet_resnet50_fpn": {
|
@@ -612,7 +626,7 @@ def test_classification_model(model_fn, dev):
|
612 | 626 | "input_shape": (1, 3, 224, 224),
|
613 | 627 | }
|
614 | 628 | model_name = model_fn.__name__
|
615 |
| - if SKIP_BIG_MODEL and model_name in skipped_big_models: |
| 629 | + if SKIP_BIG_MODEL and is_skippable(model_name, dev): |
616 | 630 | pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
|
617 | 631 | kwargs = {**defaults, **_model_params.get(model_name, {})}
|
618 | 632 | num_classes = kwargs.get("num_classes")
|
@@ -841,7 +855,7 @@ def test_video_model(model_fn, dev):
|
841 | 855 | "num_classes": 50,
|
842 | 856 | }
|
843 | 857 | model_name = model_fn.__name__
|
844 |
| - if SKIP_BIG_MODEL and model_name in skipped_big_models: |
| 858 | + if SKIP_BIG_MODEL and is_skippable(model_name, dev): |
845 | 859 | pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
|
846 | 860 | kwargs = {**defaults, **_model_params.get(model_name, {})}
|
847 | 861 | num_classes = kwargs.get("num_classes")
|
|
0 commit comments