33from enum import IntEnum
44import numpy as np
55from onnx import ModelProto ,TensorProto ,ValueInfoProto ,load
6+ from onnx .reference import ReferenceEvaluator
67from onnx .helper import tensor_dtype_to_np_dtype
78from onnx .shape_inference import infer_shapes
89from .import to_array_extended
@@ -138,17 +139,23 @@ class YieldEvaluator:
138139
139140 :param onnx_model: model to run
140141 :param recursive: dig into subgraph and functions as well
142+ :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
143+ <onnx_array_api.reference.ExtendedReferenceEvaluator>`
141144 """
142145
143146def __init__ (
144147self ,
145148onnx_model :ModelProto ,
146149recursive :bool = False ,
147- cls = ExtendedReferenceEvaluator ,
150+ cls : Optional [ type [ ExtendedReferenceEvaluator ]] = None ,
148151 ):
149152assert not recursive ,"recursive=True is not yet implemented"
150153self .onnx_model = onnx_model
151- self .evaluator = cls (onnx_model )if cls is not None else None
154+ self .evaluator = (
155+ cls (onnx_model )
156+ if cls is not None
157+ else ExtendedReferenceEvaluator (onnx_model )
158+ )
152159
153160def enumerate_results (
154161self ,
@@ -166,9 +173,9 @@ def enumerate_results(
166173 Returns:
167174 iterator on tuple(result kind, name, value, node.op_type or None)
168175 """
169- assert isinstance (self .evaluator ,ExtendedReferenceEvaluator ), (
176+ assert isinstance (self .evaluator ,ReferenceEvaluator ), (
170177f"This implementation only works with "
171- f"ExtendedReferenceEvaluator not{ type (self .evaluator )} "
178+ f"ReferenceEvaluator not{ type (self .evaluator )} "
172179 )
173180attributes = {}
174181if output_names is None :
@@ -595,6 +602,7 @@ def compare_onnx_execution(
595602raise_exc :bool = True ,
596603mode :str = "execute" ,
597604keep_tensor :bool = False ,
605+ cls :Optional [type [ReferenceEvaluator ]]= None ,
598606)-> Tuple [List [ResultExecution ],List [ResultExecution ],List [Tuple [int ,int ]]]:
599607"""
600608 Compares the execution of two onnx models.
@@ -611,6 +619,7 @@ def compare_onnx_execution(
611619 :param mode: the model should be executed but the function can be executed
612620 but the comparison may append on nodes only
613621 :param keep_tensor: keeps the tensor in order to compute a precise distance
622+ :param cls: evaluator class to use
614623 :return: four results, a sequence of results
615624 for the first model and the second model,
616625 the alignment between the two, DistanceExecution
@@ -634,15 +643,15 @@ def compare_onnx_execution(
634643print (f"[compare_onnx_execution] execute with{ len (inputs )} inputs" )
635644print ("[compare_onnx_execution] execute first model" )
636645res1 = list (
637- YieldEvaluator (model1 ).enumerate_summarized (
646+ YieldEvaluator (model1 , cls = cls ).enumerate_summarized (
638647None ,feeds1 ,raise_exc = raise_exc ,keep_tensor = keep_tensor
639648 )
640649 )
641650if verbose :
642651print (f"[compare_onnx_execution] got{ len (res1 )} results" )
643652print ("[compare_onnx_execution] execute second model" )
644653res2 = list (
645- YieldEvaluator (model2 ).enumerate_summarized (
654+ YieldEvaluator (model2 , cls = cls ).enumerate_summarized (
646655None ,feeds2 ,raise_exc = raise_exc ,keep_tensor = keep_tensor
647656 )
648657 )