Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit448a0f2

Browse files
committed
Adds validation for float 8
1 parent50c4d52 commit448a0f2

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

‎_unittests/ut_validation/test_f8.py‎

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ def test_search_float32_into_fe5m2(self):
237237
else:
238238
add=v-value
239239
iflen(w)>0:
240-
raiseAssertionError(f"A warning was thrown for v={v}, value={value}, w={w[0]}.")
240+
raiseAssertionError(
241+
f"A warning was thrown for v={v}, "
242+
f"value={value}, w={w[0]}."
243+
)
241244
else:
242245
v=value+add
243246
b=search_float32_into_fe5m2(v)
@@ -306,9 +309,9 @@ def test_inf_nan(self):
306309
0.203125,
307310
0.75,
308311
numpy.nan,
309-
numpy.nan,
310-
-numpy.nan,
311-
-numpy.nan,
312+
max(CastFloat8.values_e4m3fn)[0],
313+
max(CastFloat8.values_e4m3fn)[0],
314+
min(CastFloat8.values_e4m3fn)[0],
312315
],
313316
dtype=numpy.float32,
314317
)
@@ -380,26 +383,27 @@ def test_search_e5m2_pow(self):
380383
)
381384

382385
deftest_float32_to_fe4m3fn_inf(self):
383-
mx=
384-
v0=numpy.float32(448)
386+
mx=max(CastFloat8.values_e4m3fn)[0]
387+
v0=numpy.float32(mx)
385388
v1=numpy.float32(numpy.inf)
386389
a=search_float32_into_fe4m3(v0)
387390
b=search_float32_into_fe4m3(v1)
388391
self.assertEqual(a,b)
389392

390-
v0=numpy.float32(448)
393+
v0=numpy.float32(mx)
391394
v1=numpy.float32(numpy.inf)
392395
a=float32_to_fe4m3(v0)
393396
b=float32_to_fe4m3(v1)
394397
self.assertEqual(a,b)
395398

396-
v0=numpy.float32(-448)
399+
mi=min(CastFloat8.values_e4m3fn)[0]
400+
v0=numpy.float32(mi)
397401
v1=numpy.float32(-numpy.inf)
398402
a=search_float32_into_fe4m3(v0)
399403
b=search_float32_into_fe4m3(v1)
400404
self.assertEqual(a,b)
401405

402-
v0=numpy.float32(-448)
406+
v0=numpy.float32(mi)
403407
v1=numpy.float32(-numpy.inf)
404408
a=float32_to_fe4m3(v0)
405409
b=float32_to_fe4m3(v1)
@@ -666,18 +670,32 @@ def test_float32_to_fe4m3fnuz_inf(self):
666670
self.assertNotEqual(a,b)
667671

668672
deftest_float32_to_fe5m2fnuz_inf(self):
669-
v0=numpy.float32(65536)
673+
mx=max(CastFloat8.values_e5m2fnuz)[0]
674+
v0=numpy.float32(mx)
670675
v1=numpy.float32(numpy.inf)
671676
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
672677
b=search_float32_into_fe5m2(v1,fn=True,uz=True)
673678
self.assertEqual(a,b)
674679

675-
v0=numpy.float32(65536)
680+
v0=numpy.float32(mx)
676681
v1=numpy.float32(numpy.inf)
677682
a=float32_to_fe5m2(v0,fn=True,uz=True)
678683
b=float32_to_fe5m2(v1,fn=True,uz=True)
679684
self.assertEqual(a,b)
680685

686+
mi=min(CastFloat8.values_e5m2fnuz)[0]
687+
v0=numpy.float32(mi)
688+
v1=numpy.float32(-numpy.inf)
689+
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
690+
b=search_float32_into_fe5m2(v1,fn=True,uz=True)
691+
self.assertEqual(a,b)
692+
693+
v0=numpy.float32(mi)
694+
v1=numpy.float32(-numpy.inf)
695+
a=float32_to_fe5m2(v0,fn=True,uz=True)
696+
b=float32_to_fe5m2(v1,fn=True,uz=True)
697+
self.assertEqual(a,b)
698+
681699
v0=numpy.float32(numpy.nan)
682700
v1=numpy.float32(-numpy.nan)
683701
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
@@ -688,7 +706,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
688706
v1=numpy.float32(-numpy.inf)
689707
a=search_float32_into_fe5m2(v0,fn=True,uz=True)
690708
b=search_float32_into_fe5m2(v1,fn=True,uz=True)
691-
self.assertEqual(a,b)
709+
self.assertNotEqual(a,b)
692710

693711
v0=numpy.float32(numpy.nan)
694712
v1=numpy.float32(-numpy.nan)
@@ -700,7 +718,7 @@ def test_float32_to_fe5m2fnuz_inf(self):
700718
v1=numpy.float32(-numpy.inf)
701719
a=float32_to_fe5m2(v0,fn=True,uz=True)
702720
b=float32_to_fe5m2(v1,fn=True,uz=True)
703-
self.assertEqual(a,b)
721+
self.assertNotEqual(a,b)
704722

705723
deftest_simple_fe4m3(self):
706724
values= [448]

‎onnx_array_api/validation/f8.py‎

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,16 @@ def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) -
378378
b=int.from_bytes(struct.pack("<f",numpy.float32(value)),"little")
379379
ret= (b&0x80000000)>>24# sign
380380
ifuz:
381-
ifnumpy.isnan(value)ornumpy.isinf(value):
381+
ifnumpy.isnan(value):
382382
return0x80
383+
ifnumpy.isinf(value):
384+
returnret|0x7F
383385
set_values=CastFloat8.values_e4m3fnuz
384386
else:
385-
ifnumpy.isnan(value)ornumpy.isinf(value):
387+
ifnumpy.isnan(value):
386388
return0x7F|ret
389+
ifnumpy.isinf(value):
390+
return0x7E|ret
387391
set_values=CastFloat8.values_e4m3fn
388392
f=numpy.float32(value)
389393
i=CastFloat8.find_closest_value(f,set_values)
@@ -403,8 +407,10 @@ def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False)
403407
ret= (b&0x80000000)>>24# sign
404408

405409
iffnanduz:
406-
ifnumpy.isnan(value)ornumpy.isinf(value):
410+
ifnumpy.isnan(value):
407411
return0x80
412+
ifnumpy.isinf(value):
413+
returnret|0x7F
408414
set_values=CastFloat8.values_e5m2fnuz
409415
elifnotfnandnotuz:
410416
ifnumpy.isnan(value):
@@ -435,7 +441,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
435441
if (b&0x7FC00000)==0x7FC00000:
436442
return0x80
437443
ifnumpy.isinf(x):
438-
return0x80
444+
returnret|0x7F# saturation
439445
e= (b&0x7F800000)>>23# exponent
440446
m=b&0x007FFFFF# mantissa
441447

@@ -472,7 +478,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False):
472478
if (b&0x7FC00000)==0x7FC00000:
473479
return0x7F|ret
474480
ifnumpy.isinf(x):
475-
return0x7F|ret
481+
return0x7E|ret# saturation
476482
e= (b&0x7F800000)>>23# exponent
477483
m=b&0x007FFFFF# mantissa
478484

@@ -524,6 +530,8 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False):
524530
iffnanduz:
525531
if (b&0x7FC00000)==0x7FC00000:
526532
return0x80
533+
if (b&0x7FFFFFFF)==0x7F800000:
534+
returnret|0x7F
527535
e= (b&0x7F800000)>>23# exponent
528536
m=b&0x007FFFFF# mantissa
529537

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp