Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf7bc922

Browse files
authored
Fix float 8 cast (#37)
* fix f8* fix* cleaning
1 parent2fde01f commitf7bc922

File tree

2 files changed

+145
-63
lines changed

2 files changed

+145
-63
lines changed

‎_unittests/ut_validation/test_f8.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
importos
22
importpprint
3+
importstruct
34
importunittest
45
importwarnings
56
importnumpy
67
importpandas
8+
fromonnximportTensorProto
79
fromonnx_array_api.validation.f8import (
810
CastFloat8,
911
UndefinedCastError,
@@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self):
285287
ok=""ifb==nfelse"WRONG",
286288
true=value,
287289
add=add,
290+
exponent=(
291+
int.from_bytes(
292+
struct.pack("<f",numpy.float32(v)),"little"
293+
)
294+
&0x7F800000
295+
)
296+
>>23,
297+
d1=v-fe4m3_to_float32_float(nf),
298+
d2=v-fe4m3_to_float32_float(b),
288299
)
289300
)
290301
ifwrong>0:
@@ -449,10 +460,13 @@ def test_search_e4m3_pow(self):
449460
continue
450461
r2=float32_to_fe4m3(v)
451462
ifr1!=r2:
463+
ex=abs(v-fe4m3_to_float32(r1))==abs(v-fe4m3_to_float32(r2))
452464
raiseAssertionError(
453465
f"p={p}, v={v}, "
454466
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
455-
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
467+
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
468+
f"d1={v-fe4m3_to_float32(r1)} d2={v-fe4m3_to_float32(r2)} "
469+
f"|d1|==|d2|={ex}"
456470
)
457471
forpinrange(1,40):
458472
v=-(2** (-p))
@@ -462,10 +476,13 @@ def test_search_e4m3_pow(self):
462476
continue
463477
r2=float32_to_fe4m3(v)
464478
ifr1!=r2:
479+
ex=abs(v-fe4m3_to_float32(r1))==abs(v-fe4m3_to_float32(r2))
465480
raiseAssertionError(
466481
f"p={p}, v={v}, "
467482
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
468-
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
483+
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
484+
f"d1={v-fe4m3_to_float32(r1)} d2={v-fe4m3_to_float32(r2)} "
485+
f"|d1|==|d2|={ex}"
469486
)
470487

471488
deftest_search_e5m2_pow(self):
@@ -478,10 +495,13 @@ def test_search_e5m2_pow(self):
478495
continue
479496
r2=float32_to_fe5m2(v)
480497
ifr1!=r2:
498+
ex=abs(v-fe5m2_to_float32(r1))==abs(v-fe5m2_to_float32(r2))
481499
raiseAssertionError(
482500
f"p={p}, v={v}, "
483501
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
484-
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
502+
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
503+
f"d1={v-fe4m3_to_float32(r1)} d2={v-fe5m2_to_float32(r2)} "
504+
f"|d1|==|d2|={ex}"
485505
)
486506
forpinrange(1,40):
487507
v=-(2** (-p))
@@ -491,10 +511,13 @@ def test_search_e5m2_pow(self):
491511
continue
492512
r2=float32_to_fe5m2(v)
493513
ifr1!=r2:
514+
ex=abs(v-fe5m2_to_float32(r1))==abs(v-fe5m2_to_float32(r2))
494515
raiseAssertionError(
495516
f"p={p}, v={v}, "
496517
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
497-
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
518+
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
519+
f"d1={v-fe4m3_to_float32(r1)} d2={v-fe5m2_to_float32(r2)} "
520+
f"|d1|==|d2|={ex}"
498521
)
499522

500523
deftest_float32_to_fe4m3fn_inf(self):
@@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self):
11521175
self.assertTrue(numpy.isnan(back))
11531176

11541177
deftest_fe4m3fn_to_float32_bug(self):
1155-
cases= [(1.8131605,1.875)]
1156-
forval,expectedincases:
1157-
withself.subTest(value=val,expected=expected):
1158-
res=fe4m3_to_float32(search_float32_into_fe4m3(val))
1159-
self.assertEqual(expected,res)
1160-
res=fe4m3_to_float32(float32_to_fe4m3(val))
1161-
self.assertEqual(expected,res)
1178+
cases= [
1179+
(0.00439453125,0.00390625,TensorProto.FLOAT8E4M3FN),
1180+
(0.005859375,0.005859375,TensorProto.FLOAT8E4M3FN),
1181+
(0.005759375,0.005859375,TensorProto.FLOAT8E4M3FN),
1182+
(0.0046875,0.00390625,TensorProto.FLOAT8E4M3FN),
1183+
(0.001953125,0.001953125,TensorProto.FLOAT8E4M3FN),
1184+
(0.0029296875,0.00390625,TensorProto.FLOAT8E4M3FN),
1185+
(0.002053125,0.001953125,TensorProto.FLOAT8E4M3FN),
1186+
(0.00234375,0.001953125,TensorProto.FLOAT8E4M3FN),
1187+
(0.0087890625,0.0078125,TensorProto.FLOAT8E4M3FN),
1188+
(0.001171875,0.001953125,TensorProto.FLOAT8E4M3FN),
1189+
(1.8131605,1.875,TensorProto.FLOAT8E4M3FN),
1190+
(-100,-96,TensorProto.FLOAT8E4M3FNUZ),
1191+
(416,384,TensorProto.FLOAT8E5M2FNUZ),
1192+
]
1193+
forval,expected,ptincases:
1194+
withself.subTest(value=val,expected=expected,proto=pt):
1195+
ifpt==TensorProto.FLOAT8E4M3FN:
1196+
res=fe4m3_to_float32(search_float32_into_fe4m3(val))
1197+
self.assertEqual(expected,res)
1198+
res=fe4m3_to_float32(float32_to_fe4m3(val))
1199+
self.assertEqual(expected,res)
1200+
continue
1201+
ifpt==TensorProto.FLOAT8E4M3FNUZ:
1202+
res=fe4m3_to_float32(
1203+
search_float32_into_fe4m3(val,uz=True),uz=True
1204+
)
1205+
self.assertEqual(expected,res)
1206+
res=fe4m3_to_float32(float32_to_fe4m3(val,uz=True),uz=True)
1207+
self.assertEqual(expected,res)
1208+
continue
1209+
ifpt==TensorProto.FLOAT8E5M2FNUZ:
1210+
res=fe5m2_to_float32(
1211+
search_float32_into_fe5m2(val,fn=True,uz=True),
1212+
fn=True,
1213+
uz=True,
1214+
)
1215+
self.assertEqual(expected,res)
1216+
res=fe5m2_to_float32(
1217+
float32_to_fe5m2(val,fn=True,uz=True),fn=True,uz=True
1218+
)
1219+
self.assertEqual(expected,res)
1220+
continue
1221+
raiseAssertionError(f"Unexpected value for pt={pt}.")
11621222

11631223

11641224
if__name__=="__main__":

‎onnx_array_api/validation/f8.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,9 @@ def find_closest_value(value, sorted_values):
399399
ifd1<d2:
400400
returnsorted_values[a][1]
401401
ifd1==d2:
402-
raiseUndefinedCastError(
403-
f"Unable to cast{value}, d1={d1}, d2={d2}, "
404-
f"options are{sorted_values[a][1]} and{sorted_values[b][1]}."
405-
)
402+
# Applies rule tie to even
403+
ca,cb=sorted_values[a][1],sorted_values[b][1]
404+
returncbifca&1==1elseca
406405
returnsorted_values[b][1]
407406
returnsorted_values[a][1]
408407

@@ -520,28 +519,35 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
520519
ife<116:
521520
pass
522521
elife<117:
523-
ret|=1
522+
# first positive number
523+
ifm>0:
524+
ret|=1
524525
if (m>>23)&1:
525526
# rounding
526527
ret+=1
527-
elife<120:# 127 - 8 + 1
528-
d=119-e
529-
ret|=1<< (2-d)
530-
ret|=m>> (21+d)
531-
if (m>> (20+d))&1:
528+
elife<120:
529+
# denormalized number
530+
ex=e-119
531+
ret|=1<< (2+ex)
532+
ret|=m>> (21-ex)
533+
mask=1<< (20-ex)
534+
ifm&maskand (
535+
ret&1
536+
orm& (mask-1)>0
537+
or (m&maskandm& (mask<<1)andm& (mask-1)==0)
538+
):
532539
# rounding
533540
ret+=1
534-
elife<135:# 127 + 8
541+
elife<135:
542+
# normalized number
535543
ex=e-119# 127 - 8
536544
ifex==0:
537545
ret|=0x4
538546
ret|=m>>21
539547
else:
540548
ret|=ex<<3
541549
ret|=m>>20
542-
if (m&0x80000)and (
543-
(m&0x100000)or (m&0x7FFFF)
544-
):# round to nearest even
550+
ifm&0x80000and ((m&0x100000)or (m&0x7FFFF)):
545551
if (ret&0x7F)<0x7F:
546552
# rounding
547553
ret+=1
@@ -569,19 +575,25 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
569575
ife<117:
570576
pass
571577
elife<118:
572-
ret|=1
573-
if (m>>23)&1:
574-
# rounding
575-
ret+=1
576-
elife<121:# 127 - 7 + 1
577-
d=120-e
578-
ret|=1<< (2-d)
579-
ret|=m>> (21+d)
580-
if (m>> (20+d))&1:
578+
# first positive number
579+
ifm>0:
580+
ret|=1
581+
elife<121:
582+
# denormalized number
583+
ex=e-120
584+
ret|=1<< (2+ex)
585+
ret|=m>> (21-ex)
586+
mask=1<< (20-ex)
587+
ifm&maskand (
588+
ret&1
589+
orm& (mask-1)>0
590+
or (m&maskandm& (mask<<1)andm& (mask-1)==0)
591+
):
581592
# rounding
582593
ret+=1
583-
elife<136:# 127 + 8 + 1
584-
ex=e-120# 127 - 7
594+
elife<136:
595+
# normalized number
596+
ex=e-120
585597
ifex==0:
586598
ret|=0x4
587599
ret|=m>>21
@@ -590,9 +602,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
590602
ret|=m>>20
591603
if (ret&0x7F)==0x7F:
592604
ret&=0xFE
593-
if (m&0x80000)and (
594-
(m&0x100000)or (m&0x7FFFF)
595-
):# round to nearest even
605+
if (m&0x80000)and ((m&0x100000)or (m&0x7FFFF)):
596606
if (ret&0x7F)<0x7E:
597607
# rounding
598608
ret+=1
@@ -633,25 +643,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
633643
ife<109:
634644
pass
635645
elife<110:
636-
ret|=1
646+
# first positive number
647+
ifm>0:
648+
ret|=1
637649
if (m>>23)&1:
638650
# rounding
639-
# may be unused
640651
ret+=1
641-
elife<112:# 127 - 16 + 1
642-
d=111-e
643-
ret|=1<< (1-d)
644-
ret|=m>> (22+d)
645-
if (m>> (21+d))&1:
652+
elife<112:
653+
# denormlized number
654+
ex=e-111
655+
ret|=1<< (1+ex)
656+
ret|=m>> (22-ex)
657+
mask=1<< (21-ex)
658+
ifm&maskand (
659+
ret&1
660+
orm& (mask-1)>0
661+
or (m&maskandm& (mask<<1)andm& (mask-1)==0)
662+
):
646663
# rounding
647664
ret+=1
648-
elife<143:# 127 + 15 + 1
649-
ex=e-111# 127 - 16
665+
elife<143:
666+
# normalized number
667+
ex=e-111
650668
ret|=ex<<2
651669
ret|=m>>21
652-
ifm&0x100000and (
653-
(m&0xFFFFF)or (m&0x200000)
654-
):# round to nearest even
670+
ifm&0x100000and ((m&0xFFFFF)or (m&0x200000)):
655671
if (ret&0x7F)<0x7F:
656672
# rounding
657673
ret+=1
@@ -681,25 +697,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
681697
ife<110:
682698
pass
683699
elife<111:
684-
ret|=1
700+
# first positive number
701+
ifm>0:
702+
ret|=1
685703
if (m>>23)&1:
686704
# rounding
687-
# may be unused
688705
ret+=1
689-
elife<113:# 127 - 15 + 1
690-
d=112-e
691-
ret|=1<< (1-d)
692-
ret|=m>> (22+d)
693-
if (m>> (21+d))&1:
706+
elife<113:
707+
# denormlized number
708+
ex=e-112
709+
ret|=1<< (1+ex)
710+
ret|=m>> (22-ex)
711+
mask=1<< (21-ex)
712+
ifm&maskand (
713+
ret&1
714+
orm& (mask-1)>0
715+
or (m&maskandm& (mask<<1)andm& (mask-1)==0)
716+
):
694717
# rounding
695718
ret+=1
696-
elife<143:# 127 + 15 + 1
697-
ex=e-112# 127 - 15
719+
elife<143:
720+
# normalized number
721+
ex=e-112
698722
ret|=ex<<2
699723
ret|=m>>21
700-
ifm&0x100000and (
701-
(m&0xFFFFF)or (m&0x200000)
702-
):# round to nearest even
724+
ifm&0x100000and ((m&0xFFFFF)or (m&0x200000)):
703725
if (ret&0x7F)<0x7B:
704726
# rounding
705727
ret+=1

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp