Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite29df50

Browse files
committed
Improves translation to GraphBuilder
1 parent689cc6f commite29df50

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

‎CHANGELOGS.rst‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.1
5+
+++++
6+
7+
*:pr:`94`: improves translation to GraphBuilder
8+
49
0.3.0
510
+++++
611

‎onnx_array_api/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__="0.3.0"
5+
__version__="0.3.1"
66
__author__="Xavier Dupré"

‎onnx_array_api/translate_api/builder_emitter.py‎

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter):
2020
Converts event into proper code.
2121
"""
2222

23+
def__init__(self,make_model_function:str=""):
24+
super().__init__()
25+
self.make_model_function=make_model_function
26+
2327
defjoin(self,rows:List[str],single_line:bool=False)->str:
2428
"Join the rows"
2529
assert (
@@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2933

3034
def_emit_start(self,**kwargs:Dict[str,Any])->List[str]:
3135
self.opsets=kwargs.get("opsets", {})
36+
self.ir_version=kwargs.get("ir_version",None)
3237
return []
3338

3439
def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
@@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4348
)
4449
rows= [
4550
"",
46-
f"g = GraphBuilder({self.opsets})",
51+
(
52+
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
53+
ifself.ir_version
54+
elsef"GraphBuilder({self.opsets})"
55+
),
4756
*inputs,
4857
f"{self.name}({inps})",
4958
*outputs,
5059
"model = g.to_onnx()",
5160
]
61+
ifself.make_model_function:
62+
rows= [
63+
"",
64+
"",
65+
f'def{self.make_model_function}() -> "ModelProto":',
66+
*[" "+_for_inrows[1:]],
67+
" return model",
68+
"",
69+
"",
70+
f"model ={self.make_model_function}()",
71+
]
5272
returnrows
5373

5474
def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]:
@@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
7999
itype=kwargs.get("elem_type",0)
80100
shape=kwargs.get("shape",None)
81101
ifitype==0:
82-
inp="X"
102+
inp=nameor"X"
83103
else:
84104
ifshapeisNone:
85-
inp=f'X: "{_itype_to_string(itype)}"'
105+
inp=f'{name}: "{_itype_to_string(itype)}"'
86106
else:
87-
inp=f'X: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
107+
inp= (
108+
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
109+
)
88110
self.inputs_full.append(inp)
89111
self.inputs.append(name)
90112
self.inputs_full_.append((name,_itype_to_string(itype),shape))

‎onnx_array_api/translate_api/translate.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
3535
last_event=None
3636
ifisinstance(self.proto_,ModelProto):
3737
opsets= {d.domain:d.versionfordinself.proto_.opset_import}
38-
rows.extend(self.emitter(EventType.START,opsets=opsets))
38+
rows.extend(
39+
self.emitter(
40+
EventType.START,opsets=opsets,ir_version=self.proto_.ir_version
41+
)
42+
)
3943
inputs=self.proto_.graph.input
4044
outputs=self.proto_.graph.output
4145
nodes=self.proto_.graph.node

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp