@@ -81,12 +81,33 @@ def pathify(path):
81
81
return Path (path + ext )
82
82
83
83
84
- def _pytest_pyfunc_call (obj , pyfuncitem ):
85
- testfunction = pyfuncitem .obj
86
- funcargs = pyfuncitem .funcargs
87
- testargs = {arg : funcargs [arg ] for arg in pyfuncitem ._fixtureinfo .argnames }
88
- obj .result = testfunction (** testargs )
89
- return True
84
+ def generate_test_name (item ):
85
+ """
86
+ Generate a unique name for the hash for this test.
87
+ """
88
+ if item .cls is not None :
89
+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
90
+ else :
91
+ name = f"{ item .module .__name__ } .{ item .name } "
92
+ return name
93
+
94
+
95
+ def wrap_figure_interceptor (plugin , item ):
96
+ """
97
+ Intercept and store figures returned by test functions.
98
+ """
99
+ # Only intercept figures on marked figure tests
100
+ if get_compare (item ) is not None :
101
+
102
+ # Use the full test name as a key to ensure correct figure is being retrieved
103
+ test_name = generate_test_name (item )
104
+
105
+ def figure_interceptor (store , obj ):
106
+ def wrapper (* args , ** kwargs ):
107
+ store .return_value [test_name ] = obj (* args , ** kwargs )
108
+ return wrapper
109
+
110
+ item .obj = figure_interceptor (plugin , item .obj )
90
111
91
112
92
113
def pytest_report_header (config , startdir ):
@@ -275,6 +296,7 @@ def __init__(self,
275
296
self ._generated_hash_library = {}
276
297
self ._test_results = {}
277
298
self ._test_stats = None
299
+ self .return_value = {}
278
300
279
301
# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
280
302
# turn debug prints on only if "-vv" or more passed
@@ -287,7 +309,7 @@ def generate_filename(self, item):
287
309
Given a pytest item, generate the figure filename.
288
310
"""
289
311
if self .config .getini ('mpl-use-full-test-name' ):
290
- filename = self . generate_test_name (item ) + '.png'
312
+ filename = generate_test_name (item ) + '.png'
291
313
else :
292
314
compare = get_compare (item )
293
315
# Find test name to use as plot name
@@ -298,21 +320,11 @@ def generate_filename(self, item):
298
320
filename = str (pathify (filename ))
299
321
return filename
300
322
301
- def generate_test_name (self , item ):
302
- """
303
- Generate a unique name for the hash for this test.
304
- """
305
- if item .cls is not None :
306
- name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
307
- else :
308
- name = f"{ item .module .__name__ } .{ item .name } "
309
- return name
310
-
311
323
def make_test_results_dir (self , item ):
312
324
"""
313
325
Generate the directory to put the results in.
314
326
"""
315
- test_name = pathify (self . generate_test_name (item ))
327
+ test_name = pathify (generate_test_name (item ))
316
328
results_dir = self .results_dir / test_name
317
329
results_dir .mkdir (exist_ok = True , parents = True )
318
330
return results_dir
@@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
526
538
pytest .fail (f"Can't find hash library at path { hash_library_filename } " )
527
539
528
540
hash_library = self .load_hash_library (hash_library_filename )
529
- hash_name = self . generate_test_name (item )
541
+ hash_name = generate_test_name (item )
530
542
baseline_hash = hash_library .get (hash_name , None )
531
543
summary ['baseline_hash' ] = baseline_hash
532
544
@@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa
607
619
with plt .style .context (style , after_reset = True ), switch_backend (backend ):
608
620
609
621
# Run test and get figure object
622
+ wrap_figure_interceptor (self , item )
610
623
yield
611
- fig = self .result
624
+ test_name = generate_test_name (item )
625
+ if test_name not in self .return_value :
626
+ # Test function did not complete successfully
627
+ return
628
+ fig = self .return_value [test_name ]
612
629
613
630
if remove_text :
614
631
remove_ticks_and_titles (fig )
615
632
616
- test_name = self .generate_test_name (item )
617
633
result_dir = self .make_test_results_dir (item )
618
634
619
635
summary = {
@@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa
677
693
if summary ['status' ] == 'skipped' :
678
694
pytest .skip (summary ['status_msg' ])
679
695
680
- @pytest .hookimpl (tryfirst = True )
681
- def pytest_pyfunc_call (self , pyfuncitem ):
682
- return _pytest_pyfunc_call (self , pyfuncitem )
683
-
684
696
def generate_summary_json (self ):
685
697
json_file = self .results_dir / 'results.json'
686
698
with open (json_file , 'w' ) as f :
@@ -732,13 +744,16 @@ class FigureCloser:
732
744
733
745
def __init__ (self , config ):
734
746
self .config = config
747
+ self .return_value = {}
735
748
736
749
@pytest .hookimpl (hookwrapper = True )
737
750
def pytest_runtest_call (self , item ):
751
+ wrap_figure_interceptor (self , item )
738
752
yield
739
753
if get_compare (item ) is not None :
740
- close_mpl_figure (self .result )
741
-
742
- @pytest .hookimpl (tryfirst = True )
743
- def pytest_pyfunc_call (self , pyfuncitem ):
744
- return _pytest_pyfunc_call (self , pyfuncitem )
754
+ test_name = generate_test_name (item )
755
+ if test_name not in self .return_value :
756
+ # Test function did not complete successfully
757
+ return
758
+ fig = self .return_value [test_name ]
759
+ close_mpl_figure (fig )
0 commit comments