Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2fde01f

Browse files
authored
Fix float 8 conversion (#36)
1 parentcd73f71 commit2fde01f

File tree

6 files changed

+51
-27
lines changed

6 files changed

+51
-27
lines changed

‎.gitignore‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ _doc/examples/data/*.optimized.onnx
2222
_doc/examples/*.html
2323
_doc/_static/require.js
2424
_doc/_static/viz.js
25+
_doc/LICENSE.txt
26+
_doc/CHANGELOGS.rst
2527
_unittests/ut__main/*.png
2628
_unittests/ut__main/_cache/*
2729
_unittests/ut__main/*.html

‎_doc/api/f8.rst‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Float 8
2+
=======
3+
4+
..automodule::onnx_array_api.validation.f8
5+
:members:

‎_doc/api/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ API
2020
reference
2121
tools
2222
profiling
23+
f8

‎_doc/conf.py‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@
122122
"onnxruntime":"https://onnxruntime.ai/",
123123
"numpy":"https://numpy.org/",
124124
"numba":"https://numba.pydata.org/",
125-
"onnx-array-api": (
126-
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
127-
),
125+
"onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"),
128126
"pyinstrument":"https://github.com/joerick/pyinstrument",
129127
"python":"https://www.python.org/",
130128
"scikit-learn":"https://scikit-learn.org/stable/",

‎_unittests/ut_validation/test_f8.py‎

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,16 @@ def test_float8_e5m2fnuz_negative_nan(self):
11511151
back=fe4m3_to_float32(to,fn=True,uz=True)
11521152
self.assertTrue(numpy.isnan(back))
11531153

1154+
deftest_fe4m3fn_to_float32_bug(self):
1155+
cases= [(1.8131605,1.875)]
1156+
forval,expectedincases:
1157+
withself.subTest(value=val,expected=expected):
1158+
res=fe4m3_to_float32(search_float32_into_fe4m3(val))
1159+
self.assertEqual(expected,res)
1160+
res=fe4m3_to_float32(float32_to_fe4m3(val))
1161+
self.assertEqual(expected,res)
1162+
11541163

11551164
if__name__=="__main__":
1156-
TestF8().test_search_float32_into_fe4m3fn_simple()
1165+
TestF8().test_fe4m3fn_to_float32_bug()
11571166
unittest.main(verbosity=2)

‎onnx_array_api/validation/f8.py‎

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
1212
pass
1313

1414

15-
defdisplay_float32(value,sign=1,exponent=8,mantissa=23):
15+
defdisplay_int(ival,sign=1,exponent=8,mantissa=23):
1616
"""
17-
Displaysa float32 into b.
17+
Displaysan integer as bits.
1818
19-
:paramvalue: value to display (float32)
19+
:paramival: value to display (float32)
2020
:param sign: number of bits for the sign
2121
:param exponent: number of bits for the exponent
2222
:param mantissa: number of bits for the mantissa
2323
:return: string
2424
"""
2525
t=sign+exponent+mantissa
26-
ival=int.from_bytes(struct.pack("<f",numpy.float32(value)),"little")
2726
s=bin(ival)[2:]
2827
s="0"* (t-len(s))+s
2928
s1=s[:sign]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
3231
return".".join([s1,s2,s3])
3332

3433

34+
defdisplay_float32(value,sign=1,exponent=8,mantissa=23):
35+
"""
36+
Displays a float32 into b.
37+
38+
:param value: value to display (float32)
39+
:param sign: number of bits for the sign
40+
:param exponent: number of bits for the exponent
41+
:param mantissa: number of bits for the mantissa
42+
:return: string
43+
"""
44+
returndisplay_int(
45+
int.from_bytes(struct.pack("<f",numpy.float32(value)),"little"),
46+
sign=sign,
47+
exponent=exponent,
48+
mantissa=mantissa,
49+
)
50+
51+
3552
defdisplay_float16(value,sign=1,exponent=5,mantissa=10):
3653
"""
3754
Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
4259
:param mantissa: number of bits for the mantissa
4360
:return: string
4461
"""
45-
t=sign+exponent+mantissa
46-
ival=numpy.float16(value).view("H")# pylint: disable=E1121
47-
s=bin(ival)[2:]
48-
s="0"* (t-len(s))+s
49-
s1=s[:sign]
50-
s2=s[sign :sign+exponent]
51-
s3=s[sign+exponent :]
52-
return".".join([s1,s2,s3])
62+
returndisplay_int(
63+
numpy.float16(value).view("H"),sign=sign,exponent=exponent,mantissa=mantissa
64+
)
5365

5466

5567
defdisplay_fexmx(value,sign,exponent,mantissa):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
6476
:param mantissa: number of bits for the mantissa
6577
:return: string
6678
"""
67-
t=sign+exponent+mantissa
68-
ival=value
69-
s=bin(ival)[2:]
70-
s="0"* (t-len(s))+s
71-
s1=s[:sign]
72-
s2=s[sign :sign+exponent]
73-
s3=s[sign+exponent :]
74-
return".".join([s1,s2,s3])
79+
returndisplay_int(value,sign=sign,exponent=exponent,mantissa=mantissa)
7580

7681

7782
defdisplay_fe4m3(value,sign=1,exponent=4,mantissa=3):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534539
else:
535540
ret|=ex<<3
536541
ret|=m>>20
537-
ifm&0x80000:
542+
if (m&0x80000)and (
543+
(m&0x100000)or (m&0x7FFFF)
544+
):# round to nearest even
538545
if (ret&0x7F)<0x7F:
539546
# rounding
540547
ret+=1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584591
if (ret&0x7F)==0x7F:
585592
ret&=0xFE
586593
if (m&0x80000)and (
587-
(m&0x100000)or (m&0x7C000)
594+
(m&0x100000)or (m&0x7FFFF)
588595
):# round to nearest even
589596
if (ret&0x7F)<0x7E:
590597
# rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642649
ex=e-111# 127 - 16
643650
ret|=ex<<2
644651
ret|=m>>21
645-
ifm&0x100000:
652+
ifm&0x100000and (
653+
(m&0xFFFFF)or (m&0x200000)
654+
):# round to nearest even
646655
if (ret&0x7F)<0x7F:
647656
# rounding
648657
ret+=1

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp