Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit295e12d

Browse files
authored
Merge pull request#15175 from charris/backport-testing-utils-1.17.x
ENH: Backport improvements to testing functions.
2 parents6d2193f +e188a16 commit295e12d

File tree

2 files changed

+105
-36
lines changed

2 files changed

+105
-36
lines changed

‎numpy/testing/_private/utils.py‎

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
301301
check that all elements of these objects are equal. An exception is raised
302302
at the first conflicting values.
303303
304+
When one of `actual` and `desired` is a scalar and the other is array_like,
305+
the function checks that each element of the array_like object is equal to
306+
the scalar.
307+
304308
This function handles NaN comparisons as if NaN was a "normal" number.
305309
That is, no assertion is raised if both objects have NaNs in the same
306310
positions. This is in contrast to the IEEE standard on NaNs, which says
@@ -391,21 +395,6 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
391395
ifisscalar(desired)!=isscalar(actual):
392396
raiseAssertionError(msg)
393397

394-
# Inf/nan/negative zero handling
395-
try:
396-
isdesnan=gisnan(desired)
397-
isactnan=gisnan(actual)
398-
ifisdesnanandisactnan:
399-
return# both nan, so equal
400-
401-
# handle signed zero specially for floats
402-
ifdesired==0andactual==0:
403-
ifnotsignbit(desired)==signbit(actual):
404-
raiseAssertionError(msg)
405-
406-
except (TypeError,ValueError,NotImplementedError):
407-
pass
408-
409398
try:
410399
isdesnat=isnat(desired)
411400
isactnat=isnat(actual)
@@ -421,6 +410,33 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
421410
except (TypeError,ValueError,NotImplementedError):
422411
pass
423412

413+
# Inf/nan/negative zero handling
414+
try:
415+
isdesnan=gisnan(desired)
416+
isactnan=gisnan(actual)
417+
ifisdesnanandisactnan:
418+
return# both nan, so equal
419+
420+
# handle signed zero specially for floats
421+
array_actual=array(actual)
422+
array_desired=array(desired)
423+
if (array_actual.dtype.charin'Mm'or
424+
array_desired.dtype.charin'Mm'):
425+
# version 1.18
426+
# until this version, gisnan failed for datetime64 and timedelta64.
427+
# Now it succeeds but comparison to scalar with a different type
428+
# emits a DeprecationWarning.
429+
# Avoid that by skipping the next check
430+
raiseNotImplementedError('cannot compare to a scalar '
431+
'with a different type')
432+
433+
ifdesired==0andactual==0:
434+
ifnotsignbit(desired)==signbit(actual):
435+
raiseAssertionError(msg)
436+
437+
except (TypeError,ValueError,NotImplementedError):
438+
pass
439+
424440
try:
425441
# Explicitly use __eq__ for comparison, gh-2552
426442
ifnot (desired==actual):
@@ -703,7 +719,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
703719
header='',precision=6,equal_nan=True,
704720
equal_inf=True):
705721
__tracebackhide__=True# Hide traceback for py.test
706-
fromnumpy.coreimportarray,array2string,isnan,inf,bool_,errstate,all
722+
fromnumpy.coreimportarray,array2string,isnan,inf,bool_,errstate,all,max,object_
707723

708724
x=array(x,copy=False,subok=True)
709725
y=array(y,copy=False,subok=True)
@@ -804,15 +820,19 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
804820
# do not trigger a failure (np.ma.masked != True evaluates as
805821
# np.ma.masked, which is falsy).
806822
ifcond!=True:
807-
mismatch=100.* (reduced.size-reduced.sum(dtype=intp))/ox.size
808-
remarks= ['Mismatch: {:.3g}%'.format(mismatch)]
823+
n_mismatch=reduced.size-reduced.sum(dtype=intp)
824+
n_elements=flagged.sizeifflagged.ndim!=0elsereduced.size
825+
percent_mismatch=100*n_mismatch/n_elements
826+
remarks= [
827+
'Mismatched elements: {} / {} ({:.3g}%)'.format(
828+
n_mismatch,n_elements,percent_mismatch)]
809829

