Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb19f2fd

Browse files
committed
fix issues
1 parentfd5fc82 commitb19f2fd

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

‎_unittests/ut_light_api/test_backend_export.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
make_opsetid,
2020
make_tensor_value_info,
2121
)
22-
fromonnx.reference.op_runimportto_array_extended
22+
23+
try:
24+
fromonnx.reference.op_runimportto_array_extended
25+
exceptImportError:
26+
fromonnx.numpy_helperimportto_arrayasto_array_extended
2327
fromonnx.numpy_helperimportfrom_array,to_array
2428
fromonnx.backend.baseimportDevice,DeviceType
2529
fromonnx_array_api.referenceimportExtendedReferenceEvaluator

‎onnx_array_api/profiling.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def add_rows(rows, d):
438438
ifverboseandfLOGisnotNone:
439439
fLOG(
440440
"[pstats] %s=%r"
441-
% ((clean_text(k[0].replace("\\","/")),)+k[1:],v)
441+
% (clean_text(k[0].replace("\\","/"),*k[1:]),v)
442442
)
443443
iflen(v)<5:
444444
continue

‎onnx_array_api/reference/__init__.py‎

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
importnumpyasnp
33
fromonnximportTensorProto
44
fromonnx.numpy_helperimportfrom_arrayasonnx_from_array
5-
fromonnx.reference.ops.op_castimport (
6-
bfloat16,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
5+
6+
try:
7+
fromonnx.reference.ops.op_castimport (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
exceptImportError:
15+
bfloat16=None
1216
fromonnx.reference.op_runimportto_array_extended
1317
from .evaluatorimportExtendedReferenceEvaluator
1418
from .evaluator_yieldimport (
@@ -28,6 +32,8 @@ def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorP
2832
:param name: name
2933
:return: TensorProto
3034
"""
35+
ifbfloat16isNone:
36+
returnonnx_from_array(tensor,name)
3137
dt=tensor.dtype
3238
ifdt==float8e4m3fnanddt.descr[0][0]=="e4m3fn":
3339
to=TensorProto.FLOAT8E4M3FN

‎onnx_array_api/reference/ops/op_cast_like.py‎

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
fromonnx.helperimportnp_dtype_to_tensor_dtype
22
fromonnx.onnx_pbimportTensorProto
33
fromonnx.reference.op_runimportOpRun
4-
fromonnx.reference.ops.op_castimport (
5-
bfloat16,
6-
cast_to,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
4+
fromonnx.reference.ops.op_castimportcast_to
5+
6+
try:
7+
fromonnx.reference.ops.op_castimport (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
exceptImportError:
15+
bfloat16=None
1216

1317

1418
def_cast_like(x,y,saturate):
19+
ifbfloat16isNone:
20+
return (cast_to(x,y.dtype,saturate),)
1521
ify.dtype==bfloat16andy.dtype.descr[0][0]=="bfloat16":
1622
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
1723
to=TensorProto.BFLOAT16

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp