diff --git a/test/test_models.py b/test/test_models.py index 0acef4dcef6..3e2d9ddf4c2 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -603,7 +603,7 @@ def test_classification_model(model_fn, dev): "input_shape": (1, 3, 224, 224), } model_name = model_fn.__name__ - if dev == "cuda" and SKIP_BIG_MODEL and model_name in skipped_big_models: + if SKIP_BIG_MODEL and model_name in skipped_big_models: pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes")