33from enum import IntEnum
44import numpy as np
55from onnx import ModelProto ,TensorProto ,ValueInfoProto ,load
6+ from onnx .reference import ReferenceEvaluator
67from onnx .helper import tensor_dtype_to_np_dtype
78from onnx .shape_inference import infer_shapes
89from .import to_array_extended
@@ -166,9 +167,9 @@ def enumerate_results(
166167 Returns:
167168 iterator on tuple(result kind, name, value, node.op_type or None)
168169 """
169- assert isinstance (self .evaluator ,ExtendedReferenceEvaluator ), (
170+ assert isinstance (self .evaluator ,ReferenceEvaluator ), (
170171f"This implementation only works with "
171- f"ExtendedReferenceEvaluator not{ type (self .evaluator )} "
172+ f"ReferenceEvaluator not{ type (self .evaluator )} "
172173 )
173174attributes = {}
174175if output_names is None :
@@ -595,6 +596,7 @@ def compare_onnx_execution(
595596raise_exc :bool = True ,
596597mode :str = "execute" ,
597598keep_tensor :bool = False ,
599+ cls :Optional [type [ReferenceEvaluator ]]= None ,
598600)-> Tuple [List [ResultExecution ],List [ResultExecution ],List [Tuple [int ,int ]]]:
599601"""
600602 Compares the execution of two onnx models.
@@ -611,6 +613,7 @@ def compare_onnx_execution(
611613 :param mode: the model should be executed but the function can be executed
612614 but the comparison may append on nodes only
613615 :param keep_tensor: keeps the tensor in order to compute a precise distance
616+ :param cls: evaluator class to use
614617 :return: four results, a sequence of results
615618 for the first model and the second model,
616619 the alignment between the two, DistanceExecution
@@ -634,15 +637,15 @@ def compare_onnx_execution(
634637print (f"[compare_onnx_execution] execute with{ len (inputs )} inputs" )
635638print ("[compare_onnx_execution] execute first model" )
636639res1 = list (
637- YieldEvaluator (model1 ).enumerate_summarized (
640+ YieldEvaluator (model1 , cls = cls ).enumerate_summarized (
638641None ,feeds1 ,raise_exc = raise_exc ,keep_tensor = keep_tensor
639642 )
640643 )
641644if verbose :
642645print (f"[compare_onnx_execution] got{ len (res1 )} results" )
643646print ("[compare_onnx_execution] execute second model" )
644647res2 = list (
645- YieldEvaluator (model2 ).enumerate_summarized (
648+ YieldEvaluator (model2 , cls = cls ).enumerate_summarized (
646649None ,feeds2 ,raise_exc = raise_exc ,keep_tensor = keep_tensor
647650 )
648651 )