Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit3de3c5d

Browse files
authored
Supports for local functions in translator (#96)
* fix suffix* one fix* fix* fix ut* fix ir_version* doc
1 parent664e084 commit3de3c5d

File tree

6 files changed

+257
-32
lines changed

6 files changed

+257
-32
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.1
55
+++++
66

7+
*:pr:`96`: supports local functions in translator
78
*:pr:`95`: improves translation to GraphBuilder
89

910
0.3.0

‎_unittests/ut_translate_api/test_translate_builder.py‎

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importunittest
22
fromtextwrapimportdedent
33
importnumpyasnp
4+
importonnx.helperasoh
45
fromonnximportModelProto,TensorProto
56
fromonnx.checkerimportcheck_model
67
fromonnx.defsimportonnx_opset_version
@@ -29,37 +30,43 @@ def test_exp(self):
2930
self.assertEqualArray(np.exp(a),got)
3031

3132
code=translate(onx,api="builder")
32-
expected=dedent(
33-
"""
33+
expected= (
34+
dedent(
35+
"""
3436
def light_api(
3537
op: "GraphBuilder",
3638
X: "FLOAT[]",
3739
):
38-
Y = op.Exp(X)
40+
Y = op.Exp(X, outputs=['Y'])
3941
op.Identity(Y, outputs=["Y"])
4042
return Y
4143
4244
g = GraphBuilder({'': 19}, ir_version=10)
4345
g.make_tensor_input("X", TensorProto.FLOAT, ())
4446
light_api(g.op, "X")
45-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
4648
model = g.to_onnx()
4749
"""
48-
).strip("\n")
50+
)
51+
.strip("\n")
52+
.replace("__SUFFIX__",", is_dimension=False, indexed=False")
53+
)
4954
self.assertEqual(expected,code.strip("\n"))
5055

5156
deflight_api(
5257
op:"GraphBuilder",
5358
X:"FLOAT[]",# noqa: F722
5459
):
55-
Y=op.Exp(X)
60+
Y=op.Exp(X,outputs=["Y"])
5661
op.Identity(Y,outputs=["Y"])
5762
returnY
5863

5964
g2=GraphBuilder({"":19})
6065
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
6166
light_api(g2.op,"X")
62-
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
67+
g2.make_tensor_output(
68+
"Y",TensorProto.FLOAT, ("A",),is_dimension=False,indexed=False
69+
)
6370
onx2=g2.to_onnx()
6471

6572
ref=ReferenceEvaluator(onx2)
@@ -78,25 +85,29 @@ def test_zdoc(self):
7885
.to_onnx()
7986
)
8087
code=translate(onx,api="builder")
81-
expected=dedent(
82-
"""
88+
expected= (
89+
dedent(
90+
"""
8391
def light_api(
8492
op: "GraphBuilder",
8593
X: "FLOAT[]",
8694
):
8795
r = np.array([-1, 1], dtype=np.int64)
88-
r0_0 = op.Reshape(X, r)
89-
Y = op.Transpose(r0_0, perm=[1, 0])
96+
r0_0 = op.Reshape(X, r, outputs=['r0_0'])
97+
Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'])
9098
op.Identity(Y, outputs=["Y"])
9199
return Y
92100
93101
g = GraphBuilder({'': 19}, ir_version=10)
94102
g.make_tensor_input("X", TensorProto.FLOAT, ())
95103
light_api(g.op, "X")
96-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
97105
model = g.to_onnx()
98106
"""
99-
).strip("\n")
107+
)
108+
.strip("\n")
109+
.replace("__SUFFIX__",", is_dimension=False, indexed=False")
110+
)
100111
self.maxDiff=None
101112
self.assertEqual(expected,code.strip("\n"))
102113

@@ -130,13 +141,14 @@ def test_exp_f(self):
130141
tr=Translater(onx,emitter=BuilderEmitter("mm"))
131142
code=tr.export(as_str=True)
132143

133-
expected=dedent(
134-
"""
144+
expected= (
145+
dedent(
146+
"""
135147
def light_api(
136148
op: "GraphBuilder",
137149
X: "FLOAT[]",
138150
):
139-
Y = op.Exp(X)
151+
Y = op.Exp(X, outputs=['Y'])
140152
op.Identity(Y, outputs=["Y"])
141153
return Y
142154
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157
g = GraphBuilder({'': 19}, ir_version=10)
146158
g.make_tensor_input("X", TensorProto.FLOAT, ())
147159
light_api(g.op, "X")
148-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
149161
model = g.to_onnx()
150162
return model
151163
152164
153165
model = mm()
154166
"""
155-
).strip("\n")
167+
)
168+
.strip("\n")
169+
.replace("__SUFFIX__",", is_dimension=False, indexed=False")
170+
)
156171
self.assertEqual(expected,code.strip("\n"))
157172

158173
deflight_api(
@@ -166,14 +181,105 @@ def light_api(
166181
g2=GraphBuilder({"":19})
167182
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
168183
light_api(g2.op,"X")
169-
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
184+
g2.make_tensor_output(
185+
"Y",TensorProto.FLOAT, ("A",),is_dimension=False,indexed=False
186+
)
170187
onx2=g2.to_onnx()
171188

172189
ref=ReferenceEvaluator(onx2)
173190
a=np.arange(10).astype(np.float32)
174191
got=ref.run(None, {"X":a})[0]
175192
self.assertEqualArray(np.exp(a),got)
176193

194+
deftest_local_function(self):
195+
new_domain="custom"
196+
197+
linear_regression=oh.make_function(
198+
new_domain,
199+
"LinearRegression",
200+
["x","a","b"],
201+
["y"],
202+
[
203+
oh.make_node("MatMul", ["x","a"], ["xa"]),
204+
oh.make_node("Add", ["xa","b"], ["y"]),
205+
],
206+
[oh.make_opsetid("",14)],
207+
[],
208+
)
209+
210+
graph=oh.make_graph(
211+
[
212+
oh.make_node(
213+
"LinearRegression", ["X","A","B"], ["Y1"],domain=new_domain
214+
),
215+
oh.make_node("Abs", ["Y1"], ["Y"]),
216+
],
217+
"example",
218+
[
219+
oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None]),
220+
oh.make_tensor_value_info("A",TensorProto.FLOAT, [None,None]),
221+
oh.make_tensor_value_info("B",TensorProto.FLOAT, [None,None]),
222+
],
223+
[oh.make_tensor_value_info("Y",TensorProto.FLOAT,None)],
224+
)
225+
226+
onnx_model=oh.make_model(
227+
graph,
228+
opset_imports=[oh.make_opsetid("",14),oh.make_opsetid(new_domain,1)],
229+
functions=[linear_regression],
230+
ir_version=10,
231+
)
232+
tr=Translater(onnx_model,emitter=BuilderEmitter("mm"))
233+
code=tr.export(as_str=True)
234+
235+
expected= (
236+
dedent(
237+
"""
238+
def example(
239+
op: "GraphBuilder",
240+
X: "FLOAT[, ]",
241+
A: "FLOAT[, ]",
242+
B: "FLOAT[, ]",
243+
):
244+
Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
245+
Y = op.Abs(Y1, outputs=['Y'])
246+
op.Identity(Y, outputs=["Y"])
247+
return Y
248+
249+
250+
def make_custom_LinearRegression(g: "GraphBuilder"):
251+
gr = GraphBuilder({'': 14}, as_function=True)
252+
x = gr.make_tensor_input('x')
253+
a = gr.make_tensor_input('a')
254+
b = gr.make_tensor_input('b')
255+
op = gr.op
256+
xa = op.MatMul(x, a, outputs=['xa'])
257+
y = op.Add(xa, b, outputs=['y'])
258+
gr.make_tensor_output(y)
259+
g.add_function(builder=gr)
260+
return gr
261+
262+
263+
def mm() -> "ModelProto":
264+
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
265+
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
266+
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
267+
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
268+
example(g.op, "X", "A", "B")
269+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
270+
make_custom_LinearRegression(g)
271+
model = g.to_onnx()
272+
return model
273+
274+
275+
model = mm()
276+
"""
277+
)
278+
.strip("\n")
279+
.replace("__SUFFIX__",", is_dimension=False, indexed=False")
280+
)
281+
self.assertEqual(expected,code.strip("\n"))
282+
177283

178284
if__name__=="__main__":
179285
unittest.main(verbosity=2)

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
self._known_shapes= {}
195195
self._known_types= {}
196196
self.constants_= {}
197+
self.functions_= {}
197198
elifisinstance(target_opset_or_existing_proto,ModelProto):
198199
assert (
199200
notinput_names
@@ -223,6 +224,8 @@ def __init__(
223224
self.constants_[node.output[0]]=node
224225
self.set_shape(node.output[0],self._get_tensor_shape(node))
225226
self.set_type(node.output[0],self._get_tensor_type(node))
227+
forfinproto.functions:
228+
self.add_function(f)
226229
else:
227230
raiseNotImplementedError(
228231
f"{type(target_opset_or_existing_proto)} is not supported."
@@ -231,6 +234,14 @@ def __init__(
231234
self.op=Opset(self,self.opsets[""])if""inself.opsetselseNone
232235
self._cache_array= []
233236

237+
defadd_local_function(self,domain:str,name:str,gr:"GraphBuilder"):
238+
"Adds a local function."
239+
assert (
240+
domain,
241+
name,
242+
)notinself.functions_,f"Function{(domain,name)} was already added."
243+
self.functions_[domain,name]=gr
244+
234245
def_get_tensor_shape(
235246
self,proto:Union[NodeProto,TensorProto]
236247
)->Tuple[int, ...]:
@@ -417,6 +428,8 @@ def make_tensor_output(
417428
name:Union[str,List[str]],
418429
elem_type:Optional[int]=None,
419430
shape:Optional[Tuple[int, ...]]=None,
431+
is_dimension:bool=False,
432+
indexed:bool=False,
420433
)->Union[str,List[str]]:
421434
ifisinstance(name,list):
422435
res= []

‎onnx_array_api/translate_api/base_emitter.py‎

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class EventType(IntEnum):
2525
END_SIGNATURE=16
2626
BEGIN_RETURN=17
2727
END_RETURN=18
28+
BEGIN_FUNCTION_SIGNATURE=19
29+
END_FUNCTION_SIGNATURE=20
30+
BEGIN_FUNCTION_RETURN=21
31+
END_FUNCTION_RETURN=22
2832

2933
@classmethod
3034
defto_str(cls,self)->str:
@@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
7680
ifevent==EventType.BEGIN_FUNCTION:
7781
returnself._emit_begin_function(**kwargs)
7882

83+
ifevent==EventType.BEGIN_FUNCTION_SIGNATURE:
84+
returnself._emit_begin_function_signature(**kwargs)
85+
86+
ifevent==EventType.END_FUNCTION_SIGNATURE:
87+
returnself._emit_end_function_signature(**kwargs)
88+
7989
ifevent==EventType.END_FUNCTION:
8090
returnself._emit_end_function(**kwargs)
8191

@@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
100110
ifevent==EventType.END_RETURN:
101111
returnself._emit_end_return(**kwargs)
102112

113+
ifevent==EventType.BEGIN_FUNCTION_RETURN:
114+
returnself._emit_begin_function_return(**kwargs)
115+
116+
ifevent==EventType.END_FUNCTION_RETURN:
117+
returnself._emit_end_function_return(**kwargs)
118+
103119
raiseValueError(f"Unexpected event{EventType.to_str(event)}.")
104120

105121
defrender_attribute_value(self,value:Any)->Tuple[List[str],str]:
@@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
224240
f"Method{inspect.currentframe().f_code.co_name!r} was not overloaded."
225241
)
226242

243+
def_emit_begin_function_signature(self,**kwargs:Dict[str,Any])->List[str]:
244+
return []
245+
246+
def_emit_end_function_signature(self,**kwargs:Dict[str,Any])->List[str]:
247+
return []
248+
227249
def_emit_function_input(self,**kwargs:Dict[str,Any])->List[str]:
228250
raiseNotImplementedError(
229251
f"Method{inspect.currentframe().f_code.co_name!r} was not overloaded."
@@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
250272

251273
def_emit_end_return(self,**kwargs:Dict[str,Any])->List[str]:
252274
return []
275+
276+
def_emit_begin_function_return(self,**kwargs:Dict[str,Any])->List[str]:
277+
return []
278+
279+
def_emit_end_function_return(self,**kwargs:Dict[str,Any])->List[str]:
280+
return []

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp