Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitac4acc6

Browse files
authored
Fix as_tensor in onnx_text_plot_tree (#101)
* Fix as_tensor* fix issues* lint* fix clean* atol* fix issues
1 parent96eb50e commitac4acc6

File tree

9 files changed

+78
-86
lines changed

9 files changed

+78
-86
lines changed

‎CHANGELOGS.rst‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.2
5+
+++++
6+
7+
*:pr:`101`: fix as_tensor in onnx_text_plot_tree
8+
49
0.3.1
510
+++++
611

‎_unittests/ut_light_api/test_backend_export.py‎

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
make_opsetid,
2020
make_tensor_value_info,
2121
)
22-
fromonnx.reference.op_runimportto_array_extended
22+
23+
try:
24+
fromonnx.reference.op_runimportto_array_extended
25+
exceptImportError:
26+
fromonnx.numpy_helperimportto_arrayasto_array_extended
2327
fromonnx.numpy_helperimportfrom_array,to_array
2428
fromonnx.backend.baseimportDevice,DeviceType
2529
fromonnx_array_api.referenceimportExtendedReferenceEvaluator
@@ -240,7 +244,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
240244
raiseNotImplementedError("Unable to run the model node by node.")
241245

242246

243-
backend_test=onnx.backend.test.BackendTest(ExportBackend,__name__)
247+
dft_atol=1e-3ifsys.platform!="linux"else1e-5
248+
backend_test=onnx.backend.test.BackendTest(
249+
ExportBackend,
250+
__name__,
251+
test_kwargs={
252+
"test_dft": {"atol":dft_atol},
253+
"test_dft_axis": {"atol":dft_atol},
254+
"test_dft_axis_opset19": {"atol":dft_atol},
255+
"test_dft_inverse": {"atol":dft_atol},
256+
"test_dft_inverse_opset19": {"atol":dft_atol},
257+
"test_dft_opset19": {"atol":dft_atol},
258+
},
259+
)
244260

245261
# The following tests are too slow with the reference implementation (Conv).
246262
backend_test.exclude(

‎_unittests/ut_reference/test_backend_extended_reference_evaluator.py‎

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
importos
22
importplatform
3+
importsys
34
importunittest
45
fromtypingimportAny
56
importnumpy
@@ -78,10 +79,21 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
7879
raiseNotImplementedError("Unable to run the model node by node.")
7980

8081

82+
dft_atol=1e-3ifsys.platform!="linux"else1e-5
8183
backend_test=onnx.backend.test.BackendTest(
82-
ExtendedReferenceEvaluatorBackend,__name__
84+
ExtendedReferenceEvaluatorBackend,
85+
__name__,
86+
test_kwargs={
87+
"test_dft": {"atol":dft_atol},
88+
"test_dft_axis": {"atol":dft_atol},
89+
"test_dft_axis_opset19": {"atol":dft_atol},
90+
"test_dft_inverse": {"atol":dft_atol},
91+
"test_dft_inverse_opset19": {"atol":dft_atol},
92+
"test_dft_opset19": {"atol":dft_atol},
93+
},
8394
)
8495

96+
8597
ifos.getenv("APPVEYOR"):
8698
backend_test.exclude("(test_vgg19|test_zfnet)")
8799
ifplatform.architecture()[0]=="32bit":

‎azure-pipelines.yml‎

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,63 +93,6 @@ jobs:
9393
python -m pytest
9494
displayName: 'Runs Unit Tests'
9595
96-
-job:'TestLinuxArrayApi'
97-
pool:
98-
vmImage:'ubuntu-latest'
99-
strategy:
100-
matrix:
101-
Python310-Linux:
102-
python.version:'3.10'
103-
maxParallel:3
104-
105-
steps:
106-
-task:UsePythonVersion@0
107-
inputs:
108-
versionSpec:'$(python.version)'
109-
architecture:'x64'
110-
-script:sudo apt-get update
111-
displayName:'AptGet Update'
112-
-script:python -m pip install --upgrade pip setuptools wheel
113-
displayName:'Install tools'
114-
-script:pip install -r requirements.txt
115-
displayName:'Install Requirements'
116-
-script:pip install onnxruntime
117-
displayName:'Install onnxruntime'
118-
-script:python setup.py install
119-
displayName:'Install onnx_array_api'
120-
-script:|
121-
git clone https://github.com/data-apis/array-api-tests.git
122-
displayName: 'clone array-api-tests'
123-
-script:|
124-
cd array-api-tests
125-
git submodule update --init --recursive
126-
cd ..
127-
displayName: 'get submodules for array-api-tests'
128-
-script:pip install -r array-api-tests/requirements.txt
129-
displayName:'Install Requirements dev'
130-
-script:|
131-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
132-
cd array-api-tests
133-
displayName: 'Set API'
134-
-script:|
135-
python -m pip freeze
136-
displayName: 'pip freeze'
137-
-script:|
138-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
139-
cd array-api-tests
140-
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain
141-
displayName: "numpy test_creation_functions.py"
142-
# - script: |
143-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
144-
# cd array-api-tests
145-
# python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
146-
# displayName: "ort test_creation_functions.py"
147-
#- script: |
148-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
149-
# cd array-api-tests
150-
# python -m pytest -x array_api_tests
151-
# displayName: "all tests"
152-
15396
-job:'TestLinux'
15497
pool:
15598
vmImage:'ubuntu-latest'

‎onnx_array_api/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__="0.3.1"
5+
__version__="0.3.2"
66
__author__="Xavier Dupré"

‎onnx_array_api/plotting/text_plot.py‎

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464
self.nodes_missing_value_tracks_true=None
6565
fork,vinatts.items():
6666
ifk.startswith("nodes"):
67-
setattr(self,k,v[i])
67+
ifk.endswith("_as_tensor"):
68+
setattr(self,k.replace("_as_tensor",""),v[i])
69+
else:
70+
setattr(self,k,v[i])
6871
self.depth=0
6972
self.true_false=""
7073
self.targets= []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123
]
121124
fork,vinatts.items():
122125
ifk.startswith(prefix):
123-
if"classlabels"ink:
124-
short[k]=list(v)
125-
else:
126-
short[k]= [v[i]foriinidx]
126+
short[k]=list(v)if"classlabels"inkelse [v[i]foriinidx]
127127

128128
nodes=OrderedDict()
129129
foriinrange(len(short["nodes_treeids"])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132
foriinrange(len(short[f"{prefix}_treeids"])):
133133
idn=short[f"{prefix}_nodeids"][i]
134134
node=nodes[idn]
135-
node.append_target(
136-
tid=short[f"{prefix}_ids"][i],weight=short[f"{prefix}_weights"][i]
137-
)
135+
key=f"{prefix}_weights"
136+
ifkeynotinshort:
137+
key=f"{prefix}_weights_as_tensor"
138+
node.append_target(tid=short[f"{prefix}_ids"][i],weight=short[key][i])
138139

139140
defiterate(nodes,node,depth=0,true_false=""):
140141
node.depth=depth

‎onnx_array_api/profiling.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def add_rows(rows, d):
438438
ifverboseandfLOGisnotNone:
439439
fLOG(
440440
"[pstats] %s=%r"
441-
% ((clean_text(k[0].replace("\\","/")),)+k[1:],v)
441+
% ((clean_text(k[0].replace("\\","/")),*k[1:]),v)
442442
)
443443
iflen(v)<5:
444444
continue

‎onnx_array_api/reference/__init__.py‎

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
importnumpyasnp
33
fromonnximportTensorProto
44
fromonnx.numpy_helperimportfrom_arrayasonnx_from_array
5-
fromonnx.reference.ops.op_castimport (
6-
bfloat16,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
12-
fromonnx.reference.op_runimportto_array_extended
5+
6+
try:
7+
fromonnx.reference.ops.op_castimport (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
exceptImportError:
15+
bfloat16=None
16+
try:
17+
fromonnx.reference.op_runimportto_array_extended
18+
exceptImportError:
19+
fromonnx.numpy_helperimportto_arrayasto_array_extended
1320
from .evaluatorimportExtendedReferenceEvaluator
1421
from .evaluator_yieldimport (
1522
DistanceExecution,
@@ -28,6 +35,8 @@ def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorP
2835
:param name: name
2936
:return: TensorProto
3037
"""
38+
ifbfloat16isNone:
39+
returnonnx_from_array(tensor,name)
3140
dt=tensor.dtype
3241
ifdt==float8e4m3fnanddt.descr[0][0]=="e4m3fn":
3342
to=TensorProto.FLOAT8E4M3FN

‎onnx_array_api/reference/ops/op_cast_like.py‎

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
fromonnx.helperimportnp_dtype_to_tensor_dtype
22
fromonnx.onnx_pbimportTensorProto
33
fromonnx.reference.op_runimportOpRun
4-
fromonnx.reference.ops.op_castimport (
5-
bfloat16,
6-
cast_to,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
4+
fromonnx.reference.ops.op_castimportcast_to
5+
6+
try:
7+
fromonnx.reference.ops.op_castimport (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
exceptImportError:
15+
bfloat16=None
1216

1317

1418
def_cast_like(x,y,saturate):
19+
ifbfloat16isNone:
20+
return (cast_to(x,y.dtype,saturate),)
1521
ify.dtype==bfloat16andy.dtype.descr[0][0]=="bfloat16":
1622
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
1723
to=TensorProto.BFLOAT16

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp