Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite90ce31

Browse files
authored
Improves F8 conversion (#40)
1 parent85e2e52 commite90ce31

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed
6.49 KB
Binary file not shown.

‎_unittests/ut_validation/test_f8.py‎

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,32 @@ def test_fe4m3fn_to_float32_bug(self):
12201220
continue
12211221
raiseAssertionError(f"Unexpected value for pt={pt}.")
12221222

1223+
deftest_inf(self):
1224+
forx,ein [(numpy.float32(numpy.inf),126), (numpy.float32(-numpy.inf),254)]:
1225+
f8=float32_to_fe4m3(x)
1226+
self.assertEqual(e,f8)
1227+
1228+
deftest_nan(self):
1229+
expected=127
1230+
values= [
1231+
(
1232+
None,
1233+
int.from_bytes(struct.pack("<f",numpy.float32(numpy.nan)),"little"),
1234+
numpy.float32(numpy.nan),
1235+
expected,
1236+
)
1237+
]
1238+
foriinrange(0,23):
1239+
v=0x7F800000| (1<<i)
1240+
f=numpy.uint32(v).view(numpy.float32)
1241+
values.append((i,v,f,expected))
1242+
values.append((i,v,-f,expected|128))
1243+
1244+
fori,v,x,einvalues:
1245+
withself.subTest(x=x,e=e,h=hex(v),i=i):
1246+
f8=float32_to_fe4m3(x)
1247+
self.assertEqual(e,f8)
1248+
12231249

12241250
if__name__=="__main__":
1225-
TestF8().test_fe4m3fn_to_float32_bug()
12261251
unittest.main(verbosity=2)

‎onnx_array_api/validation/f8.py‎

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
503503
"""
504504
ifnotfn:
505505
raiseNotImplementedError("fn=False is not implemented.")
506-
b=int.from_bytes(struct.pack("<f",numpy.float32(x)),"little")
506+
ifnotisinstance(x,numpy.float32):
507+
x=numpy.float32(x)
508+
b=int.from_bytes(struct.pack("<f",x),"little")
507509
ret= (b&0x80000000)>>24# sign
508510
ifuz:
509-
if (b&0x7FC00000)==0x7FC00000:
510-
return0x80
511-
ifnumpy.isinf(x):
511+
if (b&0x7FFFFFFF)==0x7F800000:
512+
# infinity
512513
ifsaturate:
513514
returnret|127
514515
return0x80
516+
if (b&0x7F800000)==0x7F800000:
517+
return0x80
515518
e= (b&0x7F800000)>>23# exponent
516519
m=b&0x007FFFFF# mantissa
517520

@@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
558561
ret=0
559562
returnint(ret)
560563
else:
561-
if (b&0x7FC00000)==0x7FC00000:
562-
return0x7F|ret
563-
ifnumpy.isinf(x):
564+
if (b&0x7FFFFFFF)==0x7F800000:
565+
# infinity
564566
ifsaturate:
565567
returnret|126
566568
return0x7F|ret
569+
if (b&0x7F800000)==0x7F800000:
570+
# non
571+
return0x7F|ret
567572
e= (b&0x7F800000)>>23# exponent
568573
m=b&0x007FFFFF# mantissa
569574

@@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
624629
ret= (b&0x80000000)>>24# sign
625630

626631
iffnanduz:
627-
if (b&0x7FC00000)==0x7FC00000:
628-
return0x80
629632
if (b&0x7FFFFFFF)==0x7F800000:
630633
# inf
631634
ifsaturate:
632635
returnret|0x7F
633636
return0x80
637+
if (b&0x7F800000)==0x7F800000:
638+
# nan
639+
return0x80
634640
e= (b&0x7F800000)>>23# exponent
635641
m=b&0x007FFFFF# mantissa
636642

@@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
675681
ret=0
676682
returnint(ret)
677683
elifnotfnandnotuz:
678-
if (b&0x7FC00000)==0x7FC00000:
679-
return0x7F|ret
680-
ifnumpy.isinf(x):
684+
if (b&0x7FFFFFFF)==0x7F800000:
685+
# inf
681686
ifsaturate:
682687
return0x7B|ret
683688
return0x7C|ret
689+
if (b&0x7F800000)==0x7F800000:
690+
# nan
691+
return0x7F|ret
684692
e= (b&0x7F800000)>>23# exponent
685693
m=b&0x007FFFFF# mantissa
686694

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp