Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9351aca

Browse files
committed
ch
1 parente29df50 commit9351aca

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.3.1
55
+++++
66

7-
*:pr:`94`: improves translation to GraphBuilder
7+
*:pr:`95`: improves translation to GraphBuilder
88

99
0.3.0
1010
+++++

‎_unittests/ut_translate_api/test_translate_builder.py‎

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
fromonnx_array_api.ext_test_caseimportExtTestCase
99
fromonnx_array_api.light_apiimportstart
1010
fromonnx_array_api.graph_apiimportGraphBuilder
11-
fromonnx_array_api.translate_apiimporttranslate
11+
fromonnx_array_api.translate_apiimporttranslate,Translater
12+
fromonnx_array_api.translate_api.builder_emitterimportBuilderEmitter
1213

1314

1415
OPSET_API=min(19,onnx_opset_version()-1)
@@ -38,7 +39,7 @@ def light_api(
3839
op.Identity(Y, outputs=["Y"])
3940
return Y
4041
41-
g = GraphBuilder({'': 19})
42+
g = GraphBuilder({'': 19}, ir_version=11)
4243
g.make_tensor_input("X", TensorProto.FLOAT, ())
4344
light_api(g.op, "X")
4445
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -89,7 +90,7 @@ def light_api(
8990
op.Identity(Y, outputs=["Y"])
9091
return Y
9192
92-
g = GraphBuilder({'': 19})
93+
g = GraphBuilder({'': 19}, ir_version=11)
9394
g.make_tensor_input("X", TensorProto.FLOAT, ())
9495
light_api(g.op, "X")
9596
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -117,6 +118,62 @@ def light_api(
117118
self.assertNotEmpty(model)
118119
check_model(model)
119120

121+
deftest_exp_f(self):
122+
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
123+
self.assertIsInstance(onx,ModelProto)
124+
self.assertIn("Exp",str(onx))
125+
ref=ReferenceEvaluator(onx)
126+
a=np.arange(10).astype(np.float32)
127+
got=ref.run(None, {"X":a})[0]
128+
self.assertEqualArray(np.exp(a),got)
129+
130+
tr=Translater(onx,emitter=BuilderEmitter("mm"))
131+
code=tr.export(as_str=True)
132+
133+
expected=dedent(
134+
"""
135+
def light_api(
136+
op: "GraphBuilder",
137+
X: "FLOAT[]",
138+
):
139+
Y = op.Exp(X)
140+
op.Identity(Y, outputs=["Y"])
141+
return Y
142+
143+
144+
def mm() -> "ModelProto":
145+
g = GraphBuilder({'': 19}, ir_version=11)
146+
g.make_tensor_input("X", TensorProto.FLOAT, ())
147+
light_api(g.op, "X")
148+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
149+
model = g.to_onnx()
150+
return model
151+
152+
153+
model = mm()
154+
"""
155+
).strip("\n")
156+
self.assertEqual(expected,code.strip("\n"))
157+
158+
deflight_api(
159+
op:"GraphBuilder",
160+
X:"FLOAT[]",# noqa: F722
161+
):
162+
Y=op.Exp(X)
163+
op.Identity(Y,outputs=["Y"])
164+
returnY
165+
166+
g2=GraphBuilder({"":19})
167+
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
168+
light_api(g2.op,"X")
169+
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
170+
onx2=g2.to_onnx()
171+
172+
ref=ReferenceEvaluator(onx2)
173+
a=np.arange(10).astype(np.float32)
174+
got=ref.run(None, {"X":a})[0]
175+
self.assertEqualArray(np.exp(a),got)
176+
120177

121178
if__name__=="__main__":
122179
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp