Skip to content

Commit 80591f2

Browse files
committed
Add option to skip certain unit tests (#2878)
Summary: Pull Request resolved: #2878 Add option to skip certain tests - useful when a few tests are broken and we need to skip them to gain visibility in other unit tests. This usually happens on our CPU unit tests, so only modifying this script. To use, simply add the name of the tests to skip in this txt file. such as: ``` test_sharding_fused_ebc_as_top_level ``` You can also use the class name for the test: e.g. ``` ModelParallelSparseOnlyTestGloo ``` Differential Revision: D72815908
1 parent 2dc7dc4 commit 80591f2

File tree

4 files changed

+14
-26
lines changed

4 files changed

+14
-26
lines changed

.github/scripts/tests_to_skip.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_disabled_in_oss_compatibility

.github/workflows/unittest_ci_cpu.yml

+9-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ jobs:
7272
python -c "import numpy"
7373
echo "numpy succeeded"
7474
conda install -n build_binary -y pytest
75+
# Read the list of tests to skip from a file, ignoring empty lines and comments
76+
skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
77+
# Check if skip_expression is effectively empty
78+
if [ -z "$skip_expression" ]; then
79+
skip_expression=""
80+
else
81+
skip_expression=${skip_expression:5} # Remove the leading " and "
82+
fi
7583
conda run -n build_binary \
7684
python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
77-
--ignore-glob=**/test_utils/
85+
--ignore-glob=**/test_utils/ -k "$skip_expression"

torchrec/ir/tests/test_serializer.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,7 @@ def test_serialize_deserialize_ebc(self) -> None:
253253
self.assertEqual(deserialized.shape, orginal.shape)
254254
self.assertTrue(torch.allclose(deserialized, orginal))
255255

256-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
257-
@unittest.skipIf(
258-
torch.cuda.device_count() == 0,
259-
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
260-
)
261-
def test_dynamic_shape_ebc(self) -> None:
262-
# TODO: https://fb.workplace.com/groups/1028545332188949/permalink/1138699244506890/
256+
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
263257
model = self.generate_model()
264258
feature1 = KeyedJaggedTensor.from_offsets_sync(
265259
keys=["f1", "f2", "f3"],

torchrec/models/experimental/test_transformerdlrm.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def test_larger(self) -> None:
6161
concat_dense = inter_arch(dense_features, sparse_features)
6262
self.assertEqual(concat_dense.size(), (B, D * (F + 1)))
6363

64-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
65-
@unittest.skipIf(
66-
torch.cuda.device_count() == 0,
67-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
68-
)
69-
def test_correctness(self) -> None:
64+
def test_correctness_disabled_in_oss_compatibility(self) -> None:
7065
D = 4
7166
B = 3
7267
# multi-head attentions
@@ -170,12 +165,7 @@ def test_correctness(self) -> None:
170165
)
171166
)
172167

173-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
174-
@unittest.skipIf(
175-
torch.cuda.device_count() == 0,
176-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
177-
)
178-
def test_numerical_stability(self) -> None:
168+
def test_numerical_stability_disabled_in_oss_compatibility(self) -> None:
179169
D = 4
180170
B = 3
181171
# multi-head attentions
@@ -204,12 +194,7 @@ def test_numerical_stability(self) -> None:
204194

205195

206196
class DLRMTransformerTest(unittest.TestCase):
207-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
208-
@unittest.skipIf(
209-
torch.cuda.device_count() == 0,
210-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
211-
)
212-
def test_basic(self) -> None:
197+
def test_basic_disabled_in_oss_compatibility(self) -> None:
213198
torch.manual_seed(0)
214199
B = 2
215200
D = 8

0 commit comments

Comments
 (0)