Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitccf07e7

Browse files
committed
refactoring
1 parent4b5934c commitccf07e7

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

‎_doc/api/light_api.rst‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,6 @@ BaseEmitter
7878
..autoclass::onnx_array_api.light_api.base_emitter.BaseEmitter
7979
:members:
8080

81-
Emitter
82-
+++++++
83-
84-
..autoclass::onnx_array_api.light_api.emitter.Emitter
85-
:members:
86-
8781
EventType
8882
+++++++++
8983

@@ -96,6 +90,12 @@ InnerEmitter
9690
..autoclass::onnx_array_api.light_api.inner_emitter.InnerEmitter
9791
:members:
9892

93+
LightEmitter
94+
++++++++++++
95+
96+
..autoclass::onnx_array_api.light_api.emitter.LightEmitter
97+
:members:
98+
9999
Translater
100100
++++++++++
101101

‎_unittests/ut_light_api/test_translate.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
fromonnx.referenceimportReferenceEvaluator
77
fromonnx_array_api.ext_test_caseimportExtTestCase
88
fromonnx_array_api.light_apiimportstart,translate,g
9-
fromonnx_array_api.light_api.emitterimportEventType
9+
fromonnx_array_api.light_api.base_emitterimportEventType
1010

1111
OPSET_API=min(19,onnx_opset_version()-1)
1212

@@ -220,4 +220,5 @@ def test_aionnxml(self):
220220

221221

222222
if__name__=="__main__":
223+
TestTranslate().test_export_if()
223224
unittest.main(verbosity=2)

‎onnx_array_api/light_api/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
6767
:param single_line: as a single line or not
6868
:param api: API to export into,
6969
default is `"light"` and this is handle by class
70-
:class:`onnx_array_api.light_api.emitter.Emitter`,
70+
:class:`onnx_array_api.light_api.light_emitter.LightEmitter`,
7171
another value is `"onnx"` which is the inner API implemented
7272
in onnx package.
7373
:return: code

‎onnx_array_api/light_api/inner_emitter.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
fromtypingimportAny,Dict,List,Optional,Tuple
22
fromonnximportAttributeProto
33
from .annotationsimportELEMENT_TYPE_NAME
4-
from .emitterimportBaseEmitter
4+
from .base_emitterimportBaseEmitter
55
from .translateimportTranslater
66

77

‎onnx_array_api/light_api/emitter.py‎renamed to ‎onnx_array_api/light_api/light_emitter.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .base_emitterimportBaseEmitter
44

55

6-
classEmitter(BaseEmitter):
6+
classLightEmitter(BaseEmitter):
77
"""
88
Converts event into proper code.
99
"""
@@ -29,6 +29,9 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
2929
def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
3030
return ["to_onnx()"]
3131

32+
def_emit_to_onnx_function(self,**kwargs:Dict[str,Any])->List[str]:
33+
return []
34+
3235
def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]:
3336
return []
3437

‎onnx_array_api/light_api/translate.py‎

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
fromonnx.numpy_helperimportto_array
55
from ..referenceimportto_array_extended
66
from .base_emitterimportEventType
7-
from .emitterimportEmitter
7+
from .light_emitterimportLightEmitter
88

99

1010
classTranslater:
@@ -15,10 +15,10 @@ class Translater:
1515
def__init__(
1616
self,
1717
proto:Union[ModelProto,FunctionProto,GraphProto],
18-
emitter:Optional[Emitter]=None,
18+
emitter:Optional[LightEmitter]=None,
1919
):
2020
self.proto_=proto
21-
self.emitter=emitterorEmitter()
21+
self.emitter=emitterorLightEmitter()
2222

2323
def__repr__(self)->str:
2424
returnf"{self.__class__.__name__}(<{type(self.proto_)})"
@@ -43,6 +43,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
4343
sparse_initializers=self.proto_.graph.sparse_initializer
4444
attributes= []
4545
last_event=EventType.TO_ONNX_MODEL
46+
is_function=False
4647
elifisinstance(self.proto_, (FunctionProto,GraphProto)):
4748
inputs=self.proto_.input
4849
outputs=self.proto_.output
@@ -56,14 +57,17 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
5657
attributes= (
5758
self.proto_.attributeifhasattr(self.proto_,"attribute")else []
5859
)
59-
last_event=EventType.TO_ONNX_FUNCTION
60+
is_function=isinstance(self.proto_,FunctionProto)
61+
last_event= (
62+
EventType.TO_ONNX_FUNCTIONifis_functionelseEventType.TO_ONNX_MODEL
63+
)
6064
else:
6165
raiseValueError(f"Unexpected type{type(self.proto_)} for proto.")
6266

6367
ifsparse_initializers:
6468
raiseNotImplementedError("Sparse initializer not supported yet.")
6569

66-
ifisinstance(self.proto_,FunctionProto):
70+
ifis_function:
6771
rows.extend(
6872
self.emitter(
6973
EventType.BEGIN_FUNCTION,
@@ -85,7 +89,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
8589
)
8690

8791
foriininputs:
88-
ifisinstance(i,str):
92+
ifis_function:
8993
rows.extend(self.emitter(EventType.FUNCTION_INPUT,name=i))
9094
else:
9195
rows.extend(
@@ -100,7 +104,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
100104
)
101105
)
102106

103-
ifattributes:
107+
ifis_functionandattributes:
104108
rows.extend(
105109
self.emitter(EventType.FUNCTION_ATTRIBUTES,attributes=list(attributes))
106110
)
@@ -119,7 +123,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
119123
)
120124

121125
foroinoutputs:
122-
ifisinstance(o,str):
126+
ifis_function:
123127
rows.extend(self.emitter(EventType.FUNCTION_OUTPUT,name=o))
124128
else:
125129
rows.extend(
@@ -137,11 +141,10 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
137141
name=self.proto_.name
138142
else:
139143
name=self.proto_.graph.name
144+
140145
rows.extend(
141146
self.emitter(
142-
EventType.END_FUNCTION
143-
ifisinstance(self.proto_,FunctionProto)
144-
elseEventType.END_GRAPH,
147+
EventType.END_FUNCTIONifis_functionelseEventType.END_GRAPH,
145148
name=name,
146149
)
147150
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp