2020search_float32_into_fe5m2 ,
2121)
2222from onnx_array_api .ext_test_case import ExtTestCase
23+ from ml_dtypes import float8_e4m3fn ,float8_e5m2
24+
25+
26+ def new_cvt_float32_to_e4m3fn (x ):
27+ return numpy .array (x ,dtype = numpy .float32 ).astype (float8_e4m3fn )
28+
29+
30+ def new_cvt_e4m3fn_to_float32 (x ):
31+ return numpy .array (x ,dtype = float8_e4m3fn ).astype (numpy .float32 )
32+
33+
34+ def new_cvt_float32_to_e5m2 (x ):
35+ return numpy .array (x ,dtype = numpy .float32 ).astype (float8_e5m2 )
36+
37+
38+ def new_cvt_e5m2_to_float32 (x ):
39+ return numpy .array (x ,dtype = float8_e5m2 ).astype (numpy .float32 )
2340
2441
2542class TestF8 (ExtTestCase ):
@@ -76,6 +93,17 @@ def test_fe4m3fn_to_float32_all(self):
7693continue
7794self .assertEqual (a ,b )
7895
96+ def test_fe4m3fn_to_float32_all_ml_types (self ):
97+ for i in range (0 ,256 ):
98+ a = fe4m3_to_float32_float (i )
99+ b = fe4m3_to_float32 (i )
100+ c = new_cvt_float32_to_e4m3fn (b )
101+ if numpy .isnan (a ):
102+ self .assertTrue (numpy .isnan (b ))
103+ continue
104+ self .assertEqual (float (a ),float (c ))
105+ self .assertEqual (a ,b )
106+
79107def test_display_float (self ):
80108f = 45
81109s = display_float32 (f )
@@ -164,6 +192,7 @@ def test_search_float32_into_fe5m2_equal(self):
164192 ):
165193b = search_float32_into_fe5m2 (value )
166194nf = float32_to_fe5m2 (value )
195+ cf = new_cvt_float32_to_e5m2 (value )
167196if expected in {253 ,254 ,255 ,125 ,126 ,127 }:# nan
168197self .assertIn (b , {253 ,254 ,255 ,125 ,126 ,127 })
169198self .assertIn (nf , {253 ,254 ,255 ,125 ,126 ,127 })
@@ -173,6 +202,10 @@ def test_search_float32_into_fe5m2_equal(self):
173202else :
174203self .assertIn (b , (0 ,128 ))
175204self .assertIn (nf , (0 ,128 ))
205+ if numpy .isnan (float (cf )):
206+ self .assertTrue (numpy .isnan (fe5m2_to_float32 (nf )))
207+ continue
208+ self .assertEqual (fe5m2_to_float32 (nf ),float (cf ))
176209
177210def test_search_float32_into_fe4m3fn (self ):
178211values = [(fe4m3_to_float32_float (i ),i )for i in range (0 ,256 )]
@@ -739,6 +772,33 @@ def test_simple_fe4m3(self):
739772back = [fe4m3_to_float32 (c ,uz = True )for c in cvt ]
740773self .assertEqual (values ,back )
741774
775+ # ml-dtypes
776+
777+ def test_inf_nan_ml_dtypes (self ):
778+ x = numpy .float32 (numpy .inf )
779+ g1 = float32_to_fe4m3 (x )
780+ g2 = float32_to_fe5m2 (x )
781+ i1 = fe4m3_to_float32 (g1 )
782+ i2 = fe5m2_to_float32 (g2 )
783+ self .assertEqual (i1 ,448 )
784+ self .assertTrue (numpy .isinf (i2 ))
785+ m1 = new_cvt_float32_to_e4m3fn (x )
786+ m2 = new_cvt_float32_to_e5m2 (x )
787+ self .assertTrue (numpy .isnan (m1 ))# different from ONNX choice
788+ self .assertTrue (numpy .isinf (m2 ))
789+
790+ x = numpy .float32 (numpy .nan )
791+ g1 = float32_to_fe4m3 (x )
792+ g2 = float32_to_fe5m2 (x )
793+ i1 = fe4m3_to_float32 (g1 )
794+ i2 = fe5m2_to_float32 (g2 )
795+ self .assertTrue (numpy .isnan (i1 ))
796+ self .assertTrue (numpy .isnan (i2 ))
797+ m1 = new_cvt_float32_to_e4m3fn (x )
798+ m2 = new_cvt_float32_to_e5m2 (x )
799+ self .assertTrue (numpy .isnan (m1 ))
800+ self .assertTrue (numpy .isnan (m2 ))
801+
742802
743803if __name__ == "__main__" :
744804TestF8 ().test_search_float32_into_fe4m3fn_simple ()