Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbaa25d8

Browse files
committed
fix initializer
1 parent092dfa2 commitbaa25d8

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

‎_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
fromtextwrapimportdedent
33
importnumpyasnp
44
fromonnximportModelProto,TensorProto
5+
fromonnx.checkerimportcheck_model
56
fromonnx.defsimportonnx_opset_version
67
fromonnx.referenceimportReferenceEvaluator
78
fromonnx_array_api.ext_test_caseimportExtTestCase
@@ -39,7 +40,7 @@ def light_api(
3940
4041
g = GraphBuilder({'': 19})
4142
g.make_tensor_input("X", TensorProto.FLOAT, ())
42-
light_api(g.op,X)
43+
light_api(g.op,"X")
4344
g.make_tensor_output("Y", TensorProto.FLOAT, ())
4445
model = g.to_onnx()
4546
"""
@@ -78,18 +79,43 @@ def test_zdoc(self):
7879
code=translate(onx,api="builder")
7980
expected=dedent(
8081
"""
81-
(
82-
start()
83-
.vin("X")
84-
.reshape((-1, 1))
85-
.Transpose(perm=[1, 0])
86-
.rename("Y")
87-
.vout()
88-
.to_onnx()
89-
)"""
82+
def light_api(
83+
op: "GraphBuilder",
84+
X: "FLOAT[]",
85+
):
86+
r = np.array([-1, 1], dtype=np.int64)
87+
r0_0 = op.Reshape(X, r)
88+
Y = op.Transpose(r0_0, perm=[1, 0])
89+
op.Identity(Y, outputs=["Y"])
90+
return Y
91+
92+
g = GraphBuilder({'': 21})
93+
g.make_tensor_input("X", TensorProto.FLOAT, ())
94+
light_api(g.op, "X")
95+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
96+
model = g.to_onnx()
97+
"""
9098
).strip("\n")
9199
self.maxDiff=None
92-
self.assertEqual(expected,code)
100+
self.assertEqual(expected,code.strip("\n"))
101+
102+
deflight_api(
103+
op:"GraphBuilder",
104+
X:"FLOAT[]",# noqa: F722
105+
):
106+
r=np.array([-1,1],dtype=np.int64)
107+
r0_0=op.Reshape(X,r)
108+
Y=op.Transpose(r0_0,perm=[1,0])
109+
op.Identity(Y,outputs=["Y"])
110+
returnY
111+
112+
g=GraphBuilder({"":21})
113+
X=g.make_tensor_input("X",TensorProto.FLOAT, ())
114+
light_api(g.op,X)
115+
g.make_tensor_output("Y",TensorProto.FLOAT, ())
116+
model=g.to_onnx()
117+
self.assertNotEmpty(model)
118+
check_model(model)
93119

94120

95121
if__name__=="__main__":

‎onnx_array_api/graph_api/graph_builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __getattr__(self, name):
119119
exceptAttributeErrorase:
120120
raiseAttributeError(f"Unable to access attribute{name!r}.")frome
121121

122+
defInitializer(
123+
self,init:Union[TensorProto,np.ndarray],name:Optional[str]=None
124+
)->str:
125+
"""
126+
Creates an initializer.
127+
128+
:param init: value
129+
:param name: name if value is not a TensorProto
130+
:return: its name
131+
"""
132+
returnself.builder.make_initializer(init,name=name,exists=True)
133+
122134
defmake_node(
123135
self,
124136
op_type:str,

‎onnx_array_api/translate_api/builder_emitter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
fromtypingimportAny,Dict,List
22
fromonnximportTensorProto
3+
fromonnx.numpy_helperimportto_array
34
from .base_emitterimportBaseEmitter
45

56
_types= {
@@ -31,7 +32,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3132
return []
3233

3334
def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
34-
inps=", ".join(["g.op",*self.inputs])
35+
inps=", ".join(["g.op",*[f'"{i}"'foriinself.inputs]])
3536
inputs= []
3637
forinp,stype,shapeinself.inputs_full_:
3738
inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype},{shape})')
@@ -64,7 +65,14 @@ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
6465
return []
6566

6667
def_emit_initializer(self,**kwargs:Dict[str,Any])->List[str]:
67-
assertFalse,f"not implemented yet with{kwargs}"
68+
init=kwargs["init"]
69+
ifisinstance(init,TensorProto):
70+
assert (
71+
kwargs["name"]==init.name
72+
),f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}"
73+
self.inits.append(init)
74+
return []
75+
raiseAssertionError(f"Unsupported type for an initializer{type(init)}")
6876

6977
def_emit_input(self,**kwargs:Dict[str,Any])->List[str]:
7078
name=kwargs["name"]
@@ -90,6 +98,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
9098
foriinself.inputs_full:
9199
rows.append(f"{i},")
92100
rows.append("):")
101+
forinitinself.inits:
102+
val=to_array(init)
103+
stype=str(val.dtype).split(".")[-1]
104+
rows.append(f"{init.name} = np.array({val.tolist()}, dtype=np.{stype})")
93105
returnrows
94106

95107
def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp