diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index fee83432388856..f2056afee4e423 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -2457,3 +2457,34 @@ def adjust_int_max_str_digits(max_digits): #For recursion tests, easily exceeds default recursion limit EXCEEDS_RECURSION_LIMIT = 5000 + +# A decorator to skip tests that are not rerunnable. +def skip_rerun(reason): + """ + This decorator skips the decorated test case + if it has already been run. + """ + def decorator(test_case): + + original_setUpClass = test_case.setUpClass + original_tearDownClass = test_case.tearDownClass + + test_case._has_run = False + + @classmethod + def setUpClass(cls): + if cls._has_run: + raise unittest.SkipTest(reason) + original_setUpClass() + + @classmethod + def tearDownClass(cls): + original_tearDownClass() + cls._has_run = True + + test_case.setUpClass = setUpClass + test_case.tearDownClass = tearDownClass + + return test_case + + return decorator diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index cd1679dd23281c..58a99fc0168a34 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -18,6 +18,7 @@ from test.support import script_helper from test.support import socket_helper from test.support import warnings_helper +from test.support import skip_rerun TESTFN = os_helper.TESTFN @@ -684,6 +685,31 @@ def test_has_strftime_extensions(self): else: self.assertTrue(support.has_strftime_extensions) + def test_skip_rerun(self): + @skip_rerun("This test case does not support rerunning") + class TestCase(unittest.TestCase): + counter = 0 + + def test(self): + self.assertEqual(self.counter, 0) + self.__class__.counter += 1 + self.assertEqual(self.counter, 1) + + test_case = TestCase("test") + + output = io.StringIO() + unittest.TextTestRunner(output).run(unittest.TestSuite([test_case])) + output.seek(0) + self.assertRegex(output.read(), "Ran 1 test") + + output = io.StringIO() + unittest.TextTestRunner(output).run(unittest.TestSuite([test_case])) + output.seek(0) + self.assertRegex(output.read(), "Ran 0 tests") + + self.assertEqual(TestCase.counter, 1) + + # XXX -follows a list of untested API # make_legacy_pyc # is_resource_enabled