11import unittest
22from textwrap import dedent
33import numpy as np
4+ import onnx .helper as oh
45from onnx import ModelProto ,TensorProto
56from onnx .checker import check_model
67from onnx .defs import onnx_opset_version
@@ -29,37 +30,43 @@ def test_exp(self):
2930self .assertEqualArray (np .exp (a ),got )
3031
3132code = translate (onx ,api = "builder" )
32- expected = dedent (
33- """
33+ expected = (
34+ dedent (
35+ """
3436 def light_api(
3537 op: "GraphBuilder",
3638 X: "FLOAT[]",
3739 ):
38- Y = op.Exp(X)
40+ Y = op.Exp(X, outputs=['Y'] )
3941 op.Identity(Y, outputs=["Y"])
4042 return Y
4143
4244 g = GraphBuilder({'': 19}, ir_version=10)
4345 g.make_tensor_input("X", TensorProto.FLOAT, ())
4446 light_api(g.op, "X")
45- g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
4648 model = g.to_onnx()
4749 """
48- ).strip ("\n " )
50+ )
51+ .strip ("\n " )
52+ .replace ("__SUFFIX__" ,", is_dimension=False, indexed=False" )
53+ )
4954self .assertEqual (expected ,code .strip ("\n " ))
5055
5156def light_api (
5257op :"GraphBuilder" ,
5358X :"FLOAT[]" ,# noqa: F722
5459 ):
55- Y = op .Exp (X )
60+ Y = op .Exp (X , outputs = [ "Y" ] )
5661op .Identity (Y ,outputs = ["Y" ])
5762return Y
5863
5964g2 = GraphBuilder ({"" :19 })
6065g2 .make_tensor_input ("X" ,TensorProto .FLOAT , ("A" ,))
6166light_api (g2 .op ,"X" )
62- g2 .make_tensor_output ("Y" ,TensorProto .FLOAT , ("A" ,))
67+ g2 .make_tensor_output (
68+ "Y" ,TensorProto .FLOAT , ("A" ,),is_dimension = False ,indexed = False
69+ )
6370onx2 = g2 .to_onnx ()
6471
6572ref = ReferenceEvaluator (onx2 )
@@ -78,25 +85,29 @@ def test_zdoc(self):
7885 .to_onnx ()
7986 )
8087code = translate (onx ,api = "builder" )
81- expected = dedent (
82- """
88+ expected = (
89+ dedent (
90+ """
8391 def light_api(
8492 op: "GraphBuilder",
8593 X: "FLOAT[]",
8694 ):
8795 r = np.array([-1, 1], dtype=np.int64)
88- r0_0 = op.Reshape(X, r)
89- Y = op.Transpose(r0_0, perm=[1, 0])
96+ r0_0 = op.Reshape(X, r, outputs=['r0_0'] )
97+ Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'] )
9098 op.Identity(Y, outputs=["Y"])
9199 return Y
92100
93101 g = GraphBuilder({'': 19}, ir_version=10)
94102 g.make_tensor_input("X", TensorProto.FLOAT, ())
95103 light_api(g.op, "X")
96- g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
97105 model = g.to_onnx()
98106 """
99- ).strip ("\n " )
107+ )
108+ .strip ("\n " )
109+ .replace ("__SUFFIX__" ,", is_dimension=False, indexed=False" )
110+ )
100111self .maxDiff = None
101112self .assertEqual (expected ,code .strip ("\n " ))
102113
@@ -130,13 +141,14 @@ def test_exp_f(self):
130141tr = Translater (onx ,emitter = BuilderEmitter ("mm" ))
131142code = tr .export (as_str = True )
132143
133- expected = dedent (
134- """
144+ expected = (
145+ dedent (
146+ """
135147 def light_api(
136148 op: "GraphBuilder",
137149 X: "FLOAT[]",
138150 ):
139- Y = op.Exp(X)
151+ Y = op.Exp(X, outputs=['Y'] )
140152 op.Identity(Y, outputs=["Y"])
141153 return Y
142154
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157 g = GraphBuilder({'': 19}, ir_version=10)
146158 g.make_tensor_input("X", TensorProto.FLOAT, ())
147159 light_api(g.op, "X")
148- g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
149161 model = g.to_onnx()
150162 return model
151163
152164
153165 model = mm()
154166 """
155- ).strip ("\n " )
167+ )
168+ .strip ("\n " )
169+ .replace ("__SUFFIX__" ,", is_dimension=False, indexed=False" )
170+ )
156171self .assertEqual (expected ,code .strip ("\n " ))
157172
158173def light_api (
@@ -166,14 +181,105 @@ def light_api(
166181g2 = GraphBuilder ({"" :19 })
167182g2 .make_tensor_input ("X" ,TensorProto .FLOAT , ("A" ,))
168183light_api (g2 .op ,"X" )
169- g2 .make_tensor_output ("Y" ,TensorProto .FLOAT , ("A" ,))
184+ g2 .make_tensor_output (
185+ "Y" ,TensorProto .FLOAT , ("A" ,),is_dimension = False ,indexed = False
186+ )
170187onx2 = g2 .to_onnx ()
171188
172189ref = ReferenceEvaluator (onx2 )
173190a = np .arange (10 ).astype (np .float32 )
174191got = ref .run (None , {"X" :a })[0 ]
175192self .assertEqualArray (np .exp (a ),got )
176193
194+ def test_local_function (self ):
195+ new_domain = "custom"
196+
197+ linear_regression = oh .make_function (
198+ new_domain ,
199+ "LinearRegression" ,
200+ ["x" ,"a" ,"b" ],
201+ ["y" ],
202+ [
203+ oh .make_node ("MatMul" , ["x" ,"a" ], ["xa" ]),
204+ oh .make_node ("Add" , ["xa" ,"b" ], ["y" ]),
205+ ],
206+ [oh .make_opsetid ("" ,14 )],
207+ [],
208+ )
209+
210+ graph = oh .make_graph (
211+ [
212+ oh .make_node (
213+ "LinearRegression" , ["X" ,"A" ,"B" ], ["Y1" ],domain = new_domain
214+ ),
215+ oh .make_node ("Abs" , ["Y1" ], ["Y" ]),
216+ ],
217+ "example" ,
218+ [
219+ oh .make_tensor_value_info ("X" ,TensorProto .FLOAT , [None ,None ]),
220+ oh .make_tensor_value_info ("A" ,TensorProto .FLOAT , [None ,None ]),
221+ oh .make_tensor_value_info ("B" ,TensorProto .FLOAT , [None ,None ]),
222+ ],
223+ [oh .make_tensor_value_info ("Y" ,TensorProto .FLOAT ,None )],
224+ )
225+
226+ onnx_model = oh .make_model (
227+ graph ,
228+ opset_imports = [oh .make_opsetid ("" ,14 ),oh .make_opsetid (new_domain ,1 )],
229+ functions = [linear_regression ],
230+ ir_version = 10 ,
231+ )
232+ tr = Translater (onnx_model ,emitter = BuilderEmitter ("mm" ))
233+ code = tr .export (as_str = True )
234+
235+ expected = (
236+ dedent (
237+ """
238+ def example(
239+ op: "GraphBuilder",
240+ X: "FLOAT[, ]",
241+ A: "FLOAT[, ]",
242+ B: "FLOAT[, ]",
243+ ):
244+ Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
245+ Y = op.Abs(Y1, outputs=['Y'])
246+ op.Identity(Y, outputs=["Y"])
247+ return Y
248+
249+
250+ def make_custom_LinearRegression(g: "GraphBuilder"):
251+ gr = GraphBuilder({'': 14}, as_function=True)
252+ x = gr.make_tensor_input('x')
253+ a = gr.make_tensor_input('a')
254+ b = gr.make_tensor_input('b')
255+ op = gr.op
256+ xa = op.MatMul(x, a, outputs=['xa'])
257+ y = op.Add(xa, b, outputs=['y'])
258+ gr.make_tensor_output(y)
259+ g.add_function(builder=gr)
260+ return gr
261+
262+
263+ def mm() -> "ModelProto":
264+ g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
265+ g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
266+ g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
267+ g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
268+ example(g.op, "X", "A", "B")
269+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
270+ make_custom_LinearRegression(g)
271+ model = g.to_onnx()
272+ return model
273+
274+
275+ model = mm()
276+ """
277+ )
278+ .strip ("\n " )
279+ .replace ("__SUFFIX__" ,", is_dimension=False, indexed=False" )
280+ )
281+ self .assertEqual (expected ,code .strip ("\n " ))
282+
177283
178284if __name__ == "__main__" :
179285unittest .main (verbosity = 2 )