Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitebafa26

Browse files
authored
Adds function to plot onnx model as graphs (#61)
* Add methods to draw onnx plots* improve versatility* doc* disable test when graphviz not installed* documentation* add missing function
1 parent7895c27 commitebafa26

File tree

5 files changed

+307
-0
lines changed

5 files changed

+307
-0
lines changed

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ build/*
1414
*egg-info/*
1515
onnxruntime_profile*
1616
prof
17+
test*.png
1718
_doc/sg_execution_times.rst
1819
_doc/auto_examples/*
1920
_doc/examples/_cache/*

‎CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`61`: adds function to plot onnx model as graphs
78
*:pr:`60`: supports translation of local functions
89
*:pr:`59`: add methods to update nodes in GraphAPI
910

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
importos
2+
importunittest
3+
importonnx.parser
4+
fromonnx_array_api.ext_test_caseimport (
5+
ExtTestCase,
6+
skipif_ci_windows,
7+
skipif_ci_apple,
8+
)
9+
fromonnx_array_api.plotting.dot_plotimportto_dot
10+
fromonnx_array_api.plotting.graphviz_helperimportdraw_graph_graphviz,plot_dot
11+
12+
13+
classTestGraphviz(ExtTestCase):
14+
@classmethod
15+
def_get_graph(cls):
16+
returnonnx.parser.parse_model(
17+
"""
18+
<ir_version: 8, opset_import: [ "": 18]>
19+
agraph (float[N] x) => (float[N] z) {
20+
two = Constant <value_float=2.0> ()
21+
four = Add(two, two)
22+
z = Mul(x, x)
23+
}"""
24+
)
25+
26+
@skipif_ci_windows("graphviz not installed")
27+
@skipif_ci_apple("graphviz not installed")
28+
deftest_draw_graph_graphviz(self):
29+
fout="test_draw_graph_graphviz.png"
30+
dot=to_dot(self._get_graph())
31+
draw_graph_graphviz(dot,image=fout)
32+
self.assertExists(os.path.exists(fout))
33+
34+
@skipif_ci_windows("graphviz not installed")
35+
@skipif_ci_apple("graphviz not installed")
36+
deftest_draw_graph_graphviz_proto(self):
37+
fout="test_draw_graph_graphviz_proto.png"
38+
dot=self._get_graph()
39+
draw_graph_graphviz(dot,image=fout)
40+
self.assertExists(os.path.exists(fout))
41+
42+
@skipif_ci_windows("graphviz not installed")
43+
@skipif_ci_apple("graphviz not installed")
44+
deftest_plot_dot(self):
45+
dot=to_dot(self._get_graph())
46+
ax=plot_dot(dot)
47+
ax.get_figure().savefig("test_plot_dot.png")
48+
49+
50+
if__name__=="__main__":
51+
unittest.main(verbosity=2)

‎onnx_array_api/ext_test_case.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def is_windows() -> bool:
1919
returnsys.platform=="win32"
2020

2121

22+
defis_apple()->bool:
23+
returnsys.platform=="darwin"
24+
25+
2226
defskipif_ci_windows(msg)->Callable:
2327
"""
2428
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
@@ -29,6 +33,16 @@ def skipif_ci_windows(msg) -> Callable:
2933
returnlambdax:x
3034

3135

36+
defskipif_ci_apple(msg)->Callable:
37+
"""
38+
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
39+
"""
40+
ifis_apple()andis_azure():
41+
msg=f"Test does not work on azure pipeline (Apple).{msg}"
42+
returnunittest.skip(msg)
43+
returnlambdax:x
44+
45+
3246
defignore_warnings(warns:List[Warning])->Callable:
3347
"""
3448
Catches warnings.
@@ -230,6 +244,10 @@ def assertEmpty(self, value: Any):
230244
return
231245
raiseAssertionError(f"value is not empty:{value!r}.")
232246

247+
defassertExists(self,name):
248+
ifnotos.path.exists(name):
249+
raiseAssertionError(f"File or folder{name!r} does not exists.")
250+
233251
defassertHasAttr(self,cls:type,name:str):
234252
ifnothasattr(cls,name):
235253
raiseAssertionError(f"Class{cls} has no attribute{name!r}.")
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
importos
2+
importsubprocess
3+
importsys
4+
importtempfile
5+
fromtypingimportList,Optional,Tuple,Union
6+
importnumpyasnp
7+
fromonnximportModelProto
8+
9+
10+
def_find_in_PATH(prog:str)->Optional[str]:
11+
"""
12+
Looks into every path mentioned in ``%PATH%`` a specific file,
13+
it raises an exception if not found.
14+
15+
:param prog: program to look for
16+
:return: path
17+
"""
18+
sep=";"ifsys.platform.startswith("win")else":"
19+
path=os.environ["PATH"]
20+
forpinpath.split(sep):
21+
f=os.path.join(p,prog)
22+
ifos.path.exists(f):
23+
returnp
24+
returnNone
25+
26+
27+
def_find_graphviz_dot(exc:bool=True)->str:
28+
"""
29+
Determines the path to graphviz (on Windows),
30+
the function tests the existence of versions 34 to 45
31+
assuming it was installed in a standard folder:
32+
``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
33+
34+
:param exc: raise exception of be silent
35+
:return: path to dot
36+
:raises FileNotFoundError: if graphviz not found
37+
"""
38+
ifsys.platform.startswith("win"):
39+
version=list(range(34,60))
40+
version.extend([f"{v}.1"forvinversion])
41+
forvinversion:
42+
graphviz_dot=f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
43+
ifos.path.exists(graphviz_dot):
44+
returngraphviz_dot
45+
extra= ["build/update_modules/Graphviz/bin"]
46+
forextinextra:
47+
graphviz_dot=os.path.join(ext,"dot.exe")
48+
ifos.path.exists(graphviz_dot):
49+
returngraphviz_dot
50+
p=_find_in_PATH("dot.exe")
51+
ifpisNone:
52+
ifexc:
53+
raiseFileNotFoundError(
54+
f"Unable to find graphviz, look into paths such as{graphviz_dot}."
55+
)
56+
returnNone
57+
returnos.path.join(p,"dot.exe")
58+
# linux
59+
return"dot"
60+
61+
62+
def_run_subprocess(
63+
args:List[str],
64+
cwd:Optional[str]=None,
65+
):
66+
assertnotisinstance(
67+
args,str
68+
),"args should be a sequence of strings, not a string."
69+
70+
p=subprocess.Popen(
71+
args,
72+
cwd=cwd,
73+
shell=False,
74+
env=os.environ,
75+
stdout=subprocess.PIPE,
76+
stderr=subprocess.STDOUT,
77+
)
78+
raise_exception=False
79+
output=""
80+
whileTrue:
81+
output=p.stdout.readline().decode(errors="ignore")
82+
ifoutput==""andp.poll()isnotNone:
83+
break
84+
ifoutput:
85+
if (
86+
"fatal error"inoutput
87+
or"CMake Error"inoutput
88+
or"gmake: ***"inoutput
89+
or"): error C"inoutput
90+
or": error: "inoutput
91+
):
92+
raise_exception=True
93+
p.poll()
94+
p.stdout.close()
95+
ifraise_exception:
96+
raiseRuntimeError(
97+
"An error was found in the output. The build is stopped.\n{output}"
98+
)
99+
returnoutput
100+
101+
102+
def_run_graphviz(filename:str,image:str,engine:str="dot")->str:
103+
"""
104+
Run :epkg:`Graphviz`.
105+
106+
:param filename: filename which contains the graph definition
107+
:param image: output image
108+
:param engine: *dot* or *neato*
109+
:return: output of graphviz
110+
"""
111+
ext=os.path.splitext(image)[-1]
112+
assertextin {
113+
".png",
114+
".bmp",
115+
".fig",
116+
".gif",
117+
".ico",
118+
".jpg",
119+
".jpeg",
120+
".pdf",
121+
".ps",
122+
".svg",
123+
".vrml",
124+
".tif",
125+
".tiff",
126+
".wbmp",
127+
},f"Unexpected extension{ext!r} for{image!r}."
128+
ifsys.platform.startswith("win"):
129+
bin_=os.path.dirname(_find_graphviz_dot())
130+
# if bin not in os.environ["PATH"]:
131+
# os.environ["PATH"] = os.environ["PATH"] + ";" + bin
132+
exe=os.path.join(bin_,engine)
133+
else:
134+
exe=engine
135+
ifos.path.exists(image):
136+
os.remove(image)
137+
output=_run_subprocess([exe,f"-T{ext[1:]}",filename,"-o",image])
138+
assertos.path.exists(image),f"Graphviz failed due to{output}"
139+
returnoutput
140+
141+
142+
defdraw_graph_graphviz(
143+
dot:Union[str,ModelProto],
144+
image:str,
145+
engine:str="dot",
146+
)->str:
147+
"""
148+
Draws a graph using :epkg:`Graphviz`.
149+
150+
:param dot: dot graph or ModelProto
151+
:param image: output image, None, just returns the output
152+
:param engine: *dot* or *neato*
153+
:return: :epkg:`Graphviz` output or
154+
the dot text if *image* is None
155+
156+
The function creates a temporary file to store the dot file if *image* is not None.
157+
"""
158+
ifisinstance(dot,ModelProto):
159+
from .dot_plotimportto_dot
160+
161+
sdot=to_dot(dot)
162+
else:
163+
sdot=dot
164+
withtempfile.NamedTemporaryFile(delete=False)asfp:
165+
fp.write(sdot.encode("utf-8"))
166+
fp.close()
167+
168+
filename=fp.name
169+
assertos.path.exists(
170+
filename
171+
),f"File{filename!r} cannot be created to store the graph."
172+
out=_run_graphviz(filename,image,engine=engine)
173+
assertos.path.exists(
174+
image
175+
),f"Graphviz failed with no reason,{image!r} not found, output is{out}."
176+
os.remove(filename)
177+
returnout
178+
179+
180+
defplot_dot(
181+
dot:Union[str,ModelProto],
182+
ax:Optional["matplotlib.axis.Axis"]=None,# noqa: F821
183+
engine:str="dot",
184+
figsize:Optional[Tuple[int,int]]=None,
185+
)->"matplotlib.axis.Axis":# noqa: F821
186+
"""
187+
Draws a dot graph into a matplotlib graph.
188+
189+
:param dot: dot graph or ModelProto
190+
:param image: output image, None, just returns the output
191+
:param engine: *dot* or *neato*
192+
:param figsize: figsize of ax is None
193+
:return: :epkg:`Graphviz` output or
194+
the dot text if *image* is None
195+
196+
.. plot::
197+
198+
import matplotlib.pyplot as plt
199+
import onnx.parser
200+
201+
model = onnx.parser.parse_model(
202+
'''
203+
<ir_version: 8, opset_import: [ "": 18]>
204+
agraph (float[N] x) => (float[N] z) {
205+
two = Constant <value_float=2.0> ()
206+
four = Add(two, two)
207+
z = Mul(four, four)
208+
}''')
209+
ax = plot_dot(dot)
210+
ax.set_title("Dummy graph")
211+
plt.show()
212+
"""
213+
ifaxisNone:
214+
importmatplotlib.pyplotasplt
215+
216+
_,ax=plt.subplots(1,1,figsize=figsize)
217+
clean=True
218+
else:
219+
clean=False
220+
221+
fromPILimportImage
222+
223+
withtempfile.NamedTemporaryFile(suffix=".png",delete=False)asfp:
224+
fp.close()
225+
226+
draw_graph_graphviz(dot,fp.name,engine=engine)
227+
img=np.asarray(Image.open(fp.name))
228+
os.remove(fp.name)
229+
230+
ax.imshow(img)
231+
232+
ifclean:
233+
ax.get_xaxis().set_visible(False)
234+
ax.get_yaxis().set_visible(False)
235+
ax.get_figure().tight_layout()
236+
returnax

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp