@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
1212pass
1313
1414
15- def display_float32 ( value ,sign = 1 ,exponent = 8 ,mantissa = 23 ):
15+ def display_int ( ival ,sign = 1 ,exponent = 8 ,mantissa = 23 ):
1616"""
17- Displaysa float32 into b .
17+ Displaysan integer as bits .
1818
19- :paramvalue : value to display (float32)
19+ :paramival : value to display (float32)
2020 :param sign: number of bits for the sign
2121 :param exponent: number of bits for the exponent
2222 :param mantissa: number of bits for the mantissa
2323 :return: string
2424 """
2525t = sign + exponent + mantissa
26- ival = int .from_bytes (struct .pack ("<f" ,numpy .float32 (value )),"little" )
2726s = bin (ival )[2 :]
2827s = "0" * (t - len (s ))+ s
2928s1 = s [:sign ]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
3231return "." .join ([s1 ,s2 ,s3 ])
3332
3433
34+ def display_float32 (value ,sign = 1 ,exponent = 8 ,mantissa = 23 ):
35+ """
36+ Displays a float32 into b.
37+
38+ :param value: value to display (float32)
39+ :param sign: number of bits for the sign
40+ :param exponent: number of bits for the exponent
41+ :param mantissa: number of bits for the mantissa
42+ :return: string
43+ """
44+ return display_int (
45+ int .from_bytes (struct .pack ("<f" ,numpy .float32 (value )),"little" ),
46+ sign = sign ,
47+ exponent = exponent ,
48+ mantissa = mantissa ,
49+ )
50+
51+
3552def display_float16 (value ,sign = 1 ,exponent = 5 ,mantissa = 10 ):
3653"""
3754 Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
4259 :param mantissa: number of bits for the mantissa
4360 :return: string
4461 """
45- t = sign + exponent + mantissa
46- ival = numpy .float16 (value ).view ("H" )# pylint: disable=E1121
47- s = bin (ival )[2 :]
48- s = "0" * (t - len (s ))+ s
49- s1 = s [:sign ]
50- s2 = s [sign :sign + exponent ]
51- s3 = s [sign + exponent :]
52- return "." .join ([s1 ,s2 ,s3 ])
62+ return display_int (
63+ numpy .float16 (value ).view ("H" ),sign = sign ,exponent = exponent ,mantissa = mantissa
64+ )
5365
5466
5567def display_fexmx (value ,sign ,exponent ,mantissa ):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
6476 :param mantissa: number of bits for the mantissa
6577 :return: string
6678 """
67- t = sign + exponent + mantissa
68- ival = value
69- s = bin (ival )[2 :]
70- s = "0" * (t - len (s ))+ s
71- s1 = s [:sign ]
72- s2 = s [sign :sign + exponent ]
73- s3 = s [sign + exponent :]
74- return "." .join ([s1 ,s2 ,s3 ])
79+ return display_int (value ,sign = sign ,exponent = exponent ,mantissa = mantissa )
7580
7681
7782def display_fe4m3 (value ,sign = 1 ,exponent = 4 ,mantissa = 3 ):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534539else :
535540ret |= ex << 3
536541ret |= m >> 20
537- if m & 0x80000 :
542+ if (m & 0x80000 )and (
543+ (m & 0x100000 )or (m & 0x7FFFF )
544+ ):# round to nearest even
538545if (ret & 0x7F )< 0x7F :
539546# rounding
540547ret += 1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584591if (ret & 0x7F )== 0x7F :
585592ret &= 0xFE
586593if (m & 0x80000 )and (
587- (m & 0x100000 )or (m & 0x7C000 )
594+ (m & 0x100000 )or (m & 0x7FFFF )
588595 ):# round to nearest even
589596if (ret & 0x7F )< 0x7E :
590597# rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642649ex = e - 111 # 127 - 16
643650ret |= ex << 2
644651ret |= m >> 21
645- if m & 0x100000 :
652+ if m & 0x100000 and (
653+ (m & 0xFFFFF )or (m & 0x200000 )
654+ ):# round to nearest even
646655if (ret & 0x7F )< 0x7F :
647656# rounding
648657ret += 1