Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9ccab70

Browse files
committed
fix code
1 parent5316964 commit9ccab70

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

‎_unittests/ut_translate_api/test_translate_classic.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _run(cls, code):
406406
importonnx.helper
407407
importonnx.numpy_helper
408408
importonnx_array_api.translate_api.make_helper
409-
importonnx.reference.custom_element_types
409+
importml_dtypes
410410

411411
deffrom_array_extended(tensor,name=None):
412412
dt=tensor.dtype
@@ -433,7 +433,7 @@ def from_array_extended(tensor, name=None):
433433
globs.update(onnx.helper.__dict__)
434434
globs.update(onnx.numpy_helper.__dict__)
435435
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
436-
globs.update(onnx.reference.custom_element_types.__dict__)
436+
globs.update(ml_dtypes.__dict__)
437437
globs["from_array_extended"]=from_array_extended
438438
locs= {}
439439
try:

‎onnx_array_api/translate_api/base_emitter.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
129129
ifvalue[0].type==AttributeProto.TENSOR:
130130
repl= {"bool":"bool_","object":"object_","str":"str_"}
131131
sdtype=repl.get(str(v.dtype),str(str(v.dtype)))
132+
package="np"ifhasattr(np,sdtype)else"ml_dtypes"
132133
return [], (
133-
f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), "
134+
f"from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
134135
f"name={value[0].name!r})"
135136
)
136137
ifisinstance(v, (int,float,list)):

‎onnx_array_api/translate_api/builder_emitter.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fromtypingimportAny,Dict,List
2+
importnumpyasnp
23
fromonnximportTensorProto
34
fromonnx.numpy_helperimportto_array
45
from .base_emitterimportBaseEmitter
@@ -135,7 +136,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
135136
val=to_array(init)
136137
stype=str(val.dtype).split(".")[-1]
137138
name=self._clean_result_name(init.name)
138-
rows.append(f"{name} = np.array({val.tolist()}, dtype=np.{stype})")
139+
package="np"ifhasattr(np,stype)else"ml_dtypes"
140+
rows.append(
141+
f"{name} = np.array({val.tolist()}, dtype={package}.{stype})"
142+
)
139143
returnrows
140144

141145
def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]:

‎onnx_array_api/translate_api/inner_emitter.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fromtypingimportAny,Dict,List,Optional,Tuple
2+
importnumpyasnp
23
fromonnximportAttributeProto
34
from ..annotationsimportELEMENT_TYPE_NAME
45
from .base_emitterimportBaseEmitter
@@ -105,7 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
105106
else:
106107
raiseNotImplementedError(f"Unexpected dtype={sdtype}.")
107108
else:
108-
sdtype=f"np.{sdtype}"
109+
sdtype=f"np.{sdtype}"ifhasattr(np,sdtype)elsef"ml_dtypes.{sdtype}"
109110

110111
return [
111112
"initializers.append(",
@@ -233,7 +234,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
233234
else:
234235
raiseNotImplementedError(f"Unexpected dtype={sdtype}.")
235236
else:
236-
sdtype=f"np.{sdtype}"
237+
sdtype=f"np.{sdtype}"ifhasattr(np,sdtype)elsef"ml_dtypes.{sdtype}"
237238
ifvalue.size<=16:
238239
return [
239240
"initializers.append(",

‎onnx_array_api/translate_api/light_emitter.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fromtypingimportAny,Dict,List
2+
importnumpyasnp
23
from ..annotationsimportELEMENT_TYPE_NAME
34
from .base_emitterimportBaseEmitter
45

@@ -43,8 +44,9 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
4344
value=kwargs["value"]
4445
repl= {"bool":"bool_","object":"object_","str":"str_"}
4546
sdtype=repl.get(str(value.dtype),str(str(value.dtype)))
47+
package="np"ifhasattr(np,sdtype)else"ml_dtypes"
4648
return [
47-
f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
49+
f"cst(np.array({value.tolist()}, dtype={package}.{sdtype}))",
4850
f"rename({name!r})",
4951
]
5052

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2026 Movatter.jp