@@ -381,41 +381,40 @@ def test_plot(fig_test, fig_ref):
381
381
fig_test.subplots().plot([1, 3, 5])
382
382
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
383
383
"""
384
-
384
+ POSITIONAL_OR_KEYWORD = inspect . Parameter . POSITIONAL_OR_KEYWORD
385
385
def decorator (func ):
386
386
import pytest
387
387
388
388
_ ,result_dir = _image_directories (func )
389
389
390
- if len (inspect .signature (func ).parameters )== 2 :
391
- # Free-standing function.
392
- @pytest .mark .parametrize ("ext" ,extensions )
393
- def wrapper (ext ):
394
- fig_test = plt .figure ("test" )
395
- fig_ref = plt .figure ("reference" )
396
- func (fig_test ,fig_ref )
397
- test_image_path = result_dir / (func .__name__ + "." + ext )
398
- ref_image_path = (
399
- result_dir / (func .__name__ + "-expected." + ext ))
400
- fig_test .savefig (test_image_path )
401
- fig_ref .savefig (ref_image_path )
402
- _raise_on_image_difference (
403
- ref_image_path ,test_image_path ,tol = tol )
404
-
405
- elif len (inspect .signature (func ).parameters )== 3 :
406
- # Method.
407
- @pytest .mark .parametrize ("ext" ,extensions )
408
- def wrapper (self ,ext ):
409
- fig_test = plt .figure ("test" )
410
- fig_ref = plt .figure ("reference" )
411
- func (self ,fig_test ,fig_ref )
412
- test_image_path = result_dir / (func .__name__ + "." + ext )
413
- ref_image_path = (
414
- result_dir / (func .__name__ + "-expected." + ext ))
415
- fig_test .savefig (test_image_path )
416
- fig_ref .savefig (ref_image_path )
417
- _raise_on_image_difference (
418
- ref_image_path ,test_image_path ,tol = tol )
390
+ @pytest .mark .parametrize ("ext" ,extensions )
391
+ def wrapper (* args ,ext ,** kwargs ):
392
+ fig_test = plt .figure ("test" )
393
+ fig_ref = plt .figure ("reference" )
394
+ func (* args ,fig_test = fig_test ,fig_ref = fig_ref ,** kwargs )
395
+ test_image_path = result_dir / (func .__name__ + "." + ext )
396
+ ref_image_path = result_dir / (
397
+ func .__name__ + "-expected." + ext
398
+ )
399
+ fig_test .savefig (test_image_path )
400
+ fig_ref .savefig (ref_image_path )
401
+ _raise_on_image_difference (
402
+ ref_image_path ,test_image_path ,tol = tol
403
+ )
404
+
405
+ sig = inspect .signature (func )
406
+ new_sig = sig .replace (
407
+ parameters = ([param
408
+ for param in sig .parameters .values ()
409
+ if param .name not in {"fig_test" ,"fig_ref" }]
410
+ + [inspect .Parameter ("ext" ,POSITIONAL_OR_KEYWORD )])
411
+ )
412
+ wrapper .__signature__ = new_sig
413
+
414
+ # reach a bit into pytest internals to hoist the marks from
415
+ # our wrapped function
416
+ new_marks = getattr (func ,"pytestmark" , [])+ wrapper .pytestmark
417
+ wrapper .pytestmark = new_marks
419
418
420
419
return wrapper
421
420