@@ -118,6 +118,7 @@ def enumerate_results(
118118self ,
119119output_names :Optional [List [str ]]= None ,
120120feed_inputs :Optional [Dict [str ,Any ]]= None ,
121+ raise_exc :bool = True ,
121122 )-> Iterator [Tuple [ResultType ,str ,Any ]]:
122123"""
123124 Executes the onnx model and enumerate all the intermediate results.
@@ -148,6 +149,7 @@ def enumerate_results(
148149yield ResultType .INPUT ,k ,v ,None
149150
150151# step 2: execute nodes
152+ yield_output = True
151153for node in self .evaluator .rt_nodes_ :
152154for i in node .input :
153155if i not in results :
@@ -160,39 +162,48 @@ def enumerate_results(
160162linked_attributes = {}
161163if node .has_linked_attribute and attributes :
162164linked_attributes ["linked_attributes" ]= attributes
163- if node .need_context ():
164- outputs = node .run (* inputs ,context = results ,** linked_attributes )
165- else :
166- outputs = node .run (* inputs ,** linked_attributes )
165+
166+ try :
167+ if node .need_context ():
168+ outputs = node .run (* inputs ,context = results ,** linked_attributes )
169+ else :
170+ outputs = node .run (* inputs ,** linked_attributes )
171+ except Exception :
172+ if raise_exc :
173+ raise
174+ yield_output = False
175+ break
176+
167177for name ,value in zip (node .output ,outputs ):
168178yield ResultType .RESULT ,name ,value ,node .op_type
169179results [name ]= value
170180
171181# step 3: outputs
172- for name in output_names :
173- if name not in results :
174- raise RuntimeError (
175- f"Unable to find output name{ name !r} in{ sorted (results )} , proto is\n { self .proto_ } "
176- )
177- yield ResultType .OUTPUT ,name ,results [name ],None
182+ if yield_output :
183+ for name in output_names :
184+ if name not in results :
185+ raise RuntimeError (
186+ f"Unable to find output name{ name !r} in{ sorted (results )} , proto is\n { self .proto_ } "
187+ )
188+ yield ResultType .OUTPUT ,name ,results [name ],None
178189
179190def enumerate_summarized (
180191self ,
181192output_names :Optional [List [str ]]= None ,
182193feed_inputs :Optional [Dict [str ,Any ]]= None ,
194+ raise_exc :bool = True ,
183195 )-> Iterator [ResultExecution ]:
184196"""
185197 Executes the onnx model and enumerate intermediate results without their names.
186198
187- Args:
188- output_names: requested outputs by names, None for all
189- feed_inputs: dictionary `{ input name: input value }`
190-
191- Returns:
192- iterator on tuple(result kind, node.type, dtype, shape, value, result name)
199+ :param output_names: requested outputs by names, None for all
200+ :param feed_inputs: dictionary `{ input name: input value }`
201+ :param raise_exc: raises an exception if the execution fails or stop
202+ where it is
203+ :return: iterator on ResultExecution
193204 """
194205for kind ,name ,value ,op_type in self .enumerate_results (
195- output_names ,feed_inputs
206+ output_names ,feed_inputs , raise_exc = raise_exc
196207 ):
197208summary = make_summary (value )
198209yield ResultExecution (
@@ -328,6 +339,7 @@ def to_str(
328339 """
329340rows = []
330341last = - 1 ,- 1
342+ row_index = 1
331343for i ,j in alignment :
332344assert i < len (s1 ),f"Unexpected value i={ i } >= len(s1)={ len (s1 )} "
333345assert j < len (s2 ),f"Unexpected value i={ j } >= len(s2)={ len (s2 )} "
@@ -338,20 +350,18 @@ def to_str(
338350d2 = s2 [j ]
339351d = self .distance_pair (d1 ,d2 )
340352symbol = "=" if d == 0 else "~"
341- rows .append (
342- f"{ symbol } |{ _align (str (d1 ),column_size )} |{ _align (str (d2 ),column_size )} "
343- )
353+ line = f"{ symbol } |{ _align (str (d1 ),column_size )} |{ _align (str (d2 ),column_size )} "
344354elif i == last [0 ]:
345355d2 = s2 [j ]
346- rows . append (
356+ line = (
347357f"+ |{ _align ('' ,column_size )} |{ _align (str (d2 ),column_size )} "
348358 )
349359else :
350360d1 = s1 [i ]
351- rows .append (
352- f"- |{ _align (str (d1 ),column_size )} |{ _align ('' ,column_size )} "
353- )
361+ line = f"- |{ _align (str (d1 ),column_size )} |{ _align ('' ,column_size )} "
362+ rows .append (f"{ row_index : 3d} { line } " )
354363last = i ,j
364+ row_index += 1
355365return "\n " .join (rows )
356366
357367
@@ -410,6 +420,7 @@ def compare_onnx_execution(
410420model2 :ModelProto ,
411421inputs :Optional [List [Any ]]= None ,
412422verbose :int = 0 ,
423+ raise_exc :bool = True ,
413424)-> Tuple [List [ResultExecution ],List [ResultExecution ],List [Tuple [int ,int ]]]:
414425"""
415426 Compares the execution of two onnx models.
@@ -421,6 +432,7 @@ def compare_onnx_execution(
421432 :param model2: second model
422433 :param inputs: inputs to use
423434 :param verbose: verbosity
435+ :param raise_exc: raise exception if the execution fails or stop at the error
424436 :return: four results, a sequence of results for the first model and the second model,
425437 the alignment between the two, DistanceExecution
426438 """
@@ -433,11 +445,15 @@ def compare_onnx_execution(
433445if verbose :
434446print (f"[compare_onnx_execution] got{ len (inputs )} inputs" )
435447print ("[compare_onnx_execution] execute first model" )
436- res1 = list (YieldEvaluator (model1 ).enumerate_summarized (None ,feeds1 ))
448+ res1 = list (
449+ YieldEvaluator (model1 ).enumerate_summarized (None ,feeds1 ,raise_exc = raise_exc )
450+ )
437451if verbose :
438452print (f"[compare_onnx_execution] got{ len (res1 )} results" )
439453print ("[compare_onnx_execution] execute second model" )
440- res2 = list (YieldEvaluator (model2 ).enumerate_summarized (None ,feeds2 ))
454+ res2 = list (
455+ YieldEvaluator (model2 ).enumerate_summarized (None ,feeds2 ,raise_exc = raise_exc )
456+ )
441457if verbose :
442458print (f"[compare_onnx_execution] got{ len (res2 )} results" )
443459print ("[compare_onnx_execution] compute edit distance" )