Skip to content

Commit 97a585b

Browse files
seismanweiji14
andauthored
Add testing.check_figures_equal to avoid storing baseline images (#555)
* Turn check_figures_equal into a decorator function Also moved test_check_figures_* to a doctest under check_figures_equal. * Ensure pytest fixtures can be used with check_figures_equal decorator * Add notes on using check_figures_equal to CONTRIBUTING.md * Extra checks to ensure image files exist or not Co-authored-by: Wei Ji <[email protected]>
1 parent a3d6f84 commit 97a585b

File tree

4 files changed

+163
-4
lines changed

4 files changed

+163
-4
lines changed

CONTRIBUTING.md

+32-2
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,38 @@ Leave a comment in the PR and we'll help you out.
310310

311311
### Testing plots
312312

313-
We use the [pytest-mpl](https://github.com/matplotlib/pytest-mpl) plug-in to test plot
314-
generating code.
313+
Writing an image-based test is only slightly more difficult than a simple test.
314+
The main consideration is that you must specify the "baseline" or reference
315+
image, and compare it with a "generated" or test image. This is handled using
316+
the *decorator* functions `@check_figures_equal` and
317+
`@pytest.mark.mpl_image_compare` whose usage are further described below.
318+
319+
#### Using check_figures_equal
320+
321+
This approach draws the same figure using two different methods (the reference
322+
method and the tested method), and checks that both of them are the same.
323+
It takes two `pygmt.Figure` objects ('fig_ref' and 'fig_test'), generates a png
324+
image, and checks for the Root Mean Square (RMS) error between the two.
325+
Here's an example:
326+
327+
```python
328+
@check_figures_equal()
329+
def test_my_plotting_case(fig_ref, fig_test):
330+
"Test that my plotting function works"
331+
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
332+
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
333+
```
334+
335+
Note: This is the recommended way to test plots whenever possible, such as when
336+
we want to compare a reference GMT plot created from NetCDF files with one
337+
generated by PyGMT that passes through several layers of virtualfile machinery.
338+
Using this method will help save space in the git repository by not having to
339+
store baseline images as with the other method below.
340+
341+
#### Using mpl_image_compare
342+
343+
This method uses the [pytest-mpl](https://github.com/matplotlib/pytest-mpl)
344+
plug-in to test plot generating code.
315345
Every time the tests are run, `pytest-mpl` compares the generated plots with known
316346
correct ones stored in `pygmt/tests/baseline`.
317347
If your test created a `pygmt.Figure` object, you can test it by adding a *decorator* and

pygmt/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ class GMTVersionError(GMTError):
4444
"""
4545
Raised when an incompatible version of GMT is being used.
4646
"""
47+
48+
49+
class GMTImageComparisonFailure(AssertionError):
50+
"""
51+
Raised when a comparison between two images fails.
52+
"""

pygmt/helpers/testing.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Helper functions for testing.
3+
"""
4+
5+
import inspect
6+
import os
7+
8+
from matplotlib.testing.compare import compare_images
9+
10+
from ..exceptions import GMTImageComparisonFailure
11+
from ..figure import Figure
12+
13+
14+
def check_figures_equal(*, tol=0.0, result_dir="result_images"):
15+
"""
16+
Decorator for test cases that generate and compare two figures.
17+
18+
The decorated function must take two arguments, *fig_ref* and *fig_test*,
19+
and draw the reference and test images on them. After the function
20+
returns, the figures are saved and compared.
21+
22+
This decorator is practically identical to matplotlib's check_figures_equal
23+
function, but adapted for PyGMT figures. See also the original code at
24+
https://matplotlib.org/3.3.1/api/testing_api.html#
25+
matplotlib.testing.decorators.check_figures_equal
26+
27+
Parameters
28+
----------
29+
tol : float
30+
The RMS threshold above which the test is considered failed.
31+
result_dir : str
32+
The directory where the figures will be stored.
33+
34+
Examples
35+
--------
36+
37+
>>> import pytest
38+
>>> import shutil
39+
40+
>>> @check_figures_equal(result_dir="tmp_result_images")
41+
... def test_check_figures_equal(fig_ref, fig_test):
42+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
43+
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
44+
>>> test_check_figures_equal()
45+
>>> assert len(os.listdir("tmp_result_images")) == 0
46+
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
47+
48+
>>> @check_figures_equal(result_dir="tmp_result_images")
49+
... def test_check_figures_unequal(fig_ref, fig_test):
50+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
51+
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
52+
>>> with pytest.raises(GMTImageComparisonFailure):
53+
... test_check_figures_unequal()
54+
>>> for suffix in ["", "-expected", "-failed-diff"]:
55+
... assert os.path.exists(
56+
... os.path.join(
57+
... "tmp_result_images",
58+
... f"test_check_figures_unequal{suffix}.png",
59+
... )
60+
... )
61+
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
62+
"""
63+
64+
def decorator(func):
65+
66+
os.makedirs(result_dir, exist_ok=True)
67+
old_sig = inspect.signature(func)
68+
69+
def wrapper(*args, **kwargs):
70+
try:
71+
fig_ref = Figure()
72+
fig_test = Figure()
73+
func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs)
74+
ref_image_path = os.path.join(
75+
result_dir, func.__name__ + "-expected.png"
76+
)
77+
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
78+
fig_ref.savefig(ref_image_path)
79+
fig_test.savefig(test_image_path)
80+
81+
# Code below is adapted for PyGMT, and is originally based on
82+
# matplotlib.testing.decorators._raise_on_image_difference
83+
err = compare_images(
84+
expected=ref_image_path,
85+
actual=test_image_path,
86+
tol=tol,
87+
in_decorator=True,
88+
)
89+
if err is None: # Images are the same
90+
os.remove(ref_image_path)
91+
os.remove(test_image_path)
92+
else: # Images are not the same
93+
for key in ["actual", "expected", "diff"]:
94+
err[key] = os.path.relpath(err[key])
95+
raise GMTImageComparisonFailure(
96+
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
97+
% err
98+
)
99+
finally:
100+
del fig_ref
101+
del fig_test
102+
103+
parameters = [
104+
param
105+
for param in old_sig.parameters.values()
106+
if param.name not in {"fig_test", "fig_ref"}
107+
]
108+
new_sig = old_sig.replace(parameters=parameters)
109+
wrapper.__signature__ = new_sig
110+
111+
return wrapper
112+
113+
return decorator

pygmt/tests/test_grdimage.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
Test Figure.grdimage
33
"""
44
import numpy as np
5-
import xarray as xr
65
import pytest
6+
import xarray as xr
77

88
from .. import Figure
9-
from ..exceptions import GMTInvalidInput
109
from ..datasets import load_earth_relief
10+
from ..exceptions import GMTInvalidInput
11+
from ..helpers.testing import check_figures_equal
1112

1213

1314
@pytest.fixture(scope="module", name="grid")
@@ -93,3 +94,12 @@ def test_grdimage_over_dateline(xrgrid):
9394
xrgrid.gmt.gtype = 1 # geographic coordinate system
9495
fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i")
9596
return fig
97+
98+
99+
@check_figures_equal()
100+
def test_grdimage_central_longitude(grid, fig_ref, fig_test):
101+
"""
102+
Test that plotting a grid centred at different longitudes/meridians work.
103+
"""
104+
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
105+
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")

0 commit comments

Comments
 (0)