Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd248c16

Browse files
authored
Better handling of float 8 in onnx_simple_text_plot (#27)
* better handling of float 8 in onnx_simple_text_plot* add function from_array_extended* doc* refactoring
1 parentc6a3718 commitd248c16

File tree

11 files changed

+170
-12
lines changed

11 files changed

+170
-12
lines changed

‎CHANGELOGS.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ Change Logs
44
0.2.0
55
+++++
66

7-
*:pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support
7+
*:pr:`27`: add function from_array_extended to convert
8+
an array to a TensorProto, including bfloat16 and float 8 types
9+
*:pr:`24`: add ExtendedReferenceEvaluator to support scenario
10+
for the Array API onnx does not support
811
*:pr:`22`: support OrtValue in function:func:`ort_profile`
912
*:pr:`17`: implements ArrayAPI
1013
*:pr:`3`: fixes Array API with onnxruntime and scikit-learn

‎_unittests/ut_plotting/test_text_plot.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,50 @@ def test_function_plot(self):
306306
self.assertIn("type=? shape=?",text)
307307
self.assertIn("LinearRegression[custom]",text)
308308

309+
deftest_function_plot_f8(self):
310+
new_domain="custom"
311+
opset_imports= [make_opsetid("",14),make_opsetid(new_domain,1)]
312+
313+
node1=make_node("MatMul", ["X","A"], ["XA"])
314+
node2=make_node("Add", ["XA","B"], ["Y"])
315+
316+
linear_regression=make_function(
317+
new_domain,# domain name
318+
"LinearRegression",# function name
319+
["X","A","B"],# input names
320+
["Y"],# output names
321+
[node1,node2],# nodes
322+
opset_imports,# opsets
323+
[],
324+
)# attribute names
325+
326+
X=make_tensor_value_info("X",TensorProto.FLOAT8E4M3FN, [None,None])
327+
A=make_tensor_value_info("A",TensorProto.FLOAT8E5M2, [None,None])
328+
B=make_tensor_value_info("B",TensorProto.FLOAT8E4M3FNUZ, [None,None])
329+
Y=make_tensor_value_info("Y",TensorProto.FLOAT8E5M2FNUZ,None)
330+
331+
graph=make_graph(
332+
[
333+
make_node(
334+
"LinearRegression", ["X","A","B"], ["Y1"],domain=new_domain
335+
),
336+
make_node("Abs", ["Y1"], ["Y"]),
337+
],
338+
"example",
339+
[X,A,B],
340+
[Y],
341+
)
342+
343+
onnx_model=make_model(
344+
graph,opset_imports=opset_imports,functions=[linear_regression]
345+
)# functions to add)
346+
347+
text=onnx_simple_text_plot(onnx_model)
348+
self.assertIn("function name=LinearRegression domain=custom",text)
349+
self.assertIn("MatMul(X, A) -> XA",text)
350+
self.assertIn("type=? shape=?",text)
351+
self.assertIn("LinearRegression[custom]",text)
352+
309353
deftest_onnx_text_plot_tree_simple(self):
310354
iris=load_iris()
311355
X,y=iris.data.astype(numpy.float32),iris.target
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
importunittest
2+
importnumpyasnp
3+
fromonnximportTensorProto
4+
fromonnx.helperimportmake_graph,make_model,make_node,make_tensor_value_info
5+
fromonnx_array_api.ext_test_caseimportExtTestCase
6+
fromonnx_array_api.referenceimport (
7+
to_array_extended,
8+
from_array_extended,
9+
ExtendedReferenceEvaluator,
10+
)
11+
12+
13+
classTestArrayTensor(ExtTestCase):
14+
deftest_from_array(self):
15+
fordtin (np.float32,np.float16,np.uint16,np.uint8):
16+
withself.subTest(dtype=dt):
17+
a=np.array([0,1,2],dtype=dt)
18+
t=from_array_extended(a,"a")
19+
b=to_array_extended(t)
20+
self.assertEqualArray(a,b)
21+
t2=from_array_extended(b,"a")
22+
self.assertEqual(t.SerializeToString(),t2.SerializeToString())
23+
24+
deftest_from_array_f8(self):
25+
defmake_model_f8(fr,to):
26+
model=make_model(
27+
make_graph(
28+
[make_node("Cast", ["X"], ["Y"],to=to)],
29+
"cast",
30+
[make_tensor_value_info("X",fr,None)],
31+
[make_tensor_value_info("Y",to,None)],
32+
)
33+
)
34+
returnmodel
35+
36+
fordtin (np.float32,np.float16,np.uint16,np.uint8):
37+
withself.subTest(dtype=dt):
38+
a=np.array([0,1,2],dtype=dt)
39+
b=from_array_extended(a,"a")
40+
fortoin [
41+
TensorProto.FLOAT8E4M3FN,
42+
TensorProto.FLOAT8E4M3FNUZ,
43+
TensorProto.FLOAT8E5M2,
44+
TensorProto.FLOAT8E5M2FNUZ,
45+
TensorProto.BFLOAT16,
46+
]:
47+
withself.subTest(fr=b.data_type,to=to):
48+
model=make_model_f8(b.data_type,to)
49+
ref=ExtendedReferenceEvaluator(model)
50+
got=ref.run(None, {"X":a})[0]
51+
back=from_array_extended(got,"a")
52+
self.assertEqual(to,back.data_type)
53+
54+
55+
if__name__=="__main__":
56+
unittest.main(verbosity=2)

‎onnx_array_api/npx/npx_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
importnumpyasnp
44
fromonnximportFunctionProto,ModelProto,NodeProto,TensorProto
55
fromonnx.helperimportmake_tensor,tensor_dtype_to_np_dtype
6-
fromonnx.numpy_helperimportfrom_array
6+
from..referenceimportfrom_array_extendedasfrom_array
77
from .npx_constantsimportFUNCTION_DOMAIN
88
from .npx_core_apiimportcst,make_tuple,npxapi_inline,npxapi_no_inline,var
99
from .npx_typesimport (

‎onnx_array_api/npx/npx_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
make_opsetid,
2525
make_tensor_value_info,
2626
)
27-
fromonnx.numpy_helperimportfrom_array
2827
fromonnx.onnx_cpp2py_export.checkerimportValidationError
2928
fromonnx.onnx_cpp2py_export.shape_inferenceimportInferenceError
3029
fromonnx.shape_inferenceimportinfer_shapes
3130

31+
from ..referenceimportfrom_array_extendedasfrom_array
3232
from .npx_constantsimport_OPSET_TO_IR_VERSION,FUNCTION_DOMAIN,ONNX_DOMAIN
3333
from .npx_function_implementationimportget_function_implementation
3434
from .npx_helperimport (

‎onnx_array_api/npx/npx_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
make_operatorsetid,
1010
make_value_info,
1111
)
12-
fromonnx.numpy_helperimportfrom_array
1312
fromonnx.version_converterimportconvert_version
13+
from ..referenceimportfrom_array_extendedasfrom_array
1414

1515

1616
defrename_in_onnx_graph(

‎onnx_array_api/plotting/_helper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ValueInfoProto,
1111
)
1212
fromonnx.helperimporttensor_dtype_to_np_dtype
13-
fromonnx.numpy_helperimportto_array
13+
from..referenceimportto_array_extendedasto_array
1414
from ..npx.npx_typesimportDType
1515

1616

@@ -136,12 +136,25 @@ def _get_type(obj0):
136136
returntensor_dtype_to_np_dtype(TensorProto.DOUBLE)
137137
ifobj.data_type==TensorProto.INT64andhasattr(obj,"int64_data"):
138138
returntensor_dtype_to_np_dtype(TensorProto.INT64)
139-
ifobj.data_type==TensorProto.INT32andhasattr(obj,"int32_data"):
139+
ifobj.data_typein (
140+
TensorProto.INT8,
141+
TensorProto.UINT8,
142+
TensorProto.UINT16,
143+
TensorProto.INT16,
144+
TensorProto.INT32,
145+
TensorProto.FLOAT8E4M3FN,
146+
TensorProto.FLOAT8E4M3FNUZ,
147+
TensorProto.FLOAT8E5M2,
148+
TensorProto.FLOAT8E5M2FNUZ,
149+
)andhasattr(obj,"int32_data"):
140150
returntensor_dtype_to_np_dtype(TensorProto.INT32)
141151
ifhasattr(obj,"raw_data")andlen(obj.raw_data)>0:
142152
arr=to_array(obj)
143153
returnarr.dtype
144-
raiseRuntimeError(f"Unable to guess type from{obj0!r}.")
154+
raiseRuntimeError(
155+
f"Unable to guess type from obj.data_type={obj.data_type} "
156+
f"and obj={obj0!r} -{TensorProto.__dict__}."
157+
)
145158
ifhasattr(obj,"type"):
146159
obj=obj.type
147160
ifhasattr(obj,"tensor_type"):

‎onnx_array_api/plotting/dot_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
fromonnximportGraphProto,ModelProto
55
fromonnx.helperimporttensor_dtype_to_string
6-
fromonnx.numpy_helperimportto_array
76

7+
from ..referenceimportto_array_extendedasto_array
88
from ._helperimportGraph,_get_shape,attributes_as_dict
99

1010

‎onnx_array_api/plotting/text_plot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
importpprint
22
fromcollectionsimportOrderedDict
3-
43
importnumpy
54
fromonnximportAttributeProto
6-
fromonnx.numpy_helperimportto_array
7-
5+
from ..referenceimportto_array_extendedasto_array
86
from ._helperimport_get_shape,_get_type,attributes_as_dict
97

108

‎onnx_array_api/reference/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,45 @@
1+
fromtypingimportOptional
2+
importnumpyasnp
3+
fromonnximportTensorProto
4+
fromonnx.numpy_helperimportfrom_arrayasonnx_from_array
5+
fromonnx.reference.ops.op_castimport (
6+
bfloat16,
7+
float8e4m3fn,
8+
float8e4m3fnuz,
9+
float8e5m2,
10+
float8e5m2fnuz,
11+
)
12+
fromonnx.reference.op_runimportto_array_extended
113
from .evaluatorimportExtendedReferenceEvaluator
14+
15+
16+
deffrom_array_extended(tensor:np.array,name:Optional[str]=None)->TensorProto:
17+
"""
18+
Converts an array into a TensorProto.
19+
20+
:param tensor: numpy array
21+
:param name: name
22+
:return: TensorProto
23+
"""
24+
dt=tensor.dtype
25+
ifdt==float8e4m3fnanddt.descr[0][0]=="e4m3fn":
26+
to=TensorProto.FLOAT8E4M3FN
27+
dt_to=np.uint8
28+
elifdt==float8e4m3fnuzanddt.descr[0][0]=="e4m3fnuz":
29+
to=TensorProto.FLOAT8E4M3FNUZ
30+
dt_to=np.uint8
31+
elifdt==float8e5m2anddt.descr[0][0]=="e5m2":
32+
to=TensorProto.FLOAT8E5M2
33+
dt_to=np.uint8
34+
elifdt==float8e5m2fnuzanddt.descr[0][0]=="e5m2fnuz":
35+
to=TensorProto.FLOAT8E5M2FNUZ
36+
dt_to=np.uint8
37+
elifdt==bfloat16anddt.descr[0][0]=="bfloat16":
38+
to=TensorProto.BFLOAT16
39+
dt_to=np.uint16
40+
else:
41+
returnonnx_from_array(tensor,name)
42+
43+
t=onnx_from_array(tensor.astype(dt_to),name)
44+
t.data_type=to
45+
returnt

‎onnx_array_api/validation/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
make_node,
1717
set_model_props,
1818
)
19-
fromonnx.numpy_helperimportfrom_array,to_array
19+
from..referenceimportfrom_array_extendedasfrom_array,to_array_extendedasto_array
2020

2121

2222
defrandomize_proto(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp