@@ -301,6 +301,10 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
301301 check that all elements of these objects are equal. An exception is raised
302302 at the first conflicting values.
303303
304+ When one of `actual` and `desired` is a scalar and the other is array_like,
305+ the function checks that each element of the array_like object is equal to
306+ the scalar.
307+
304308 This function handles NaN comparisons as if NaN was a "normal" number.
305309 That is, no assertion is raised if both objects have NaNs in the same
306310 positions. This is in contrast to the IEEE standard on NaNs, which says
@@ -391,21 +395,6 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
391395if isscalar (desired )!= isscalar (actual ):
392396raise AssertionError (msg )
393397
394- # Inf/nan/negative zero handling
395- try :
396- isdesnan = gisnan (desired )
397- isactnan = gisnan (actual )
398- if isdesnan and isactnan :
399- return # both nan, so equal
400-
401- # handle signed zero specially for floats
402- if desired == 0 and actual == 0 :
403- if not signbit (desired )== signbit (actual ):
404- raise AssertionError (msg )
405-
406- except (TypeError ,ValueError ,NotImplementedError ):
407- pass
408-
409398try :
410399isdesnat = isnat (desired )
411400isactnat = isnat (actual )
@@ -421,6 +410,33 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
421410except (TypeError ,ValueError ,NotImplementedError ):
422411pass
423412
413+ # Inf/nan/negative zero handling
414+ try :
415+ isdesnan = gisnan (desired )
416+ isactnan = gisnan (actual )
417+ if isdesnan and isactnan :
418+ return # both nan, so equal
419+
420+ # handle signed zero specially for floats
421+ array_actual = array (actual )
422+ array_desired = array (desired )
423+ if (array_actual .dtype .char in 'Mm' or
424+ array_desired .dtype .char in 'Mm' ):
425+ # version 1.18
426+ # until this version, gisnan failed for datetime64 and timedelta64.
427+ # Now it succeeds but comparison to scalar with a different type
428+ # emits a DeprecationWarning.
429+ # Avoid that by skipping the next check
430+ raise NotImplementedError ('cannot compare to a scalar '
431+ 'with a different type' )
432+
433+ if desired == 0 and actual == 0 :
434+ if not signbit (desired )== signbit (actual ):
435+ raise AssertionError (msg )
436+
437+ except (TypeError ,ValueError ,NotImplementedError ):
438+ pass
439+
424440try :
425441# Explicitly use __eq__ for comparison, gh-2552
426442if not (desired == actual ):
@@ -703,7 +719,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
703719header = '' ,precision = 6 ,equal_nan = True ,
704720equal_inf = True ):
705721__tracebackhide__ = True # Hide traceback for py.test
706- from numpy .core import array ,array2string ,isnan ,inf ,bool_ ,errstate ,all
722+ from numpy .core import array ,array2string ,isnan ,inf ,bool_ ,errstate ,all , max , object_
707723
708724x = array (x ,copy = False ,subok = True )
709725y = array (y ,copy = False ,subok = True )
@@ -804,15 +820,19 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
804820# do not trigger a failure (np.ma.masked != True evaluates as
805821# np.ma.masked, which is falsy).
806822if cond != True :
807- mismatch = 100. * (reduced .size - reduced .sum (dtype = intp ))/ ox .size
808- remarks = ['Mismatch: {:.3g}%' .format (mismatch )]
823+ n_mismatch = reduced .size - reduced .sum (dtype = intp )
824+ n_elements = flagged .size if flagged .ndim != 0 else reduced .size
825+ percent_mismatch = 100 * n_mismatch / n_elements
826+ remarks = [
827+ 'Mismatched elements: {} / {} ({:.3g}%)' .format (
828+ n_mismatch ,n_elements ,percent_mismatch )]
809829
810830with errstate (invalid = 'ignore' ,divide = 'ignore' ):
811831# ignore errors for non-numeric types
812832with contextlib .suppress (TypeError ):
813833error = abs (x - y )
814- max_abs_error = error . max ()
815- if error . dtype == 'object' :
834+ max_abs_error = max (error )
835+ if getattr ( error , ' dtype' , object_ ) == object_ :
816836remarks .append ('Max absolute difference: '
817837+ str (max_abs_error ))
818838else :
@@ -826,8 +846,8 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
826846if all (~ nonzero ):
827847max_rel_error = array (inf )
828848else :
829- max_rel_error = (error [nonzero ]/ abs (y [nonzero ])). max ( )
830- if error . dtype == 'object' :
849+ max_rel_error = max (error [nonzero ]/ abs (y [nonzero ]))
850+ if getattr ( error , ' dtype' , object_ ) == object_ :
831851remarks .append ('Max relative difference: '
832852+ str (max_rel_error ))
833853else :
@@ -854,10 +874,11 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
854874 Raises an AssertionError if two array_like objects are not equal.
855875
856876 Given two array_like objects, check that the shape is equal and all
857- elements of these objects are equal. An exception is raised at
858- shape mismatch or conflicting values. In contrast to the standard usage
859- in numpy, NaNs are compared like numbers, no assertion is raised if
860- both objects have NaNs in the same positions.
877+ elements of these objects are equal (but see the Notes for the special
878+ handling of a scalar). An exception is raised at shape mismatch or
879+ conflicting values. In contrast to the standard usage in numpy, NaNs
880+ are compared like numbers, no assertion is raised if both objects have
881+ NaNs in the same positions.
861882
862883 The usual caution for verifying equality with floating point numbers is
863884 advised.
@@ -884,14 +905,20 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
884905 relative and/or absolute precision.
885906 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
886907
908+ Notes
909+ -----
910+ When one of `x` and `y` is a scalar and the other is array_like, the
911+ function checks that each element of the array_like object is equal to
912+ the scalar.
913+
887914 Examples
888915 --------
889916 The first assert does not raise an exception:
890917
891918 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
892919 ... [np.exp(0),2.33333, np.nan])
893920
894- Assert fails with numericalinprecision with floats:
921+ Assert fails with numericalimprecision with floats:
895922
896923 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
897924 ... [1, np.sqrt(np.pi)**2, np.nan])
@@ -912,6 +939,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
912939 ... [1, np.sqrt(np.pi)**2, np.nan],
913940 ... rtol=1e-10, atol=0)
914941
942+ As mentioned in the Notes section, `assert_array_equal` has special
943+ handling for scalars. Here the test checks that each value in `x` is 3:
944+
945+ >>> x = np.full((2, 5), fill_value=3)
946+ >>> np.testing.assert_array_equal(x, 3)
947+
915948 """
916949__tracebackhide__ = True # Hide traceback for py.test
917950assert_array_compare (operator .__eq__ ,x ,y ,err_msg = err_msg ,
@@ -1150,7 +1183,7 @@ def assert_string_equal(actual, desired):
11501183if desired == actual :
11511184return
11521185
1153- diff = list (difflib .Differ ().compare (actual .splitlines (1 ),desired .splitlines (1 )))
1186+ diff = list (difflib .Differ ().compare (actual .splitlines (True ),desired .splitlines (True )))
11541187diff_list = []
11551188while diff :
11561189d1 = diff .pop (0 )