810830
witherrstate(invalid='ignore',divide='ignore'):
811831
# ignore errors for non-numeric types
812832
withcontextlib.suppress(TypeError):
813833
error=abs(x-y)
814-
max_abs_error=error.max()
815-
iferror.dtype=='object':
834+
max_abs_error=max(error)
835+
ifgetattr(error,'dtype',object_)==object_:
816836
remarks.append('Max absolute difference: '
817837
+str(max_abs_error))
818838
else:
@@ -826,8 +846,8 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
826846
ifall(~nonzero):
827847
max_rel_error=array(inf)
828848
else:
829-
max_rel_error= (error[nonzero]/abs(y[nonzero])).max()
830-
iferror.dtype=='object':
849+
max_rel_error=max(error[nonzero]/abs(y[nonzero]))
850+
ifgetattr(error,'dtype',object_)==object_:
831851
remarks.append('Max relative difference: '
832852
+str(max_rel_error))
833853
else:
@@ -854,10 +874,11 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
854874
Raises an AssertionError if two array_like objects are not equal.
855875
856876
Given two array_like objects, check that the shape is equal and all
857-
elements of these objects are equal. An exception is raised at
858-
shape mismatch or conflicting values. In contrast to the standard usage
859-
in numpy, NaNs are compared like numbers, no assertion is raised if
860-
both objects have NaNs in the same positions.
877+
elements of these objects are equal (but see the Notes for the special
878+
handling of a scalar). An exception is raised at shape mismatch or
879+
conflicting values. In contrast to the standard usage in numpy, NaNs
880+
are compared like numbers, no assertion is raised if both objects have
881+
NaNs in the same positions.
861882
862883
The usual caution for verifying equality with floating point numbers is
863884
advised.
@@ -884,14 +905,20 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
884905
relative and/or absolute precision.
885906
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
886907
908+
Notes
909+
-----
910+
When one of `x` and `y` is a scalar and the other is array_like, the
911+
function checks that each element of the array_like object is equal to
912+
the scalar.
913+
887914
Examples
888915
--------
889916
The first assert does not raise an exception:
890917
891918
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
892919
... [np.exp(0),2.33333, np.nan])
893920
894-
Assert fails with numericalinprecision with floats:
921+
Assert fails with numericalimprecision with floats:
895922
896923
>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
897924
... [1, np.sqrt(np.pi)**2, np.nan])
@@ -912,6 +939,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
912939
... [1, np.sqrt(np.pi)**2, np.nan],
913940
... rtol=1e-10, atol=0)
914941
942+
As mentioned in the Notes section, `assert_array_equal` has special
943+
handling for scalars. Here the test checks that each value in `x` is 3:
944+
945+
>>> x = np.full((2, 5), fill_value=3)
946+
>>> np.testing.assert_array_equal(x, 3)
947+
915948
"""
916949
__tracebackhide__=True# Hide traceback for py.test
917950
assert_array_compare(operator.__eq__,x,y,err_msg=err_msg,
@@ -1150,7 +1183,7 @@ def assert_string_equal(actual, desired):
11501183
ifdesired==actual:
11511184
return
11521185

1153-
diff=list(difflib.Differ().compare(actual.splitlines(1),desired.splitlines(1)))
1186+
diff=list(difflib.Differ().compare(actual.splitlines(True),desired.splitlines(True)))
11541187
diff_list= []
11551188
whilediff:
11561189
d1=diff.pop(0)

‎numpy/testing/tests/test_utils.py‎

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def foo(t):
9090
fortin ['S1','U1']:
9191
foo(t)
9292

93+
deftest_0_ndim_array(self):
94+
x=np.array(473963742225900817127911193656584771)
95+
y=np.array(18535119325151578301457182298393896)
96+
assert_raises(AssertionError,self._assert_func,x,y)
97+
98+
y=x
99+
self._assert_func(x,y)
100+
101+
x=np.array(43)
102+
y=np.array(10)
103+
assert_raises(AssertionError,self._assert_func,x,y)
104+
105+
y=x
106+
self._assert_func(x,y)
107+
93108
deftest_generic_rank3(self):
94109
"""Test rank 3 array for all dtypes."""
95110
deffoo(t):
@@ -520,7 +535,7 @@ def test_error_message(self):
520535
withpytest.raises(AssertionError)asexc_info:
521536
self._assert_func(x,y,decimal=12)
522537
msgs=str(exc_info.value).split('\n')
523-
assert_equal(msgs[3],'Mismatch:100%')
538+
assert_equal(msgs[3],'Mismatched elements: 3 / 3 (100%)')
524539
assert_equal(msgs[4],'Max absolute difference: 1.e-05')
525540
assert_equal(msgs[5],'Max relative difference: 3.33328889e-06')
526541
assert_equal(
@@ -536,7 +551,7 @@ def test_error_message(self):
536551
withpytest.raises(AssertionError)asexc_info:
537552
self._assert_func(x,y)
538553
msgs=str(exc_info.value).split('\n')
539-
assert_equal(msgs[3],'Mismatch:33.3%')
554+
assert_equal(msgs[3],'Mismatched elements: 1 / 3 (33.3%)')
540555
assert_equal(msgs[4],'Max absolute difference: 1.e-05')
541556
assert_equal(msgs[5],'Max relative difference: 3.33328889e-06')
542557
assert_equal(msgs[6],' x: array([1. , 2. , 3.00003])')
@@ -548,7 +563,7 @@ def test_error_message(self):
548563
withpytest.raises(AssertionError)asexc_info:
549564
self._assert_func(x,y)
550565
msgs=str(exc_info.value).split('\n')
551-
assert_equal(msgs[3],'Mismatch:50%')
566+
assert_equal(msgs[3],'Mismatched elements: 1 / 2 (50%)')
552567
assert_equal(msgs[4],'Max absolute difference: 1.')
553568
assert_equal(msgs[5],'Max relative difference: 1.')
554569
assert_equal(msgs[6],' x: array([inf, 0.])')
@@ -560,10 +575,30 @@ def test_error_message(self):
560575
withpytest.raises(AssertionError)asexc_info:
561576
self._assert_func(x,y)
562577
msgs=str(exc_info.value).split('\n')
563-
assert_equal(msgs[3],'Mismatch:100%')
578+
assert_equal(msgs[3],'Mismatched elements: 2 / 2 (100%)')
564579
assert_equal(msgs[4],'Max absolute difference: 2')
565580
assert_equal(msgs[5],'Max relative difference: inf')
566581

582+
deftest_error_message_2(self):
583+
"""Check the message is formatted correctly when either x or y is a scalar."""
584+
x=2
585+
y=np.ones(20)
586+
withpytest.raises(AssertionError)asexc_info:
587+
self._assert_func(x,y)
588+
msgs=str(exc_info.value).split('\n')
589+
assert_equal(msgs[3],'Mismatched elements: 20 / 20 (100%)')
590+
assert_equal(msgs[4],'Max absolute difference: 1.')
591+
assert_equal(msgs[5],'Max relative difference: 1.')
592+
593+
y=2
594+
x=np.ones(20)
595+
withpytest.raises(AssertionError)asexc_info:
596+
self._assert_func(x,y)
597+
msgs=str(exc_info.value).split('\n')
598+
assert_equal(msgs[3],'Mismatched elements: 20 / 20 (100%)')
599+
assert_equal(msgs[4],'Max absolute difference: 1.')
600+
assert_equal(msgs[5],'Max relative difference: 0.5')
601+
567602
deftest_subclass_that_cannot_be_bool(self):
568603
# While we cannot guarantee testing functions will always work for
569604
# subclasses, the tests should ideally rely only on subclasses having
@@ -588,9 +623,9 @@ class TestApproxEqual(object):
588623
defsetup(self):
589624
self._assert_func=assert_approx_equal
590625

591-
deftest_simple_arrays(self):
592-
x=np.array([1234.22])
593-
y=np.array([1234.23])
626+
deftest_simple_0d_arrays(self):
627+
x=np.array(1234.22)
628+
y=np.array(1234.23)
594629

595630
self._assert_func(x,y,significant=5)
596631
self._assert_func(x,y,significant=6)
@@ -855,7 +890,8 @@ def test_report_fail_percentage(self):
855890
withpytest.raises(AssertionError)asexc_info:
856891
assert_allclose(a,b)
857892
msg=str(exc_info.value)
858-
assert_('Mismatch: 25%\nMax absolute difference: 1\n'
893+
assert_('Mismatched elements: 1 / 4 (25%)\n'
894+
'Max absolute difference: 1\n'
859895
'Max relative difference: 0.5'inmsg)
860896

861897
deftest_equal_nan(self):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp