Skip to content

Commit 310ee99

Browse files
Merge pull request #171 from ConorMacBride/fix-unittest
Fix tests which exit before returning a figure or use `unittest.TestCase`
2 parents 48e652f + d84892b commit 310ee99

File tree

5 files changed

+203
-36
lines changed

5 files changed

+203
-36
lines changed

pytest_mpl/plugin.py

+45-30
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,33 @@ def pathify(path):
8181
return Path(path + ext)
8282

8383

84-
def _pytest_pyfunc_call(obj, pyfuncitem):
85-
testfunction = pyfuncitem.obj
86-
funcargs = pyfuncitem.funcargs
87-
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
88-
obj.result = testfunction(**testargs)
89-
return True
84+
def generate_test_name(item):
85+
"""
86+
Generate a unique name for the hash for this test.
87+
"""
88+
if item.cls is not None:
89+
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
90+
else:
91+
name = f"{item.module.__name__}.{item.name}"
92+
return name
93+
94+
95+
def wrap_figure_interceptor(plugin, item):
96+
"""
97+
Intercept and store figures returned by test functions.
98+
"""
99+
# Only intercept figures on marked figure tests
100+
if get_compare(item) is not None:
101+
102+
# Use the full test name as a key to ensure correct figure is being retrieved
103+
test_name = generate_test_name(item)
104+
105+
def figure_interceptor(store, obj):
106+
def wrapper(*args, **kwargs):
107+
store.return_value[test_name] = obj(*args, **kwargs)
108+
return wrapper
109+
110+
item.obj = figure_interceptor(plugin, item.obj)
90111

91112

