11from dataclasses import dataclass
2- from typing import Any ,Dict ,List ,Iterator ,Optional ,Tuple
2+ from typing import Any ,Dict ,List ,Iterator ,Optional ,Tuple , Union
33from enum import IntEnum
44import numpy as np
55from onnx import ModelProto ,TensorProto ,ValueInfoProto
@@ -77,6 +77,12 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
7777 :param module: discretization parameter
7878 :return: short string
7979 """
80+ if isinstance (value ,np .float32 ):
81+ # This should not happen.
82+ value = np .array (value )
83+ assert isinstance (
84+ value ,np .ndarray
85+ ),f"Unexpected type{ type (value )} for value, it must be a numpy array."
8086value4 = np .zeros (length ,dtype = np .float64 )
8187if value .size <= length :
8288value4 [:value .size ]= value .flatten ().astype (np .float64 )
@@ -170,6 +176,9 @@ def enumerate_results(
170176outputs = node .run (* inputs ,** linked_attributes )
171177except Exception :
172178if raise_exc :
179+ # ExtendedReferenceEvaluator(self.onnx_model, verbose=10).run(
180+ # None, feed_inputs
181+ # )
173182raise
174183yield_output = False
175184break
@@ -286,12 +295,12 @@ def distance_sequence(
286295 :param s2: second sequence
287296 :return: distance and alignment
288297 """
289- delay = self .max_lag
298+ delay = max ( self .max_lag , abs ( len ( s2 ) - len ( s1 )) + 1 )
290299distance = {(- 1 ,- 1 ):0 }
291300predecessor = {(- 1 ,- 1 ):None }
292301for i in range (len (s1 )):
293302for j in range (max (0 ,i - delay ),min (len (s2 ),i + delay )):
294- best = 1e100
303+ best = distance . get (( i , j ), 1e100 )
295304pred = None
296305ki ,kj = i - 1 ,j - 1
297306if (ki ,kj )in distance :
@@ -418,7 +427,7 @@ def generate_inputs(model: ModelProto) -> List[np.ndarray]:
418427def compare_onnx_execution (
419428model1 :ModelProto ,
420429model2 :ModelProto ,
421- inputs :Optional [List [Any ]]= None ,
430+ inputs :Optional [Union [ List [Any ], Tuple [ Dict [ str , Any ]] ]]= None ,
422431verbose :int = 0 ,
423432raise_exc :bool = True ,
424433)-> Tuple [List [ResultExecution ],List [ResultExecution ],List [Tuple [int ,int ]]]:
@@ -430,7 +439,8 @@ def compare_onnx_execution(
430439
431440 :param model1: first model
432441 :param model2: second model
433- :param inputs: inputs to use
442+ :param inputs: inputs to use, a list of inputs if both models have
443+ the same number of inputs or two dictionaries, one for each model
434444 :param verbose: verbosity
435445 :param raise_exc: raise exception if the execution fails or stop at the error
436446 :return: four results, a sequence of results for the first model and the second model,
@@ -440,8 +450,14 @@ def compare_onnx_execution(
440450print ("[compare_onnx_execution] generate inputs" )
441451if inputs is None :
442452inputs = generate_inputs (model1 )
443- feeds1 = {i .name :v for i ,v in zip (model1 .graph .input ,inputs )}
444- feeds2 = {i .name :v for i ,v in zip (model2 .graph .input ,inputs )}
453+ if isinstance (inputs ,tuple ):
454+ assert len (inputs )== 2 ,f"Unexpected number{ len (inputs )} of inputs."
455+ feeds1 ,feeds2 = inputs
456+ else :
457+ feeds1 = {i .name :v for i ,v in zip (model1 .graph .input ,inputs )}
458+ feeds2 = {i .name :v for i ,v in zip (model2 .graph .input ,inputs )}
459+ assert isinstance (feeds1 ,dict ),f"Unexpected type{ type (feeds1 )} for inputs"
460+ assert isinstance (feeds2 ,dict ),f"Unexpected type{ type (feeds2 )} for inputs"
445461if verbose :
446462print (f"[compare_onnx_execution] got{ len (inputs )} inputs" )
447463print ("[compare_onnx_execution] execute first model" )