- Notifications
You must be signed in to change notification settings - Fork0
Adds function to plot onnx model as graphs#61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Merged
Changes fromall commits
Commits
Show all changes
6 commits Select commitHold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Jump to file
Failed to load files.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
1 change: 1 addition & 0 deletions.gitignore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletionsCHANGELOGS.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions_unittests/ut_plotting/test_graphviz.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import os | ||
| import unittest | ||
| import onnx.parser | ||
| from onnx_array_api.ext_test_case import ( | ||
| ExtTestCase, | ||
| skipif_ci_windows, | ||
| skipif_ci_apple, | ||
| ) | ||
| from onnx_array_api.plotting.dot_plot import to_dot | ||
| from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot | ||
| class TestGraphviz(ExtTestCase): | ||
| @classmethod | ||
| def _get_graph(cls): | ||
| return onnx.parser.parse_model( | ||
| """ | ||
| <ir_version: 8, opset_import: [ "": 18]> | ||
| agraph (float[N] x) => (float[N] z) { | ||
| two = Constant <value_float=2.0> () | ||
| four = Add(two, two) | ||
| z = Mul(x, x) | ||
| }""" | ||
| ) | ||
| @skipif_ci_windows("graphviz not installed") | ||
| @skipif_ci_apple("graphviz not installed") | ||
| def test_draw_graph_graphviz(self): | ||
| fout = "test_draw_graph_graphviz.png" | ||
| dot = to_dot(self._get_graph()) | ||
| draw_graph_graphviz(dot, image=fout) | ||
| self.assertExists(os.path.exists(fout)) | ||
| @skipif_ci_windows("graphviz not installed") | ||
| @skipif_ci_apple("graphviz not installed") | ||
| def test_draw_graph_graphviz_proto(self): | ||
| fout = "test_draw_graph_graphviz_proto.png" | ||
| dot = self._get_graph() | ||
| draw_graph_graphviz(dot, image=fout) | ||
| self.assertExists(os.path.exists(fout)) | ||
| @skipif_ci_windows("graphviz not installed") | ||
| @skipif_ci_apple("graphviz not installed") | ||
| def test_plot_dot(self): | ||
| dot = to_dot(self._get_graph()) | ||
| ax = plot_dot(dot) | ||
| ax.get_figure().savefig("test_plot_dot.png") | ||
| if __name__ == "__main__": | ||
| unittest.main(verbosity=2) |
18 changes: 18 additions & 0 deletionsonnx_array_api/ext_test_case.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
236 changes: 236 additions & 0 deletionsonnx_array_api/plotting/graphviz_helper.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| import os | ||
| import subprocess | ||
| import sys | ||
| import tempfile | ||
| from typing import List, Optional, Tuple, Union | ||
| import numpy as np | ||
| from onnx import ModelProto | ||
| def _find_in_PATH(prog: str) -> Optional[str]: | ||
| """ | ||
| Looks into every path mentioned in ``%PATH%`` a specific file, | ||
| it raises an exception if not found. | ||
| :param prog: program to look for | ||
| :return: path | ||
| """ | ||
| sep = ";" if sys.platform.startswith("win") else ":" | ||
| path = os.environ["PATH"] | ||
| for p in path.split(sep): | ||
| f = os.path.join(p, prog) | ||
| if os.path.exists(f): | ||
| return p | ||
| return None | ||
| def _find_graphviz_dot(exc: bool = True) -> str: | ||
| """ | ||
| Determines the path to graphviz (on Windows), | ||
| the function tests the existence of versions 34 to 45 | ||
| assuming it was installed in a standard folder: | ||
| ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``. | ||
| :param exc: raise exception of be silent | ||
| :return: path to dot | ||
| :raises FileNotFoundError: if graphviz not found | ||
| """ | ||
| if sys.platform.startswith("win"): | ||
| version = list(range(34, 60)) | ||
| version.extend([f"{v}.1" for v in version]) | ||
| for v in version: | ||
| graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe" | ||
| if os.path.exists(graphviz_dot): | ||
| return graphviz_dot | ||
| extra = ["build/update_modules/Graphviz/bin"] | ||
| for ext in extra: | ||
| graphviz_dot = os.path.join(ext, "dot.exe") | ||
| if os.path.exists(graphviz_dot): | ||
| return graphviz_dot | ||
| p = _find_in_PATH("dot.exe") | ||
| if p is None: | ||
| if exc: | ||
| raise FileNotFoundError( | ||
| f"Unable to find graphviz, look into paths such as {graphviz_dot}." | ||
| ) | ||
| return None | ||
| return os.path.join(p, "dot.exe") | ||
| # linux | ||
| return "dot" | ||
| def _run_subprocess( | ||
| args: List[str], | ||
| cwd: Optional[str] = None, | ||
| ): | ||
| assert not isinstance( | ||
| args, str | ||
| ), "args should be a sequence of strings, not a string." | ||
| p = subprocess.Popen( | ||
| args, | ||
| cwd=cwd, | ||
| shell=False, | ||
| env=os.environ, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.STDOUT, | ||
| ) | ||
| raise_exception = False | ||
| output = "" | ||
| while True: | ||
| output = p.stdout.readline().decode(errors="ignore") | ||
| if output == "" and p.poll() is not None: | ||
| break | ||
| if output: | ||
| if ( | ||
| "fatal error" in output | ||
| or "CMake Error" in output | ||
| or "gmake: ***" in output | ||
| or "): error C" in output | ||
| or ": error: " in output | ||
| ): | ||
| raise_exception = True | ||
| p.poll() | ||
| p.stdout.close() | ||
| if raise_exception: | ||
| raise RuntimeError( | ||
| "An error was found in the output. The build is stopped.\n{output}" | ||
| ) | ||
| return output | ||
| def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: | ||
| """ | ||
| Run :epkg:`Graphviz`. | ||
| :param filename: filename which contains the graph definition | ||
| :param image: output image | ||
| :param engine: *dot* or *neato* | ||
| :return: output of graphviz | ||
| """ | ||
| ext = os.path.splitext(image)[-1] | ||
| assert ext in { | ||
| ".png", | ||
| ".bmp", | ||
| ".fig", | ||
| ".gif", | ||
| ".ico", | ||
| ".jpg", | ||
| ".jpeg", | ||
| ".pdf", | ||
| ".ps", | ||
| ".svg", | ||
| ".vrml", | ||
| ".tif", | ||
| ".tiff", | ||
| ".wbmp", | ||
| }, f"Unexpected extension {ext!r} for {image!r}." | ||
| if sys.platform.startswith("win"): | ||
| bin_ = os.path.dirname(_find_graphviz_dot()) | ||
| # if bin not in os.environ["PATH"]: | ||
| # os.environ["PATH"] = os.environ["PATH"] + ";" + bin | ||
| exe = os.path.join(bin_, engine) | ||
| else: | ||
| exe = engine | ||
| if os.path.exists(image): | ||
| os.remove(image) | ||
| output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image]) | ||
| assert os.path.exists(image), f"Graphviz failed due to {output}" | ||
| return output | ||
| def draw_graph_graphviz( | ||
| dot: Union[str, ModelProto], | ||
| image: str, | ||
| engine: str = "dot", | ||
| ) -> str: | ||
| """ | ||
| Draws a graph using :epkg:`Graphviz`. | ||
| :param dot: dot graph or ModelProto | ||
| :param image: output image, None, just returns the output | ||
| :param engine: *dot* or *neato* | ||
| :return: :epkg:`Graphviz` output or | ||
| the dot text if *image* is None | ||
| The function creates a temporary file to store the dot file if *image* is not None. | ||
| """ | ||
| if isinstance(dot, ModelProto): | ||
| from .dot_plot import to_dot | ||
| sdot = to_dot(dot) | ||
| else: | ||
| sdot = dot | ||
| with tempfile.NamedTemporaryFile(delete=False) as fp: | ||
| fp.write(sdot.encode("utf-8")) | ||
| fp.close() | ||
| filename = fp.name | ||
| assert os.path.exists( | ||
| filename | ||
| ), f"File {filename!r} cannot be created to store the graph." | ||
| out = _run_graphviz(filename, image, engine=engine) | ||
| assert os.path.exists( | ||
| image | ||
| ), f"Graphviz failed with no reason, {image!r} not found, output is {out}." | ||
| os.remove(filename) | ||
| return out | ||
| def plot_dot( | ||
| dot: Union[str, ModelProto], | ||
| ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 | ||
| engine: str = "dot", | ||
| figsize: Optional[Tuple[int, int]] = None, | ||
| ) -> "matplotlib.axis.Axis": # noqa: F821 | ||
| """ | ||
| Draws a dot graph into a matplotlib graph. | ||
| :param dot: dot graph or ModelProto | ||
| :param image: output image, None, just returns the output | ||
| :param engine: *dot* or *neato* | ||
| :param figsize: figsize of ax is None | ||
| :return: :epkg:`Graphviz` output or | ||
| the dot text if *image* is None | ||
| .. plot:: | ||
| import matplotlib.pyplot as plt | ||
| import onnx.parser | ||
| model = onnx.parser.parse_model( | ||
| ''' | ||
| <ir_version: 8, opset_import: [ "": 18]> | ||
| agraph (float[N] x) => (float[N] z) { | ||
| two = Constant <value_float=2.0> () | ||
| four = Add(two, two) | ||
| z = Mul(four, four) | ||
| }''') | ||
| ax = plot_dot(dot) | ||
| ax.set_title("Dummy graph") | ||
| plt.show() | ||
| """ | ||
| if ax is None: | ||
| import matplotlib.pyplot as plt | ||
| _, ax = plt.subplots(1, 1, figsize=figsize) | ||
| clean = True | ||
| else: | ||
| clean = False | ||
| from PIL import Image | ||
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp: | ||
| fp.close() | ||
| draw_graph_graphviz(dot, fp.name, engine=engine) | ||
| img = np.asarray(Image.open(fp.name)) | ||
| os.remove(fp.name) | ||
| ax.imshow(img) | ||
| if clean: | ||
| ax.get_xaxis().set_visible(False) | ||
| ax.get_yaxis().set_visible(False) | ||
| ax.get_figure().tight_layout() | ||
| return ax |
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.