92113
def pytest_report_header(config, startdir):
@@ -275,6 +296,7 @@ def __init__(self,
275296
self._generated_hash_library = {}
276297
self._test_results = {}
277298
self._test_stats = None
299+
self.return_value = {}
278300

279301
# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
280302
# turn debug prints on only if "-vv" or more passed
@@ -287,7 +309,7 @@ def generate_filename(self, item):
287309
Given a pytest item, generate the figure filename.
288310
"""
289311
if self.config.getini('mpl-use-full-test-name'):
290-
filename = self.generate_test_name(item) + '.png'
312+
filename = generate_test_name(item) + '.png'
291313
else:
292314
compare = get_compare(item)
293315
# Find test name to use as plot name
@@ -298,21 +320,11 @@ def generate_filename(self, item):
298320
filename = str(pathify(filename))
299321
return filename
300322

301-
def generate_test_name(self, item):
302-
"""
303-
Generate a unique name for the hash for this test.
304-
"""
305-
if item.cls is not None:
306-
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
307-
else:
308-
name = f"{item.module.__name__}.{item.name}"
309-
return name
310-
311323
def make_test_results_dir(self, item):
312324
"""
313325
Generate the directory to put the results in.
314326
"""
315-
test_name = pathify(self.generate_test_name(item))
327+
test_name = pathify(generate_test_name(item))
316328
results_dir = self.results_dir / test_name
317329
results_dir.mkdir(exist_ok=True, parents=True)
318330
return results_dir
@@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
526538
pytest.fail(f"Can't find hash library at path {hash_library_filename}")
527539

528540
hash_library = self.load_hash_library(hash_library_filename)
529-
hash_name = self.generate_test_name(item)
541+
hash_name = generate_test_name(item)
530542
baseline_hash = hash_library.get(hash_name, None)
531543
summary['baseline_hash'] = baseline_hash
532544

@@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa
607619
with plt.style.context(style, after_reset=True), switch_backend(backend):
608620

609621
# Run test and get figure object
622+
wrap_figure_interceptor(self, item)
610623
yield
611-
fig = self.result
624+
test_name = generate_test_name(item)
625+
if test_name not in self.return_value:
626+
# Test function did not complete successfully
627+
return
628+
fig = self.return_value[test_name]
612629

613630
if remove_text:
614631
remove_ticks_and_titles(fig)
615632

616-
test_name = self.generate_test_name(item)
617633
result_dir = self.make_test_results_dir(item)
618634

619635
summary = {
@@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa
677693
if summary['status'] == 'skipped':
678694
pytest.skip(summary['status_msg'])
679695

680-
@pytest.hookimpl(tryfirst=True)
681-
def pytest_pyfunc_call(self, pyfuncitem):
682-
return _pytest_pyfunc_call(self, pyfuncitem)
683-
684696
def generate_summary_json(self):
685697
json_file = self.results_dir / 'results.json'
686698
with open(json_file, 'w') as f:
@@ -732,13 +744,16 @@ class FigureCloser:
732744

733745
def __init__(self, config):
734746
self.config = config
747+
self.return_value = {}
735748

736749
@pytest.hookimpl(hookwrapper=True)
737750
def pytest_runtest_call(self, item):
751+
wrap_figure_interceptor(self, item)
738752
yield
739753
if get_compare(item) is not None:
740-
close_mpl_figure(self.result)
741-
742-
@pytest.hookimpl(tryfirst=True)
743-
def pytest_pyfunc_call(self, pyfuncitem):
744-
return _pytest_pyfunc_call(self, pyfuncitem)
754+
test_name = generate_test_name(item)
755+
if test_name not in self.return_value:
756+
# Test function did not complete successfully
757+
return
758+
fig = self.return_value[test_name]
759+
close_mpl_figure(fig)

setup.cfg

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ test =
4242

4343
[tool:pytest]
4444
testpaths = "tests"
45+
markers =
46+
image: run test during image comparison only mode.
47+
hash: run test during hash comparison only mode.
48+
filterwarnings =
49+
error
50+
ignore:distutils Version classes are deprecated
51+
ignore:the imp module is deprecated in favour of importlib
4552

4653
[flake8]
4754
max-line-length = 100

tests/conftest.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
from packaging.version import Version
3+
4+
pytest_plugins = ["pytester"]
5+
6+
if Version(pytest.__version__) < Version("6.2.0"):
7+
@pytest.fixture
8+
def pytester(testdir):
9+
return testdir

tests/test_pytest_mpl.py

+142-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import subprocess
55
from pathlib import Path
6+
from unittest import TestCase
67

78
import matplotlib
89
import matplotlib.ft2font
@@ -259,6 +260,23 @@ def test_succeeds(self):
259260
return fig
260261

261262

263+
class TestClassWithTestCase(TestCase):
264+
265+
# Regression test for a bug that occurred when using unittest.TestCase
266+
267+
def setUp(self):
268+
self.x = [1, 2, 3]
269+
270+
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir_local,
271+
filename='test_succeeds.png',
272+
tolerance=DEFAULT_TOLERANCE)
273+
def test_succeeds(self):
274+
fig = plt.figure()
275+
ax = fig.add_subplot(1, 1, 1)
276+
ax.plot(self.x)
277+
return fig
278+
279+
262280
# hashlib
263281

264282
@pytest.mark.skipif(not hash_library.exists(), reason="No hash library for this mpl version")
@@ -514,8 +532,27 @@ def test_fails(self):
514532
return fig
515533
"""
516534

535+
TEST_FAILING_UNITTEST_TESTCASE = """
536+
from unittest import TestCase
537+
import pytest
538+
import matplotlib.pyplot as plt
539+
class TestClassWithTestCase(TestCase):
540+
def setUp(self):
541+
self.x = [1, 2, 3]
542+
@pytest.mark.mpl_image_compare
543+
def test_fails(self):
544+
fig = plt.figure()
545+
ax = fig.add_subplot(1, 1, 1)
546+
ax.plot(self.x)
547+
return fig
548+
"""
517549

518-
@pytest.mark.parametrize("code", [TEST_FAILING_CLASS, TEST_FAILING_CLASS_SETUP_METHOD])
550+
551+
@pytest.mark.parametrize("code", [
552+
TEST_FAILING_CLASS,
553+
TEST_FAILING_CLASS_SETUP_METHOD,
554+
TEST_FAILING_UNITTEST_TESTCASE,
555+
])
519556
def test_class_fail(code, tmpdir):
520557

521558
test_file = tmpdir.join('test.py').strpath
@@ -529,3 +566,107 @@ def test_class_fail(code, tmpdir):
529566
# If we don't use --mpl option, the test should succeed
530567
code = call_pytest([test_file])
531568
assert code == 0
569+
570+
571+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
572+
def test_user_fail(pytester, runpytest_args):
573+
pytester.makepyfile(
574+
"""
575+
import pytest
576+
@pytest.mark.mpl_image_compare
577+
def test_fail():
578+
pytest.fail("Manually failed by user.")
579+
"""
580+
)
581+
result = pytester.runpytest(*runpytest_args)
582+
result.assert_outcomes(failed=1)
583+
result.stdout.fnmatch_lines("FAILED*Manually failed by user.*")
584+
585+
586+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
587+
def test_user_skip(pytester, runpytest_args):
588+
pytester.makepyfile(
589+
"""
590+
import pytest
591+
@pytest.mark.mpl_image_compare
592+
def test_skip():
593+
pytest.skip("Manually skipped by user.")
594+
"""
595+
)
596+
result = pytester.runpytest(*runpytest_args)
597+
result.assert_outcomes(skipped=1)
598+
599+
600+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
601+
def test_user_importorskip(pytester, runpytest_args):
602+
pytester.makepyfile(
603+
"""
604+
import pytest
605+
@pytest.mark.mpl_image_compare
606+
def test_importorskip():
607+
pytest.importorskip("nonexistantmodule")
608+
"""
609+
)
610+
result = pytester.runpytest(*runpytest_args)
611+
result.assert_outcomes(skipped=1)
612+
613+
614+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
615+
def test_user_xfail(pytester, runpytest_args):
616+
pytester.makepyfile(
617+
"""
618+
import pytest
619+
@pytest.mark.mpl_image_compare
620+
def test_xfail():
621+
pytest.xfail()
622+
"""
623+
)
624+
result = pytester.runpytest(*runpytest_args)
625+
result.assert_outcomes(xfailed=1)
626+
627+
628+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
629+
def test_user_exit_success(pytester, runpytest_args):
630+
pytester.makepyfile(
631+
"""
632+
import pytest
633+
@pytest.mark.mpl_image_compare
634+
def test_exit_success():
635+
pytest.exit("Manually exited by user.", returncode=0)
636+
"""
637+
)
638+
result = pytester.runpytest(*runpytest_args)
639+
result.assert_outcomes()
640+
assert result.ret == 0
641+
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")
642+
643+
644+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
645+
def test_user_exit_failure(pytester, runpytest_args):
646+
pytester.makepyfile(
647+
"""
648+
import pytest
649+
@pytest.mark.mpl_image_compare
650+
def test_exit_fail():
651+
pytest.exit("Manually exited by user.", returncode=1)
652+
"""
653+
)
654+
result = pytester.runpytest(*runpytest_args)
655+
result.assert_outcomes()
656+
assert result.ret == 1
657+
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")
658+
659+
660+
@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
661+
def test_user_function_raises(pytester, runpytest_args):
662+
pytester.makepyfile(
663+
"""
664+
import pytest
665+
@pytest.mark.mpl_image_compare
666+
def test_raises():
667+
raise ValueError("User code raised an exception.")
668+
"""
669+
)
670+
result = pytester.runpytest(*runpytest_args)
671+
result.assert_outcomes(failed=1)
672+
result.stdout.fnmatch_lines("FAILED*ValueError*User code*")

tox.ini

-5
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,3 @@ description = check code style, e.g. with flake8
5151
deps = pre-commit
5252
commands =
5353
pre-commit run --all-files
54-
55-
[pytest]
56-
markers =
57-
image: run test during image comparison only mode.
58-
hash: run test during hash comparison only mode.

0 commit comments

Comments
 (0)