Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitdcc2ddd

Browse files
authored
Add discrepancies when comparing the execution of two models (#79)
* update requirements* add discrepancies figures* fix command line* doc
1 parenta906010 commitdcc2ddd

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

‎CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Change Logs
55
+++++
66

77
*:pr:`77`: supports ConcatOfShape and Slice with the light API
8-
*:pr:`76`: add a mode to compare models without execution
8+
*:pr:`76`,:pr:`79`: add a mode to compare models without execution
99
*:pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
1010
*:pr:`71`: adds tools to compare two onnx graphs
1111
*:pr:`61`: adds function to plot onnx model as graphs

‎_unittests/ut_reference/test_evaluator_yield.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,31 @@ def test_compare_execution(self):
462462
self.assertIn("CAAA Constant",text)
463463
self.assertEqual(len(align),5)
464464

465+
deftest_compare_execution_discrepancies(self):
466+
m1=parse_model(
467+
"""
468+
<ir_version: 8, opset_import: [ "": 18]>
469+
agraph (float[N] x) => (float[N] z) {
470+
two = Constant <value_float=2.0> ()
471+
four = Add(two, two)
472+
z = Mul(x, x)
473+
}"""
474+
)
475+
m2=parse_model(
476+
"""
477+
<ir_version: 8, opset_import: [ "": 18]>
478+
agraph (float[N] x) => (float[N] z) {
479+
two = Constant <value_float=2.0> ()
480+
z = Mul(x, x)
481+
}"""
482+
)
483+
res1,res2,align,dc=compare_onnx_execution(m1,m2,keep_tensor=True)
484+
text=dc.to_str(res1,res2,align)
485+
print(text)
486+
self.assertIn("CAAA Constant",text)
487+
self.assertIn("| a=",text)
488+
self.assertIn(" r=",text)
489+
465490
deftest_no_execution(self):
466491
model=make_model(
467492
make_graph(

‎onnx_array_api/_command_lines_parser.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,15 @@ def get_parser_compare() -> ArgumentParser:
106106
parser.add_argument(
107107
"-c",
108108
"--column-size",
109-
default=50,
109+
default=60,
110110
help="column size when displaying the results",
111111
)
112+
parser.add_argument(
113+
"-d",
114+
"--discrepancies",
115+
default=0,
116+
help="show precise discrepancies when mode is execution",
117+
)
112118
returnparser
113119

114120

@@ -120,7 +126,11 @@ def _cmd_compare(argv: List[Any]):
120126
onx1=onnx.load(args.model1)
121127
onx2=onnx.load(args.model2)
122128
res1,res2,align,dc=compare_onnx_execution(
123-
onx1,onx2,verbose=args.verbose,mode=args.mode
129+
onx1,
130+
onx2,
131+
verbose=args.verbose,
132+
mode=args.mode,
133+
keep_tensor=args.discrepanciesin (1,"1","True",True),
124134
)
125135
text=dc.to_str(res1,res2,align,column_size=int(args.column_size))
126136
print(text)

‎onnx_array_api/reference/evaluator_yield.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ResultExecution:
5757
summary:str
5858
op_type:str
5959
name:str
60+
value:Optional[Any]=None
6061

6162
def__len__(self)->int:
6263
return6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122123
else:
123124
value2=value.flatten().astype(np.float64)
124125
value4=value2.reshape((4,-1)).sum(axis=1)
125-
value4i=value4.astype(np.int64)%modulo
126-
s="".join([chr(65+i)foriinvalue4i])
127-
returns
126+
value4=np.where(np.abs(value4)<1e10,value4,np.nan)
127+
s= []
128+
forvinvalue4:
129+
s.append("?"ifnp.isnan(v)else (chr(65+int(v)%modulo)))
130+
return"".join(s)
128131

129132

130133
classYieldEvaluator:
@@ -228,6 +231,7 @@ def enumerate_summarized(
228231
output_names:Optional[List[str]]=None,
229232
feed_inputs:Optional[Dict[str,Any]]=None,
230233
raise_exc:bool=True,
234+
keep_tensor:bool=False,
231235
)->Iterator[ResultExecution]:
232236
"""
233237
Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236240
:param feed_inputs: dictionary `{ input name: input value }`
237241
:param raise_exc: raises an exception if the execution fails or stop
238242
where it is
243+
:param keep_tensor:keep the tensor in order to compute precise distances
239244
:return: iterator on ResultExecution
240245
"""
241246
forkind,name,value,op_typeinself.enumerate_results(
242247
output_names,feed_inputs,raise_exc=raise_exc
243248
):
244249
summary=make_summary(value)
245250
yieldResultExecution(
246-
kind,value.dtype,value.shape,summary,op_type,name
251+
kind,
252+
value.dtype,
253+
value.shape,
254+
summary,
255+
op_type,
256+
name,
257+
value=valueifkeep_tensorelseNone,
247258
)
248259

249260

261+
defdiscrepancies(
262+
expected:np.ndarray,value:np.ndarray,eps:float=1e-7
263+
)->Dict[str,float]:
264+
"""
265+
Computes absolute error and relative error between two matrices.
266+
"""
267+
assert (
268+
expected.size==value.size
269+
),f"Incompatible shapes v1.shape={expected.shape}, v2.shape={value.shape}"
270+
expected=expected.ravel().astype(np.float32)
271+
value=value.ravel().astype(np.float32)
272+
diff=np.abs(expected-value)
273+
rel=diff/ (np.abs(expected)+eps)
274+
returndict(aerr=float(diff.max()),rerr=float(rel.max()))
275+
276+
250277
classDistanceExecution:
251278
"""
252279
Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403430
d=self.distance_pair(d1,d2)
404431
symbol="="ifd==0else"~"
405432
line=f"{symbol} |{_align(str(d1),column_size)} |{_align(str(d2),column_size)}"
433+
if (
434+
d1.valueisnotNone
435+
andd2.valueisnotNone
436+
andd1.value.size==d2.value.size
437+
):
438+
disc=discrepancies(d1.value,d2.value)
439+
a,r=disc["aerr"],disc["rerr"]
440+
line+=f" | a={a:.3f} r={r:.3f}"
406441
elifi==last[0]:
407442
d2=s2[j]
408443
line= (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551586
verbose:int=0,
552587
raise_exc:bool=True,
553588
mode:str="execute",
589+
keep_tensor:bool=False,
554590
)->Tuple[List[ResultExecution],List[ResultExecution],List[Tuple[int,int]]]:
555591
"""
556592
Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566602
:param raise_exc: raise exception if the execution fails or stop at the error
567603
:param mode: the model should be executed but the function can be executed
568604
but the comparison may append on nodes only
605+
:param keep_tensor: keeps the tensor in order to compute a precise distance
569606
:return: four results, a sequence of results for the first model and the second model,
570607
the alignment between the two, DistanceExecution
571608
"""
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589626
print("[compare_onnx_execution] execute first model")
590627
res1=list(
591628
YieldEvaluator(model1).enumerate_summarized(
592-
None,feeds1,raise_exc=raise_exc
629+
None,feeds1,raise_exc=raise_exc,keep_tensor=keep_tensor
593630
)
594631
)
595632
ifverbose:
596633
print(f"[compare_onnx_execution] got{len(res1)} results")
597634
print("[compare_onnx_execution] execute second model")
598635
res2=list(
599636
YieldEvaluator(model2).enumerate_summarized(
600-
None,feeds2,raise_exc=raise_exc
637+
None,feeds2,raise_exc=raise_exc,keep_tensor=keep_tensor
601638
)
602639
)
603640
elifmode=="nodes":

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp