Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit19e8a9e

Browse files
authored
Add line number to the diff report (#72)
* update requirements* add line numbers* doc
1 parent8835156 commit19e8a9e

File tree

2 files changed

+50
-33
lines changed

2 files changed

+50
-33
lines changed

‎_unittests/ut_reference/test_evaluator_yield.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,18 +422,19 @@ def test_distance_sequence_str(self):
422422
text=dc.to_str(s1,s2,align)
423423
self.assertIn("OUTPUT",text)
424424
expected="""
425-
=|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA
426-
=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427-
~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428-
-|RESULTfloat322x2CEIOExpH|
429-
=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
430-
~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431-
~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
425+
1=|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA
426+
2=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427+
3~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428+
4-|RESULTfloat322x2CEIOExpH|
429+
5=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
430+
6~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431+
7~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
432432
""".replace(
433433
" ",""
434434
).strip(
435435
"\n "
436436
)
437+
self.maxDiff=None
437438
self.assertEqual(expected,text.replace(" ","").strip("\n"))
438439

439440
deftest_compare_execution(self):

‎onnx_array_api/reference/evaluator_yield.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def enumerate_results(
118118
self,
119119
output_names:Optional[List[str]]=None,
120120
feed_inputs:Optional[Dict[str,Any]]=None,
121+
raise_exc:bool=True,
121122
)->Iterator[Tuple[ResultType,str,Any]]:
122123
"""
123124
Executes the onnx model and enumerate all the intermediate results.
@@ -148,6 +149,7 @@ def enumerate_results(
148149
yieldResultType.INPUT,k,v,None
149150

150151
# step 2: execute nodes
152+
yield_output=True
151153
fornodeinself.evaluator.rt_nodes_:
152154
foriinnode.input:
153155
ifinotinresults:
@@ -160,39 +162,48 @@ def enumerate_results(
160162
linked_attributes= {}
161163
ifnode.has_linked_attributeandattributes:
162164
linked_attributes["linked_attributes"]=attributes
163-
ifnode.need_context():
164-
outputs=node.run(*inputs,context=results,**linked_attributes)
165-
else:
166-
outputs=node.run(*inputs,**linked_attributes)
165+
166+
try:
167+
ifnode.need_context():
168+
outputs=node.run(*inputs,context=results,**linked_attributes)
169+
else:
170+
outputs=node.run(*inputs,**linked_attributes)
171+
exceptException:
172+
ifraise_exc:
173+
raise
174+
yield_output=False
175+
break
176+
167177
forname,valueinzip(node.output,outputs):
168178
yieldResultType.RESULT,name,value,node.op_type
169179
results[name]=value
170180

171181
# step 3: outputs
172-
fornameinoutput_names:
173-
ifnamenotinresults:
174-
raiseRuntimeError(
175-
f"Unable to find output name{name!r} in{sorted(results)}, proto is\n{self.proto_}"
176-
)
177-
yieldResultType.OUTPUT,name,results[name],None
182+
ifyield_output:
183+
fornameinoutput_names:
184+
ifnamenotinresults:
185+
raiseRuntimeError(
186+
f"Unable to find output name{name!r} in{sorted(results)}, proto is\n{self.proto_}"
187+
)
188+
yieldResultType.OUTPUT,name,results[name],None
178189

179190
defenumerate_summarized(
180191
self,
181192
output_names:Optional[List[str]]=None,
182193
feed_inputs:Optional[Dict[str,Any]]=None,
194+
raise_exc:bool=True,
183195
)->Iterator[ResultExecution]:
184196
"""
185197
Executes the onnx model and enumerate intermediate results without their names.
186198
187-
Args:
188-
output_names: requested outputs by names, None for all
189-
feed_inputs: dictionary `{ input name: input value }`
190-
191-
Returns:
192-
iterator on tuple(result kind, node.type, dtype, shape, value, result name)
199+
:param output_names: requested outputs by names, None for all
200+
:param feed_inputs: dictionary `{ input name: input value }`
201+
:param raise_exc: raises an exception if the execution fails or stop
202+
where it is
203+
:return: iterator on ResultExecution
193204
"""
194205
forkind,name,value,op_typeinself.enumerate_results(
195-
output_names,feed_inputs
206+
output_names,feed_inputs,raise_exc=raise_exc
196207
):
197208
summary=make_summary(value)
198209
yieldResultExecution(
@@ -328,6 +339,7 @@ def to_str(
328339
"""
329340
rows= []
330341
last=-1,-1
342+
row_index=1
331343
fori,jinalignment:
332344
asserti<len(s1),f"Unexpected value i={i} >= len(s1)={len(s1)}"
333345
assertj<len(s2),f"Unexpected value i={j} >= len(s2)={len(s2)}"
@@ -338,20 +350,18 @@ def to_str(
338350
d2=s2[j]
339351
d=self.distance_pair(d1,d2)
340352
symbol="="ifd==0else"~"
341-
rows.append(
342-
f"{symbol} |{_align(str(d1),column_size)} |{_align(str(d2),column_size)}"
343-
)
353+
line=f"{symbol} |{_align(str(d1),column_size)} |{_align(str(d2),column_size)}"
344354
elifi==last[0]:
345355
d2=s2[j]
346-
rows.append(
356+
line=(
347357
f"+ |{_align('',column_size)} |{_align(str(d2),column_size)} "
348358
)
349359
else:
350360
d1=s1[i]
351-
rows.append(
352-
f"- |{_align(str(d1),column_size)} |{_align('',column_size)}"
353-
)
361+
line=f"- |{_align(str(d1),column_size)} |{_align('',column_size)}"
362+
rows.append(f"{row_index: 3d}{line}")
354363
last=i,j
364+
row_index+=1
355365
return"\n".join(rows)
356366

357367

@@ -410,6 +420,7 @@ def compare_onnx_execution(
410420
model2:ModelProto,
411421
inputs:Optional[List[Any]]=None,
412422
verbose:int=0,
423+
raise_exc:bool=True,
413424
)->Tuple[List[ResultExecution],List[ResultExecution],List[Tuple[int,int]]]:
414425
"""
415426
Compares the execution of two onnx models.
@@ -421,6 +432,7 @@ def compare_onnx_execution(
421432
:param model2: second model
422433
:param inputs: inputs to use
423434
:param verbose: verbosity
435+
:param raise_exc: raise exception if the execution fails or stop at the error
424436
:return: four results, a sequence of results for the first model and the second model,
425437
the alignment between the two, DistanceExecution
426438
"""
@@ -433,11 +445,15 @@ def compare_onnx_execution(
433445
ifverbose:
434446
print(f"[compare_onnx_execution] got{len(inputs)} inputs")
435447
print("[compare_onnx_execution] execute first model")
436-
res1=list(YieldEvaluator(model1).enumerate_summarized(None,feeds1))
448+
res1=list(
449+
YieldEvaluator(model1).enumerate_summarized(None,feeds1,raise_exc=raise_exc)
450+
)
437451
ifverbose:
438452
print(f"[compare_onnx_execution] got{len(res1)} results")
439453
print("[compare_onnx_execution] execute second model")
440-
res2=list(YieldEvaluator(model2).enumerate_summarized(None,feeds2))
454+
res2=list(
455+
YieldEvaluator(model2).enumerate_summarized(None,feeds2,raise_exc=raise_exc)
456+
)
441457
ifverbose:
442458
print(f"[compare_onnx_execution] got{len(res2)} results")
443459
print("[compare_onnx_execution] compute edit distance")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp