Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite002bf4

Browse files
authored
Improves onnx_simple_text_plot (#91)
* Improves onnx_simple_text_plot* add doc_string* improve display* add complex* add missing line* complex* complex* fix unwanted code
1 parentd83ff4e commite002bf4

File tree

11 files changed

+78
-9
lines changed

11 files changed

+78
-9
lines changed

‎onnx_array_api/_helpers.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any):
4040
dt=TensorProto.INT64
4141
elifdtypeisfloat:
4242
dt=TensorProto.DOUBLE
43+
elifdtype==np.complex64:
44+
dt=TensorProto.COMPLEX64
45+
elifdtype==np.complex128:
46+
dt=TensorProto.COMPLEX128
4347
else:
4448
raiseKeyError(f"Unable to guess type for dtype={dtype}.")# noqa: B904
4549
returndt

‎onnx_array_api/annotations.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6464
np.uint64:TensorProto.UINT64,
6565
np.bool_:TensorProto.BOOL,
6666
np.str_:TensorProto.STRING,
67+
np.complex64:TensorProto.COMPLEX64,
68+
np.complex128:TensorProto.COMPLEX128,
6769
}
6870

6971

‎onnx_array_api/array_api/__init__.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _finfo(dtype):
4747
continue
4848
ifisinstance(v, (np.float32,np.float64,np.float16)):
4949
d[k]=float(v)
50+
elifisinstance(v, (np.complex128,np.complex64)):
51+
d[k]=complex(v)
5052
else:
5153
d[k]=v
5254
d["dtype"]=DType(np_dtype_to_tensor_dtype(dt))
@@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor):
124126
module.float16=DType(TensorProto.FLOAT16)
125127
module.float32=DType(TensorProto.FLOAT)
126128
module.float64=DType(TensorProto.DOUBLE)
129+
module.complex64=DType(TensorProto.COMPLEX64)
130+
module.complex128=DType(TensorProto.COMPLEX128)
127131
module.int8=DType(TensorProto.INT8)
128132
module.int16=DType(TensorProto.INT16)
129133
module.int32=DType(TensorProto.INT32)

‎onnx_array_api/array_api/_onnx_common.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def asarray(
9393
v=TEagerTensor(va)
9494
elifisinstance(a,float):
9595
v=TEagerTensor(np.array(a,dtype=np.float64))
96+
elifisinstance(a,complex):
97+
v=TEagerTensor(np.array(a,dtype=np.complex128))
9698
elifisinstance(a,bool):
9799
v=TEagerTensor(np.array(a,dtype=np.bool_))
98100
elifisinstance(a,str):

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,6 @@ def make_nodes(
536536
ifisinstance(value,TensorProto):
537537
value.name=name
538538
self.initializers_dict[name]=value
539-
540539
self.constants_[name]=None
541540
self.set_shape(name,builder._known_shapes[init])
542541
self.set_type(name,builder._known_types[init])

‎onnx_array_api/npx/npx_jit_eager.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, .
167167
f"to the attribute list, v={v}."
168168
)
169169
res.append(v.key)
170-
elifisinstance(v, (int,float,bool,DType)):
170+
elifisinstance(v, (int,float,bool,complex,DType)):
171171
ifivinself.kwargs_to_input_:
172172
res.append(self.kwargs_to_input_[iv])
173173
res.append(type(v))
@@ -204,7 +204,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, .
204204
ifkinself.kwargs_to_input_:
205205
res.append(type(v))
206206
res.append(v)
207-
elifisinstance(v, (int,float,str,type,bool,DType)):
207+
elifisinstance(v, (int,float,str,type,bool,complex,DType)):
208208
res.append(k)
209209
res.append(type(v))
210210
res.append(v)

‎onnx_array_api/npx/npx_numpy_tensors.py‎

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,35 @@ def __float__(self):
265265
DType(TensorProto.DOUBLE),
266266
DType(TensorProto.FLOAT16),
267267
DType(TensorProto.BFLOAT16),
268+
DType(TensorProto.COMPLEX64),
269+
DType(TensorProto.COMPLEX128),
268270
}:
269271
raiseTypeError(
270272
f"Conversion to float only works for float scalar, "
271273
f"not for dtype={self.dtype}."
272274
)
273275
returnfloat(self._tensor)
274276

277+
def__complex__(self):
278+
"Implicit conversion to complex."
279+
ifself.shape:
280+
raiseValueError(
281+
f"Conversion to bool only works for scalar, not for{self!r}."
282+
)
283+
ifself.dtypenotin {
284+
DType(TensorProto.FLOAT),
285+
DType(TensorProto.DOUBLE),
286+
DType(TensorProto.FLOAT16),
287+
DType(TensorProto.BFLOAT16),
288+
DType(TensorProto.COMPLEX64),
289+
DType(TensorProto.COMPLEX128),
290+
}:
291+
raiseTypeError(
292+
f"Conversion to float only works for float scalar, "
293+
f"not for dtype={self.dtype}."
294+
)
295+
returncomplex(self._tensor)
296+
275297
def__iter__(self):
276298
"""
277299
The :epkg:`Array API` does not define this function (2022/12).

‎onnx_array_api/npx/npx_var.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,13 +1171,17 @@ def __init__(self, cst: Any):
11711171
Var.__init__(self,np.array(cst,dtype=np.int64),op="Identity")
11721172
elifisinstance(cst,float):
11731173
Var.__init__(self,np.array(cst,dtype=np.float64),op="Identity")
1174+
elifisinstance(cst,complex):
1175+
Var.__init__(self,np.array(cst,dtype=np.complex128),op="Identity")
11741176
elifisinstance(cst,list):
11751177
ifall(isinstance(t,bool)fortincst):
11761178
Var.__init__(self,np.array(cst,dtype=np.bool_),op="Identity")
11771179
elifall(isinstance(t, (int,bool))fortincst):
11781180
Var.__init__(self,np.array(cst,dtype=np.int64),op="Identity")
11791181
elifall(isinstance(t, (float,int,bool))fortincst):
11801182
Var.__init__(self,np.array(cst,dtype=np.float64),op="Identity")
1183+
elifall(isinstance(t, (float,int,bool,complex))fortincst):
1184+
Var.__init__(self,np.array(cst,dtype=np.complex128),op="Identity")
11811185
else:
11821186
raiseValueError(
11831187
f"Unable to convert cst (type={type(cst)}), value={cst}."

‎onnx_array_api/plotting/text_plot.py‎

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,10 @@ def str_node(indent, node):
824824
rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}")
825825
ifhasattr(model,"graph"):
826826
ifmodel.doc_string:
827-
rows.append(f"doc_string:{model.doc_string}")
827+
iflen(model.doc_string)<55:
828+
rows.append(f"doc_string:{model.doc_string}")
829+
else:
830+
rows.append(f"doc_string:{model.doc_string[:55]}...")
828831
main_model=model
829832
model=model.graph
830833
else:
@@ -861,9 +864,16 @@ def str_node(indent, node):
861864
else:
862865
content=""
863866
line_name_new[init.name]=len(rows)
867+
ifinit.doc_string:
868+
t= (
869+
f"init: name={init.name!r} type={_get_type(init)} "
870+
f"shape={_get_shape(init)}{content}"
871+
)
872+
rows.append(f"{t}{' '*max(0,70-len(t))}--{init.doc_string}")
873+
continue
864874
rows.append(
865-
"init: name=%r type=%r shape=%r%s"
866-
% (init.name,_get_type(init),_get_shape(init),content)
875+
f"init: name={init.name!r} type={_get_type(init)}"
876+
f"shape={_get_shape(init)}{content}"
867877
)
868878
iflevel==0:
869879
rows.append("----- main graph ----")
@@ -1044,7 +1054,10 @@ def _mark_link(rows, lengths, r1, r2, d):
10441054
forfctinmain_model.functions:
10451055
rows.append(f"----- function name={fct.name} domain={fct.domain}")
10461056
iffct.doc_string:
1047-
rows.append(f"----- doc_string:{fct.doc_string}")
1057+
iflen(fct.doc_string)<55:
1058+
rows.append(f"----- doc_string:{fct.doc_string}")
1059+
else:
1060+
rows.append(f"----- doc_string:{fct.doc_string[:55]}...")
10481061
res=onnx_simple_text_plot(
10491062
fct,
10501063
verbose=verbose,
@@ -1103,10 +1116,19 @@ def onnx_text_plot_io(model, verbose=False, att_display=None):
11031116
)
11041117
# initializer
11051118
forinitinmodel.initializer:
1119+
1120+
ifinit.doc_string:
1121+
t= (
1122+
f"init: name={init.name!r} type={_get_type(init)} "
1123+
f"shape={_get_shape(init)}"
1124+
)
1125+
rows.append(f"{t}{' '*max(0,70-len(t))}--{init.doc_string}")
1126+
continue
11061127
rows.append(
1107-
"init: name=%r type=%r shape=%r"
1108-
% (init.name,_get_type(init),_get_shape(init))
1128+
f"init: name={init.name!r} type={_get_type(init)}"
1129+
f"shape={_get_shape(init)}"
11091130
)
1131+
11101132
# outputs
11111133
foroutinmodel.output:
11121134
rows.append(

‎onnx_array_api/reference/evaluator_yield.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def generate_input(info: ValueInfoProto) -> np.ndarray:
485485
return (value.astype(np.float16)/p).astype(np.float16).reshape(new_shape)
486486
ifelem_type==TensorProto.DOUBLE:
487487
return (value.astype(np.float64)/p).astype(np.float64).reshape(new_shape)
488+
ifelem_type==TensorProto.COMPLEX64:
489+
return (value.astype(np.complex64)/p).astype(np.complex64).reshape(new_shape)
490+
ifelem_type==TensorProto.COMPLEX128:
491+
return (
492+
(value.astype(np.complex128)/p).astype(np.complex128).reshape(new_shape)
493+
)
488494
raiseRuntimeError(f"Unexpected element_type{elem_type} for info={info}")
489495

490496

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp