Skip to content

Commit fb19edf

Browse files
committed
Move check_figures_equal out of decorators.py and back into testing.py
1 parent 04b3f41 commit fb19edf

File tree

4 files changed

+109
-102
lines changed

4 files changed

+109
-102
lines changed

pygmt/helpers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Functions, classes, decorators, and context managers to help wrap GMT modules.
33
"""
4-
from .decorators import check_figures_equal, fmt_docstring, kwargs_to_strings, use_alias
4+
from .decorators import fmt_docstring, use_alias, kwargs_to_strings
55
from .tempfile import GMTTempFile, unique_name
66
from .utils import (
77
data_kind,

pygmt/helpers/decorators.py

+2-100
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@
55
arguments, insert common text into docstrings, transform arguments to strings,
66
etc.
77
"""
8-
import functools
9-
import inspect
10-
import os
118
import textwrap
9+
import functools
1210

1311
import numpy as np
14-
from matplotlib.testing.compare import compare_images
1512

16-
from ..exceptions import GMTImageComparisonFailure, GMTInvalidInput
1713
from .utils import is_nonstr_iter
14+
from ..exceptions import GMTInvalidInput
1815

1916
COMMON_OPTIONS = {
2017
"R": """\
@@ -406,98 +403,3 @@ def remove_bools(kwargs):
406403
else:
407404
new_kwargs[arg] = value
408405
return new_kwargs
409-
410-
411-
def check_figures_equal(*, tol=0.0, result_dir="result_images"):
412-
"""
413-
Decorator for test cases that generate and compare two figures.
414-
415-
The decorated function must take two arguments, *fig_ref* and *fig_test*,
416-
and draw the reference and test images on them. After the function
417-
returns, the figures are saved and compared.
418-
419-
This decorator is practically identical to matplotlib's check_figures_equal
420-
function, but adapted for PyGMT figures. See also the original code at
421-
https://matplotlib.org/3.3.1/api/testing_api.html#
422-
matplotlib.testing.decorators.check_figures_equal
423-
424-
Parameters
425-
----------
426-
tol : float
427-
The RMS threshold above which the test is considered failed.
428-
result_dir : str
429-
The directory where the figures will be stored.
430-
431-
Examples
432-
--------
433-
434-
>>> import pytest
435-
>>> @check_figures_equal()
436-
... def test_check_figures_equal(fig_ref, fig_test):
437-
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
438-
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
439-
>>> test_check_figures_equal()
440-
441-
>>> import shutil
442-
>>> @check_figures_equal(result_dir="tmp_result_images")
443-
... def test_check_figures_unequal(fig_ref, fig_test):
444-
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
445-
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
446-
>>> with pytest.raises(GMTImageComparisonFailure):
447-
... test_check_figures_unequal()
448-
>>> shutil.rmtree(path="tmp_result_images")
449-
450-
"""
451-
452-
def decorator(func):
453-
454-
os.makedirs(result_dir, exist_ok=True)
455-
old_sig = inspect.signature(func)
456-
457-
def wrapper(*args, **kwargs):
458-
try:
459-
from ..figure import Figure # pylint: disable=import-outside-toplevel
460-
461-
fig_ref = Figure()
462-
fig_test = Figure()
463-
func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs)
464-
ref_image_path = os.path.join(
465-
result_dir, func.__name__ + "-expected.png"
466-
)
467-
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
468-
fig_ref.savefig(ref_image_path)
469-
fig_test.savefig(test_image_path)
470-
471-
# Code below is adapted for PyGMT, and is originally based on
472-
# matplotlib.testing.decorators._raise_on_image_difference
473-
err = compare_images(
474-
expected=ref_image_path,
475-
actual=test_image_path,
476-
tol=tol,
477-
in_decorator=True,
478-
)
479-
if err is None: # Images are the same
480-
os.remove(ref_image_path)
481-
os.remove(test_image_path)
482-
else: # Images are not the same
483-
for key in ["actual", "expected", "diff"]:
484-
err[key] = os.path.relpath(err[key])
485-
raise GMTImageComparisonFailure(
486-
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
487-
% err
488-
)
489-
finally:
490-
del fig_ref
491-
del fig_test
492-
493-
parameters = [
494-
param
495-
for param in old_sig.parameters.values()
496-
if param.name not in {"fig_test", "fig_ref"}
497-
]
498-
new_sig = old_sig.replace(parameters=parameters)
499-
wrapper.__signature__ = new_sig
500-
501-
return wrapper
502-
503-
return decorator

pygmt/helpers/testing.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Helper functions for testing.
3+
"""
4+
5+
import inspect
6+
import os
7+
8+
from matplotlib.testing.compare import compare_images
9+
10+
from ..exceptions import GMTImageComparisonFailure
11+
from ..figure import Figure
12+
13+
14+
def check_figures_equal(*, tol=0.0, result_dir="result_images"):
15+
"""
16+
Decorator for test cases that generate and compare two figures.
17+
18+
The decorated function must take two arguments, *fig_ref* and *fig_test*,
19+
and draw the reference and test images on them. After the function
20+
returns, the figures are saved and compared.
21+
22+
This decorator is practically identical to matplotlib's check_figures_equal
23+
function, but adapted for PyGMT figures. See also the original code at
24+
https://matplotlib.org/3.3.1/api/testing_api.html#
25+
matplotlib.testing.decorators.check_figures_equal
26+
27+
Parameters
28+
----------
29+
tol : float
30+
The RMS threshold above which the test is considered failed.
31+
result_dir : str
32+
The directory where the figures will be stored.
33+
34+
Examples
35+
--------
36+
37+
>>> import pytest
38+
>>> import shutil
39+
40+
>>> @check_figures_equal(result_dir="tmp_result_images")
41+
... def test_check_figures_equal(fig_ref, fig_test):
42+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
43+
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
44+
>>> test_check_figures_equal()
45+
46+
>>> @check_figures_equal(result_dir="tmp_result_images")
47+
... def test_check_figures_unequal(fig_ref, fig_test):
48+
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
49+
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
50+
>>> with pytest.raises(GMTImageComparisonFailure):
51+
... test_check_figures_unequal()
52+
53+
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
54+
"""
55+
56+
def decorator(func):
57+
58+
os.makedirs(result_dir, exist_ok=True)
59+
old_sig = inspect.signature(func)
60+
61+
def wrapper(*args, **kwargs):
62+
try:
63+
fig_ref = Figure()
64+
fig_test = Figure()
65+
func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs)
66+
ref_image_path = os.path.join(
67+
result_dir, func.__name__ + "-expected.png"
68+
)
69+
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
70+
fig_ref.savefig(ref_image_path)
71+
fig_test.savefig(test_image_path)
72+
73+
# Code below is adapted for PyGMT, and is originally based on
74+
# matplotlib.testing.decorators._raise_on_image_difference
75+
err = compare_images(
76+
expected=ref_image_path,
77+
actual=test_image_path,
78+
tol=tol,
79+
in_decorator=True,
80+
)
81+
if err is None: # Images are the same
82+
os.remove(ref_image_path)
83+
os.remove(test_image_path)
84+
else: # Images are not the same
85+
for key in ["actual", "expected", "diff"]:
86+
err[key] = os.path.relpath(err[key])
87+
raise GMTImageComparisonFailure(
88+
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
89+
% err
90+
)
91+
finally:
92+
del fig_ref
93+
del fig_test
94+
95+
parameters = [
96+
param
97+
for param in old_sig.parameters.values()
98+
if param.name not in {"fig_test", "fig_ref"}
99+
]
100+
new_sig = old_sig.replace(parameters=parameters)
101+
wrapper.__signature__ = new_sig
102+
103+
return wrapper
104+
105+
return decorator

pygmt/tests/test_grdimage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .. import Figure
99
from ..datasets import load_earth_relief
1010
from ..exceptions import GMTInvalidInput
11-
from ..helpers import check_figures_equal
11+
from ..helpers.testing import check_figures_equal
1212

1313

1414
@pytest.fixture(scope="module", name="grid")

0 commit comments

Comments
 (0)