Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commiteb106e2

Browse files
authored
Export evaluator type in compare_onnx_execution (#93)
* Export evaluator type in compare_onnx_execution* doc* doc
1 parent07c3683 commiteb106e2

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

‎LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2023-2024, Xavier Dupré
1+
Copyright (c) 2023-2025, Xavier Dupré
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

‎onnx_array_api/reference/evaluator_yield.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
fromenumimportIntEnum
44
importnumpyasnp
55
fromonnximportModelProto,TensorProto,ValueInfoProto,load
6+
fromonnx.referenceimportReferenceEvaluator
67
fromonnx.helperimporttensor_dtype_to_np_dtype
78
fromonnx.shape_inferenceimportinfer_shapes
89
from .importto_array_extended
@@ -138,17 +139,23 @@ class YieldEvaluator:
138139
139140
:param onnx_model: model to run
140141
:param recursive: dig into subgraph and functions as well
142+
:param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
143+
<onnx_array_api.reference.ExtendedReferenceEvaluator>`
141144
"""
142145

143146
def__init__(
144147
self,
145148
onnx_model:ModelProto,
146149
recursive:bool=False,
147-
cls=ExtendedReferenceEvaluator,
150+
cls:Optional[type[ExtendedReferenceEvaluator]]=None,
148151
):
149152
assertnotrecursive,"recursive=True is not yet implemented"
150153
self.onnx_model=onnx_model
151-
self.evaluator=cls(onnx_model)ifclsisnotNoneelseNone
154+
self.evaluator= (
155+
cls(onnx_model)
156+
ifclsisnotNone
157+
elseExtendedReferenceEvaluator(onnx_model)
158+
)
152159

153160
defenumerate_results(
154161
self,
@@ -166,9 +173,9 @@ def enumerate_results(
166173
Returns:
167174
iterator on tuple(result kind, name, value, node.op_type or None)
168175
"""
169-
assertisinstance(self.evaluator,ExtendedReferenceEvaluator), (
176+
assertisinstance(self.evaluator,ReferenceEvaluator), (
170177
f"This implementation only works with "
171-
f"ExtendedReferenceEvaluator not{type(self.evaluator)}"
178+
f"ReferenceEvaluator not{type(self.evaluator)}"
172179
)
173180
attributes= {}
174181
ifoutput_namesisNone:
@@ -595,6 +602,7 @@ def compare_onnx_execution(
595602
raise_exc:bool=True,
596603
mode:str="execute",
597604
keep_tensor:bool=False,
605+
cls:Optional[type[ReferenceEvaluator]]=None,
598606
)->Tuple[List[ResultExecution],List[ResultExecution],List[Tuple[int,int]]]:
599607
"""
600608
Compares the execution of two onnx models.
@@ -611,6 +619,7 @@ def compare_onnx_execution(
611619
:param mode: the model should be executed but the function can be executed
612620
but the comparison may append on nodes only
613621
:param keep_tensor: keeps the tensor in order to compute a precise distance
622+
:param cls: evaluator class to use
614623
:return: four results, a sequence of results
615624
for the first model and the second model,
616625
the alignment between the two, DistanceExecution
@@ -634,15 +643,15 @@ def compare_onnx_execution(
634643
print(f"[compare_onnx_execution] execute with{len(inputs)} inputs")
635644
print("[compare_onnx_execution] execute first model")
636645
res1=list(
637-
YieldEvaluator(model1).enumerate_summarized(
646+
YieldEvaluator(model1,cls=cls).enumerate_summarized(
638647
None,feeds1,raise_exc=raise_exc,keep_tensor=keep_tensor
639648
)
640649
)
641650
ifverbose:
642651
print(f"[compare_onnx_execution] got{len(res1)} results")
643652
print("[compare_onnx_execution] execute second model")
644653
res2=list(
645-
YieldEvaluator(model2).enumerate_summarized(
654+
YieldEvaluator(model2,cls=cls).enumerate_summarized(
646655
None,feeds2,raise_exc=raise_exc,keep_tensor=keep_tensor
647656
)
648657
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp