Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb11db3c

Browse files
committed
Export evaluator type in compare_onnx_execution
1 parent07c3683 commitb11db3c

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

‎LICENSE.txt‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2023-2024, Xavier Dupré
1+
Copyright (c) 2023-2025, Xavier Dupré
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

‎onnx_array_api/reference/evaluator_yield.py‎

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
fromenumimportIntEnum
44
importnumpyasnp
55
fromonnximportModelProto,TensorProto,ValueInfoProto,load
6+
fromonnx.referenceimportReferenceEvaluator
67
fromonnx.helperimporttensor_dtype_to_np_dtype
78
fromonnx.shape_inferenceimportinfer_shapes
89
from .importto_array_extended
@@ -166,9 +167,9 @@ def enumerate_results(
166167
Returns:
167168
iterator on tuple(result kind, name, value, node.op_type or None)
168169
"""
169-
assertisinstance(self.evaluator,ExtendedReferenceEvaluator), (
170+
assertisinstance(self.evaluator,ReferenceEvaluator), (
170171
f"This implementation only works with "
171-
f"ExtendedReferenceEvaluator not{type(self.evaluator)}"
172+
f"ReferenceEvaluator not{type(self.evaluator)}"
172173
)
173174
attributes= {}
174175
ifoutput_namesisNone:
@@ -595,6 +596,7 @@ def compare_onnx_execution(
595596
raise_exc:bool=True,
596597
mode:str="execute",
597598
keep_tensor:bool=False,
599+
cls:Optional[type[ReferenceEvaluator]]=None,
598600
)->Tuple[List[ResultExecution],List[ResultExecution],List[Tuple[int,int]]]:
599601
"""
600602
Compares the execution of two onnx models.
@@ -611,6 +613,7 @@ def compare_onnx_execution(
611613
:param mode: the model should be executed but the function can be executed
612614
but the comparison may append on nodes only
613615
:param keep_tensor: keeps the tensor in order to compute a precise distance
616+
:param cls: evaluator class to use
614617
:return: four results, a sequence of results
615618
for the first model and the second model,
616619
the alignment between the two, DistanceExecution
@@ -634,15 +637,15 @@ def compare_onnx_execution(
634637
print(f"[compare_onnx_execution] execute with{len(inputs)} inputs")
635638
print("[compare_onnx_execution] execute first model")
636639
res1=list(
637-
YieldEvaluator(model1).enumerate_summarized(
640+
YieldEvaluator(model1,cls=cls).enumerate_summarized(
638641
None,feeds1,raise_exc=raise_exc,keep_tensor=keep_tensor
639642
)
640643
)
641644
ifverbose:
642645
print(f"[compare_onnx_execution] got{len(res1)} results")
643646
print("[compare_onnx_execution] execute second model")
644647
res2=list(
645-
YieldEvaluator(model2).enumerate_summarized(
648+
YieldEvaluator(model2,cls=cls).enumerate_summarized(
646649
None,feeds2,raise_exc=raise_exc,keep_tensor=keep_tensor
647650
)
648651
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp