Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit538f844

Browse files
committed
Replaces long initiliazer by rando values
1 parenta868dd3 commit538f844

File tree

4 files changed

+136
-2
lines changed

4 files changed

+136
-2
lines changed

‎_doc/api/translate_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ InnerEmitter
3939
..autoclass::onnx_array_api.translate_api.inner_emitter.InnerEmitter
4040
:members:
4141

42+
InnerEmitterShortInitializer
43+
++++++++++++++++++++++++++++
44+
45+
..autoclass::onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer
46+
:members:
47+
4248
LightEmitter
4349
++++++++++++
4450

‎_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,75 @@ def test_transpose(self):
178178
self.maxDiff=None
179179
self.assertEqual(expected,code)
180180

181+
deftest_transpose_short(self):
182+
onx= (
183+
start(opset=19)
184+
.vin("X")
185+
.reshape((-1,1))
186+
.Transpose(perm=[1,0])
187+
.rename("Y")
188+
.vout()
189+
.to_onnx()
190+
)
191+
self.assertIsInstance(onx,ModelProto)
192+
self.assertIn("Transpose",str(onx))
193+
ref=ReferenceEvaluator(onx)
194+
a=np.arange(10).astype(np.float32)
195+
got=ref.run(None, {"X":a})[0]
196+
self.assertEqualArray(a.reshape((-1,1)).T,got)
197+
198+
code=translate(onx,api="onnx-short")
199+
expected=dedent(
200+
"""
201+
opset_imports = [
202+
make_opsetid('', 19),
203+
]
204+
inputs = []
205+
outputs = []
206+
nodes = []
207+
initializers = []
208+
sparse_initializers = []
209+
functions = []
210+
initializers.append(
211+
from_array(
212+
np.array([-1, 1], dtype=np.int64),
213+
name='r'
214+
)
215+
)
216+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
217+
nodes.append(
218+
make_node_extended(
219+
'Reshape',
220+
['X', 'r'],
221+
['r0_0']
222+
)
223+
)
224+
nodes.append(
225+
make_node_extended(
226+
'Transpose',
227+
['r0_0'],
228+
['Y'],
229+
perm=[1, 0]
230+
)
231+
)
232+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
233+
graph = make_graph(
234+
nodes,
235+
'light_api',
236+
inputs,
237+
outputs,
238+
initializers,
239+
sparse_initializer=sparse_initializers,
240+
)
241+
model = make_model(
242+
graph,
243+
functions=functions,
244+
opset_imports=opset_imports
245+
)"""
246+
).strip("\n")
247+
self.maxDiff=None
248+
self.assertEqual(expected,code)
249+
181250
deftest_topk_reverse(self):
182251
onx= (
183252
start(opset=19)

‎onnx_array_api/translate_api/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
fromonnximportModelProto
22
from .translateimportTranslater
3-
from .inner_emitterimportInnerEmitter
3+
from .inner_emitterimportInnerEmitter,InnerEmitterShortInitializer
44
from .builder_emitterimportBuilderEmitter
55

66

@@ -16,7 +16,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1616
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1717
another value is `"onnx"` which is the inner API implemented
1818
in onnx package, `"builder"` follows the syntax for the
19-
class :class:`onnx_array_api.graph_api.GraphBuilder`
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`,
20+
`"onnx-short"` replaces long initializer with random values
2021
:return: code
2122
2223
.. runpython::
@@ -84,6 +85,9 @@ class :class:`onnx_array_api.graph_api.GraphBuilder`
8485
ifapi=="onnx":
8586
tr=Translater(proto,emitter=InnerEmitter())
8687
returntr.export(as_str=True)
88+
ifapi=="onnx-short":
89+
tr=Translater(proto,emitter=InnerEmitterShortInitializer())
90+
returntr.export(as_str=True)
8791
ifapi=="builder":
8892
tr=Translater(proto,emitter=BuilderEmitter())
8993
returntr.export(as_str=True)

‎onnx_array_api/translate_api/inner_emitter.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
106106
raiseNotImplementedError(f"Unexpected dtype={sdtype}.")
107107
else:
108108
sdtype=f"np.{sdtype}"
109+
109110
return [
110111
"initializers.append(",
111112
f"{fra}(",
@@ -209,3 +210,57 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
209210
")",
210211
]
211212
returnlines
213+
214+
215+
classInnerEmitterShortInitializer(InnerEmitter):
216+
"""
217+
Converts event into proper code.
218+
Initializer are replaced by random values if too big.
219+
"""
220+
221+
def_emit_initializer(self,**kwargs:Dict[str,Any])->List[str]:
222+
name=kwargs["name"]
223+
value=kwargs["value"]
224+
repl= {"bool":"bool_","object":"object_","str":"str_"}
225+
fra="from_array"
226+
sdtype=repl.get(str(value.dtype),str(value.dtype))
227+
ifsdtype.startswith("("):
228+
fromonnx.reference.custom_element_typesimportfloat8e4m3fn
229+
230+
ifsdtype==str(float8e4m3fn):
231+
sdtype="float8e4m3fn"
232+
fra="from_array_extended"
233+
else:
234+
raiseNotImplementedError(f"Unexpected dtype={sdtype}.")
235+
else:
236+
sdtype=f"np.{sdtype}"
237+
ifvalue.size<=16:
238+
return [
239+
"initializers.append(",
240+
f"{fra}(",
241+
f" np.array({value.tolist()}, dtype={sdtype}),",
242+
f" name={name!r}",
243+
" )",
244+
")",
245+
]
246+
if"int"insdtype:
247+
return [
248+
f"value = np.random.randint(0, 10, size={value.shape})"
249+
f".astype({sdtype})",
250+
"initializers.append(",
251+
f"{fra}(",
252+
f" np.array({value.tolist()}, dtype={sdtype}),",
253+
f" name={name!r}",
254+
" )",
255+
")",
256+
]
257+
return [
258+
f"value = np.random.randn({', '.join(map(str,value.shape))})"
259+
f".astype({sdtype})",
260+
"initializers.append(",
261+
f"{fra}(",
262+
f" np.array({value.tolist()}, dtype={sdtype}),",
263+
f" name={name!r}",
264+
" )",
265+
")",
266+
]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp