Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Improves F8 conversion#40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 1 commit intomainfromf8f
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
View file
Open in desktop
Binary file not shown.
27 changes: 26 additions & 1 deletion_unittests/ut_validation/test_f8.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -1220,7 +1220,32 @@ def test_fe4m3fn_to_float32_bug(self):
continue
raise AssertionError(f"Unexpected value for pt={pt}.")

def test_inf(self):
for x, e in [(numpy.float32(numpy.inf), 126), (numpy.float32(-numpy.inf), 254)]:
f8 = float32_to_fe4m3(x)
self.assertEqual(e, f8)

def test_nan(self):
expected = 127
values = [
(
None,
int.from_bytes(struct.pack("<f", numpy.float32(numpy.nan)), "little"),
numpy.float32(numpy.nan),
expected,
)
]
for i in range(0, 23):
v = 0x7F800000 | (1 << i)
f = numpy.uint32(v).view(numpy.float32)
values.append((i, v, f, expected))
values.append((i, v, -f, expected | 128))

for i, v, x, e in values:
with self.subTest(x=x, e=e, h=hex(v), i=i):
f8 = float32_to_fe4m3(x)
self.assertEqual(e, f8)


if __name__ == "__main__":
TestF8().test_fe4m3fn_to_float32_bug()
unittest.main(verbosity=2)
32 changes: 20 additions & 12 deletionsonnx_array_api/validation/f8.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
"""
if not fn:
raise NotImplementedError("fn=False is not implemented.")
b = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little")
if not isinstance(x, numpy.float32):
x = numpy.float32(x)
b = int.from_bytes(struct.pack("<f", x), "little")
ret = (b & 0x80000000) >> 24 # sign
if uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x80
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# infinity
if saturate:
return ret | 127
return 0x80
if (b & 0x7F800000) == 0x7F800000:
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand DownExpand Up@@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
ret = 0
return int(ret)
else:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x7F | ret
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# infinity
if saturate:
return ret | 126
return 0x7F | ret
if (b & 0x7F800000) == 0x7F800000:
# non
return 0x7F | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand DownExpand Up@@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
ret = (b & 0x80000000) >> 24 # sign

if fn and uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x80
if (b & 0x7FFFFFFF) == 0x7F800000:
# inf
if saturate:
return ret | 0x7F
return 0x80
if (b & 0x7F800000) == 0x7F800000:
# nan
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand DownExpand Up@@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
ret = 0
return int(ret)
elif not fn and not uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x7F | ret
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# inf
if saturate:
return 0x7B | ret
return 0x7C | ret
if (b & 0x7F800000) == 0x7F800000:
# nan
return 0x7F | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp