Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita6a735a

Browse files
committed
fix inf to nan
1 parent97cc690 commita6a735a

File tree

2 files changed

+27
-45
lines changed

2 files changed

+27
-45
lines changed

‎_unittests/ut_validation/test_f8.py‎

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,9 @@ def test_inf_nan(self):
342342
0.203125,
343343
0.75,
344344
numpy.nan,
345-
max(CastFloat8.values_e4m3fn)[0],
346-
max(CastFloat8.values_e4m3fn)[0],
347-
min(CastFloat8.values_e4m3fn)[0],
345+
numpy.nan,
346+
numpy.nan,
347+
-numpy.nan,
348348
],
349349
dtype=numpy.float32,
350350
)
@@ -416,7 +416,7 @@ def test_search_e5m2_pow(self):
416416
)
417417

418418
deftest_float32_to_fe4m3fn_inf(self):
419-
mx=max(CastFloat8.values_e4m3fn)[0]
419+
mx=numpy.float32(numpy.nan)
420420
v0=numpy.float32(mx)
421421
v1=numpy.float32(numpy.inf)
422422
a=search_float32_into_fe4m3(v0)
@@ -429,7 +429,7 @@ def test_float32_to_fe4m3fn_inf(self):
429429
b=float32_to_fe4m3(v1)
430430
self.assertEqual(a,b)
431431

432-
mi=min(CastFloat8.values_e4m3fn)[0]
432+
mi=numpy.float32(-numpy.nan)
433433
v0=numpy.float32(mi)
434434
v1=numpy.float32(-numpy.inf)
435435
a=search_float32_into_fe4m3(v0)
@@ -654,30 +654,18 @@ def test_search_float32_into_fe5m2fnuz(self):
654654
)
655655

656656
deftest_float32_to_fe4m3fnuz_inf(self):
657-
v0=numpy.float32(448)
657+
v0=numpy.float32(numpy.nan)
658658
v1=numpy.float32(numpy.inf)
659659
a=search_float32_into_fe4m3(v0,uz=True)
660660
b=search_float32_into_fe4m3(v1,uz=True)
661661
self.assertEqual(a,b)
662662

663-
v0=numpy.float32(448)
664-
v1=numpy.float32(numpy.inf)
665-
a=float32_to_fe4m3(v0,uz=True)
666-
b=float32_to_fe4m3(v1,uz=True)
667-
self.assertEqual(a,b)
668-
669-
v0=numpy.float32(-448)
663+
v0=numpy.float32(-numpy.nan)
670664
v1=numpy.float32(-numpy.inf)
671665
a=search_float32_into_fe4m3(v0,uz=True)
672666
b=search_float32_into_fe4m3(v1,uz=True)
673667
self.assertEqual(a,b)
674668

675-
v0=numpy.float32(-448)
676-
v1=numpy.float32(-numpy.inf)
677-
a=float32_to_fe4m3(v0,uz=True)
678-
b=float32_to_fe4m3(v1,uz=True)
679-
self.assertEqual(a,b)
680-
681669
v0=numpy.float32(numpy.nan)
682670
v1=numpy.float32(-numpy.nan)
683671
a=search_float32_into_fe4m3(v0,uz=True)
@@ -688,7 +676,7 @@ def test_float32_to_fe4m3fnuz_inf(self):
688676
v1=numpy.float32(-numpy.inf)
689677
a=search_float32_into_fe4m3(v0,uz=True)
690678
b=search_float32_into_fe4m3(v1,uz=True)
691-
self.assertNotEqual(a,b)
679+
self.assertEqual(a,b)
692680

693681
v0=numpy.float32(numpy.nan)
694682
v1=numpy.float32(-numpy.nan)
@@ -700,10 +688,10 @@ def test_float32_to_fe4m3fnuz_inf(self):
700688
v1=numpy.float32(-numpy.inf)
701689
a=float32_to_fe4m3(v0,uz=True)
702690
b=float32_to_fe4m3(v1,uz=True)
703-
self.assertNotEqual(a,b)
691+
self.assertEqual(a,b)
704692

705693
deftest_float32_to_fe5m2fnuz_inf(self):
706-
mx=max(CastFloat8.values_e5m2fnuz)[0]
694+
mx=numpy.nan
707695
v0=numpy.float32(mx)
708696
v1=numpy.float32(numpy.inf)
709697
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
@@ -716,7 +704,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
716704
b=float32_to_fe5m2(v1,fn=True,uz=True)
717705
self.assertEqual(a,b)
718706

719-
mi=min(CastFloat8.values_e5m2fnuz)[0]
707+
mi=numpy.nan
720708
v0=numpy.float32(mi)
721709
v1=numpy.float32(-numpy.inf)
722710
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
@@ -739,7 +727,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
739727
v1=numpy.float32(-numpy.inf)
740728
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
741729
b=search_float32_into_fe5m2(v1,fn=True,uz=True)
742-
self.assertNotEqual(a,b)
730+
self.assertEqual(a,b)
743731

744732
v0=numpy.float32(numpy.nan)
745733
v1=numpy.float32(-numpy.nan)
@@ -751,7 +739,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
751739
v1=numpy.float32(-numpy.inf)
752740
a=float32_to_fe5m2(v0,fn=True,uz=True)
753741
b=float32_to_fe5m2(v1,fn=True,uz=True)
754-
self.assertNotEqual(a,b)
742+
self.assertEqual(a,b)
755743

756744
deftest_simple_fe4m3(self):
757745
values= [448]
@@ -780,7 +768,7 @@ def test_inf_nan_ml_dtypes(self):
780768
g2=float32_to_fe5m2(x)
781769
i1=fe4m3_to_float32(g1)
782770
i2=fe5m2_to_float32(g2)
783-
self.assertEqual(i1,448)
771+
self.assertNotEqual(i1,448)
784772
self.assertTrue(numpy.isinf(i2))
785773
m1=new_cvt_float32_to_e4m3fn(x)
786774
m2=new_cvt_float32_to_e5m2(x)

‎onnx_array_api/validation/f8.py‎

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importstruct
22
importnumpy
33

4+
# display functions
45

56
defdisplay_float32(value,sign=1,exponent=8,mantissa=23):
67
"""
@@ -90,6 +91,9 @@ def display_fe5m2(value, sign=1, exponent=4, mantissa=3):
9091
returndisplay_fexmx(value,sign=1,exponent=5,mantissa=2)
9192

9293

94+
# cast from float 8 to float 32
95+
96+
9397
deffe4m3_to_float32_float(ival:int,fn:bool=True,uz:bool=False)->float:
9498
"""
9599
Casts a float 8 encoded as an integer into a float.
@@ -243,7 +247,6 @@ def fe4m3_to_float32(ival: int, fn: bool = True, uz: bool = False) -> float:
243247
f=numpy.uint32(res).view(numpy.float32)# pylint: disable=E1121
244248
returnf
245249

246-
247250
deffe5m2_to_float32(ival:int,fn:bool=False,uz:bool=False)->float:
248251
"""
249252
Casts a float E5M2 encoded as an integer into a float.
@@ -292,6 +295,7 @@ def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float:
292295
f=numpy.uint32(res).view(numpy.float32)# pylint: disable=E1121
293296
returnf
294297

298+
# cast from float32 to float 8
295299

296300
classCastFloat8:
297301
"""
@@ -378,16 +382,12 @@ def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) -
378382
b=int.from_bytes(struct.pack("<f",numpy.float32(value)),"little")
379383
ret= (b&0x80000000)>>24# sign
380384
ifuz:
381-
ifnumpy.isnan(value):
385+
ifnumpy.isnan(value)ornumpy.isinf(value):
382386
return0x80
383-
ifnumpy.isinf(value):
384-
returnret|0x7F
385387
set_values=CastFloat8.values_e4m3fnuz
386388
else:
387-
ifnumpy.isnan(value):
389+
ifnumpy.isnan(value)ornumpy.isinf(value):
388390
return0x7F|ret
389-
ifnumpy.isinf(value):
390-
return0x7E|ret
391391
set_values=CastFloat8.values_e4m3fn
392392
f=numpy.float32(value)
393393
i=CastFloat8.find_closest_value(f,set_values)
@@ -407,10 +407,8 @@ def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False)
407407
ret= (b&0x80000000)>>24# sign
408408

409409
iffnanduz:
410-
ifnumpy.isnan(value):
410+
ifnumpy.isnan(value)ornumpy.isinf(value):
411411
return0x80
412-
ifnumpy.isinf(value):
413-
returnret|0x7F
414412
set_values=CastFloat8.values_e5m2fnuz
415413
elifnotfnandnotuz:
416414
ifnumpy.isnan(value):
@@ -438,10 +436,8 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
438436
b=int.from_bytes(struct.pack("<f",numpy.float32(x)),"little")
439437
ret= (b&0x80000000)>>24# sign
440438
ifuz:
441-
if (b&0x7FC00000)==0x7FC00000:
439+
if (b&0x7FC00000)==0x7FC00000ornumpy.isinf(x):
442440
return0x80
443-
ifnumpy.isinf(x):
444-
returnret|0x7F# saturation
445441
e= (b&0x7F800000)>>23# exponent
446442
m=b&0x007FFFFF# mantissa
447443

@@ -475,10 +471,8 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
475471
ret|=0x7F# 01111110
476472
returnint(ret)
477473
else:
478-
if (b&0x7FC00000)==0x7FC00000:
474+
if (b&0x7FC00000)==0x7FC00000ornumpy.isinf(x):
479475
return0x7F|ret
480-
ifnumpy.isinf(x):
481-
return0x7E|ret# saturation
482476
e= (b&0x7F800000)>>23# exponent
483477
m=b&0x007FFFFF# mantissa
484478

@@ -528,10 +522,10 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False):
528522
ret= (b&0x80000000)>>24# sign
529523

530524
iffnanduz:
531-
if (b&0x7FC00000)==0x7FC00000:
525+
if (b&0x7FC00000)==0x7FC00000:# NaN
526+
return0x80
527+
if (b&0x7FFFFFFF)==0x7F800000:# Inf
532528
return0x80
533-
if (b&0x7FFFFFFF)==0x7F800000:
534-
returnret|0x7F
535529
e= (b&0x7F800000)>>23# exponent
536530
m=b&0x007FFFFF# mantissa
537531

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp