1
1
"""
2
2
Helper functions for testing.
3
3
"""
4
-
5
4
import inspect
6
5
import os
6
+ import string
7
7
8
8
from matplotlib .testing .compare import compare_images
9
-
10
9
from ..exceptions import GMTImageComparisonFailure
11
10
12
11
13
- def check_figures_equal (* , tol = 0.0 , result_dir = "result_images" ):
12
+ def check_figures_equal (* , extensions = ( "png" ,), tol = 0.0 , result_dir = "result_images" ):
14
13
"""
15
14
Decorator for test cases that generate and compare two figures.
16
15
17
- The decorated function must take two arguments, *fig_ref* and *fig_test*,
18
- and draw the reference and test images on them. After the function
19
- returns, the figures are saved and compared.
16
+ The decorated function must return two arguments, *fig_ref* and *fig_test*,
17
+ these two figures will then be saved and compared against each other.
20
18
21
19
This decorator is practically identical to matplotlib's check_figures_equal
22
20
function, but adapted for PyGMT figures. See also the original code at
@@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
25
23
26
24
Parameters
27
25
----------
26
+ extensions : list
27
+ The extensions to test. Default is ["png"].
28
28
tol : float
29
29
The RMS threshold above which the test is considered failed.
30
30
result_dir : str
@@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
66
66
... )
67
67
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
68
68
"""
69
+ # pylint: disable=invalid-name
70
+ ALLOWED_CHARS = set (string .digits + string .ascii_letters + "_-[]()" )
71
+ KEYWORD_ONLY = inspect .Parameter .KEYWORD_ONLY
69
72
70
73
def decorator (func ):
74
+ import pytest
71
75
72
76
os .makedirs (result_dir , exist_ok = True )
73
77
old_sig = inspect .signature (func )
74
78
75
- def wrapper (* args , ** kwargs ):
79
+ @pytest .mark .parametrize ("ext" , extensions )
80
+ def wrapper (* args , ext = "png" , request = None , ** kwargs ):
81
+ if "ext" in old_sig .parameters :
82
+ kwargs ["ext" ] = ext
83
+ if "request" in old_sig .parameters :
84
+ kwargs ["request" ] = request
85
+ try :
86
+ file_name = "" .join (c for c in request .node .name if c in ALLOWED_CHARS )
87
+ except AttributeError : # 'NoneType' object has no attribute 'node'
88
+ file_name = func .__name__
76
89
try :
77
90
fig_ref , fig_test = func (* args , ** kwargs )
78
- ref_image_path = os .path .join (
79
- result_dir , func .__name__ + "-expected.png"
80
- )
81
- test_image_path = os .path .join (result_dir , func .__name__ + ".png" )
91
+ ref_image_path = os .path .join (result_dir , f"{ file_name } -expected.{ ext } " )
92
+ test_image_path = os .path .join (result_dir , f"{ file_name } .{ ext } " )
82
93
fig_ref .savefig (ref_image_path )
83
94
fig_test .savefig (test_image_path )
84
95
@@ -109,9 +120,18 @@ def wrapper(*args, **kwargs):
109
120
for param in old_sig .parameters .values ()
110
121
if param .name not in {"fig_test" , "fig_ref" }
111
122
]
123
+ if "ext" not in old_sig .parameters :
124
+ parameters += [inspect .Parameter ("ext" , KEYWORD_ONLY )]
125
+ if "request" not in old_sig .parameters :
126
+ parameters += [inspect .Parameter ("request" , KEYWORD_ONLY )]
112
127
new_sig = old_sig .replace (parameters = parameters )
113
128
wrapper .__signature__ = new_sig
114
129
130
+ # reach a bit into pytest internals to hoist the marks from
131
+ # our wrapped function
132
+ new_marks = getattr (func , "pytestmark" , []) + wrapper .pytestmark
133
+ wrapper .pytestmark = new_marks
134
+
115
135
return wrapper
116
136
117
137
return decorator
0 commit comments