Skip to content

Commit 85c08ef

Browse files
authored
Generalize check figures equal to work with pytest.marks (#600)
* Allow check_figures_equal to work with pytest parametrized Following cue of matplotlib's check_figures_equal decorator. * Silence pylint complaints on ALLOWED_CHARS & KEYWORD_ONLY variable names * Fix doctest failures on helpers/testing.py * Update documentation on check_figures_equal
1 parent 6f32a2e commit 85c08ef

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,8 @@ Here's an example:
328328
@check_figures_equal()
329329
def test_my_plotting_case():
330330
"Test that my plotting function works"
331-
fig_ref = Figure()
331+
fig_ref, fig_test = Figure(), Figure()
332332
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
333-
fig_test = Figure()
334333
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
335334
return fig_ref, fig_test
336335
```

pygmt/helpers/testing.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
"""
22
Helper functions for testing.
33
"""
4-
54
import inspect
65
import os
6+
import string
77

88
from matplotlib.testing.compare import compare_images
9-
109
from ..exceptions import GMTImageComparisonFailure
1110

1211

13-
def check_figures_equal(*, tol=0.0, result_dir="result_images"):
12+
def check_figures_equal(*, extensions=("png",), tol=0.0, result_dir="result_images"):
1413
"""
1514
Decorator for test cases that generate and compare two figures.
1615
17-
The decorated function must take two arguments, *fig_ref* and *fig_test*,
18-
and draw the reference and test images on them. After the function
19-
returns, the figures are saved and compared.
16+
The decorated function must return two arguments, *fig_ref* and *fig_test*,
17+
these two figures will then be saved and compared against each other.
2018
2119
This decorator is practically identical to matplotlib's check_figures_equal
2220
function, but adapted for PyGMT figures. See also the original code at
@@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
2523
2624
Parameters
2725
----------
26+
extensions : list
27+
The extensions to test. Default is ["png"].
2828
tol : float
2929
The RMS threshold above which the test is considered failed.
3030
result_dir : str
@@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
6666
... )
6767
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
6868
"""
69+
# pylint: disable=invalid-name
70+
ALLOWED_CHARS = set(string.digits + string.ascii_letters + "_-[]()")
71+
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
6972

7073
def decorator(func):
74+
import pytest
7175

7276
os.makedirs(result_dir, exist_ok=True)
7377
old_sig = inspect.signature(func)
7478

75-
def wrapper(*args, **kwargs):
79+
@pytest.mark.parametrize("ext", extensions)
80+
def wrapper(*args, ext="png", request=None, **kwargs):
81+
if "ext" in old_sig.parameters:
82+
kwargs["ext"] = ext
83+
if "request" in old_sig.parameters:
84+
kwargs["request"] = request
85+
try:
86+
file_name = "".join(c for c in request.node.name if c in ALLOWED_CHARS)
87+
except AttributeError: # 'NoneType' object has no attribute 'node'
88+
file_name = func.__name__
7689
try:
7790
fig_ref, fig_test = func(*args, **kwargs)
78-
ref_image_path = os.path.join(
79-
result_dir, func.__name__ + "-expected.png"
80-
)
81-
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
91+
ref_image_path = os.path.join(result_dir, f"{file_name}-expected.{ext}")
92+
test_image_path = os.path.join(result_dir, f"{file_name}.{ext}")
8293
fig_ref.savefig(ref_image_path)
8394
fig_test.savefig(test_image_path)
8495

@@ -109,9 +120,18 @@ def wrapper(*args, **kwargs):
109120
for param in old_sig.parameters.values()
110121
if param.name not in {"fig_test", "fig_ref"}
111122
]
123+
if "ext" not in old_sig.parameters:
124+
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
125+
if "request" not in old_sig.parameters:
126+
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
112127
new_sig = old_sig.replace(parameters=parameters)
113128
wrapper.__signature__ = new_sig
114129

130+
# reach a bit into pytest internals to hoist the marks from
131+
# our wrapped function
132+
new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
133+
wrapper.pytestmark = new_marks
134+
115135
return wrapper
116136

117137
return decorator

0 commit comments

Comments
 (0)