@@ -57,6 +57,7 @@ class ResultExecution:
5757summary :str
5858op_type :str
5959name :str
60+ value :Optional [Any ]= None
6061
6162def __len__ (self )-> int :
6263return 6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122123else :
123124value2 = value .flatten ().astype (np .float64 )
124125value4 = value2 .reshape ((4 ,- 1 )).sum (axis = 1 )
125- value4i = value4 .astype (np .int64 )% modulo
126- s = "" .join ([chr (65 + i )for i in value4i ])
127- return s
126+ value4 = np .where (np .abs (value4 )< 1e10 ,value4 ,np .nan )
127+ s = []
128+ for v in value4 :
129+ s .append ("?" if np .isnan (v )else (chr (65 + int (v )% modulo )))
130+ return "" .join (s )
128131
129132
130133class YieldEvaluator :
@@ -228,6 +231,7 @@ def enumerate_summarized(
228231output_names :Optional [List [str ]]= None ,
229232feed_inputs :Optional [Dict [str ,Any ]]= None ,
230233raise_exc :bool = True ,
234+ keep_tensor :bool = False ,
231235 )-> Iterator [ResultExecution ]:
232236"""
233237 Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236240 :param feed_inputs: dictionary `{ input name: input value }`
237241 :param raise_exc: raises an exception if the execution fails or stop
238242 where it is
243+ :param keep_tensor:keep the tensor in order to compute precise distances
239244 :return: iterator on ResultExecution
240245 """
241246for kind ,name ,value ,op_type in self .enumerate_results (
242247output_names ,feed_inputs ,raise_exc = raise_exc
243248 ):
244249summary = make_summary (value )
245250yield ResultExecution (
246- kind ,value .dtype ,value .shape ,summary ,op_type ,name
251+ kind ,
252+ value .dtype ,
253+ value .shape ,
254+ summary ,
255+ op_type ,
256+ name ,
257+ value = value if keep_tensor else None ,
247258 )
248259
249260
261+ def discrepancies (
262+ expected :np .ndarray ,value :np .ndarray ,eps :float = 1e-7
263+ )-> Dict [str ,float ]:
264+ """
265+ Computes absolute error and relative error between two matrices.
266+ """
267+ assert (
268+ expected .size == value .size
269+ ),f"Incompatible shapes v1.shape={ expected .shape } , v2.shape={ value .shape } "
270+ expected = expected .ravel ().astype (np .float32 )
271+ value = value .ravel ().astype (np .float32 )
272+ diff = np .abs (expected - value )
273+ rel = diff / (np .abs (expected )+ eps )
274+ return dict (aerr = float (diff .max ()),rerr = float (rel .max ()))
275+
276+
250277class DistanceExecution :
251278"""
252279 Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403430d = self .distance_pair (d1 ,d2 )
404431symbol = "=" if d == 0 else "~"
405432line = f"{ symbol } |{ _align (str (d1 ),column_size )} |{ _align (str (d2 ),column_size )} "
433+ if (
434+ d1 .value is not None
435+ and d2 .value is not None
436+ and d1 .value .size == d2 .value .size
437+ ):
438+ disc = discrepancies (d1 .value ,d2 .value )
439+ a ,r = disc ["aerr" ],disc ["rerr" ]
440+ line += f" | a={ a :.3f} r={ r :.3f} "
406441elif i == last [0 ]:
407442d2 = s2 [j ]
408443line = (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551586verbose :int = 0 ,
552587raise_exc :bool = True ,
553588mode :str = "execute" ,
589+ keep_tensor :bool = False ,
554590)-> Tuple [List [ResultExecution ],List [ResultExecution ],List [Tuple [int ,int ]]]:
555591"""
556592 Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566602 :param raise_exc: raise exception if the execution fails or stop at the error
567603 :param mode: the model should be executed but the function can be executed
568604 but the comparison may append on nodes only
605+ :param keep_tensor: keeps the tensor in order to compute a precise distance
569606 :return: four results, a sequence of results for the first model and the second model,
570607 the alignment between the two, DistanceExecution
571608 """
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589626print ("[compare_onnx_execution] execute first model" )
590627res1 = list (
591628YieldEvaluator (model1 ).enumerate_summarized (
592- None ,feeds1 ,raise_exc = raise_exc
629+ None ,feeds1 ,raise_exc = raise_exc , keep_tensor = keep_tensor
593630 )
594631 )
595632if verbose :
596633print (f"[compare_onnx_execution] got{ len (res1 )} results" )
597634print ("[compare_onnx_execution] execute second model" )
598635res2 = list (
599636YieldEvaluator (model2 ).enumerate_summarized (
600- None ,feeds2 ,raise_exc = raise_exc
637+ None ,feeds2 ,raise_exc = raise_exc , keep_tensor = keep_tensor
601638 )
602639 )
603640elif mode == "nodes" :