Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Adds tools to compare models#11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 2 commits intomainfromdiff
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions.gitignore
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,3 +15,7 @@ _doc/examples/plot_*.png
_doc/_static/require.js
_doc/_static/viz.js
_unittests/ut__main/*.png
_doc/examples/data/*.optimized.onnx
_doc/examples/*.html
_unittests/ut__main/_cache/*
_unittests/ut__main/*.html
2 changes: 1 addition & 1 deletion_doc/api/index.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -11,7 +11,7 @@ API
npx_jit
npx_annot
npx_numpy
onnx_tools
ort
plotting
tools

16 changes: 16 additions & 0 deletions_doc/api/onnx_tools.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
.. _l-api-onnx-tools:

onnx tools
==========

Differences
+++++++++++

.. autofunction:: onnx_array_api.validation.diff.html_diff

.. autofunction:: onnx_array_api.validation.diff.text_diff

Protos
++++++

.. autofunction:: onnx_array_api.validation.tools.randomize_proto
7 changes: 5 additions & 2 deletions_doc/api/tools.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,6 +8,11 @@ Benchmark

.. autofunction:: onnx_array_api.ext_test_case.measure_time

Examples
++++++++

.. autofunction:: onnx_array_api.ext_test_case.example_path

Profiling
+++++++++

Expand All@@ -25,5 +30,3 @@ Unit tests

.. autoclass:: onnx_array_api.ext_test_case.ExtTestCase
:members:


Binary file added_doc/examples/data/small.onnx
View file
Open in desktop
Binary file not shown.
121 changes: 121 additions & 0 deletions_doc/examples/plot_optimization.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
"""

.. _l-onnx-array-onnxruntime-optimization:

Optimization with onnxruntime
=============================


Optimize a model with onnxruntime
+++++++++++++++++++++++++++++++++
"""
import os
from pprint import pprint
import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from onnx import load
from onnx_array_api.ext_test_case import example_path
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_array_api.validation.diff import text_diff, html_diff
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from onnx_array_api.ext_test_case import measure_time
from onnx_array_api.ort.ort_optimizers import ort_optimized_model


filename = example_path("data/small.onnx")
optimized = filename + ".optimized.onnx"

if not os.path.exists(optimized):
ort_optimized_model(filename, output=optimized)
print(optimized)

#############################
# Output comparison
# +++++++++++++++++

so = SessionOptions()
so.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)

sess = InferenceSession(filename, so)
sess_opt = InferenceSession(optimized, so)
input_name = sess.get_inputs()[0].name
out = sess.run(None, {input_name: img})[0]
out_opt = sess_opt.run(None, {input_name: img})[0]
if out.shape != out_opt.shape:
print("ERROR shape are different {out.shape} != {out_opt.shape}")
diff = numpy.abs(out - out_opt).max()
print(f"Differences: {diff}")

####################################
# Difference
# ++++++++++
#
# Unoptimized model.

with open(filename, "rb") as f:
model = load(f)
print("first model to text...")
text1 = onnx_simple_text_plot(model, indent=False)
print(text1)

#####################################
# Optimized model.


with open(optimized, "rb") as f:
model = load(f)
print("second model to text...")
text2 = onnx_simple_text_plot(model, indent=False)
print(text2)

########################################
# Differences

print("differences...")
print(text_diff(text1, text2))

#####################################
# HTML version.

print("html differences...")
output = html_diff(text1, text2)
with open("diff_html.html", "w", encoding="utf-8") as f:
f.write(output)
print("done.")

#####################################
# Benchmark
# +++++++++

img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)

t1 = measure_time(lambda: sess.run(None, {input_name: img}), repeat=25, number=25)
t1["name"] = "original"
print("Original model")
pprint(t1)

t2 = measure_time(lambda: sess_opt.run(None, {input_name: img}), repeat=25, number=25)
t2["name"] = "optimized"
print("Optimized")
pprint(t2)


############################
# Plots
# +++++


fig, ax = plt.subplots(1, 1, figsize=(12, 4))

df = DataFrame([t1, t2]).set_index("name")
print(df)

print(df["average"].values)
print((df["average"] - df["deviation"]).values)

ax.bar(df.index, df["average"].values, yerr=df["deviation"].values, capsize=6)
ax.set_title("Measure performance of optimized model\nlower is better")
plt.grid()
fig.savefig("plot_optimization.png")
Binary file added_unittests/ut_validation/data/small.onnx
View file
Open in desktop
Binary file not shown.
23 changes: 23 additions & 0 deletions_unittests/ut_validation/test_diff.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
import unittest
from onnx import load
from onnx.checker import check_model
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
from onnx_array_api.validation.diff import text_diff, html_diff


class TestDiff(ExtTestCase):
def test_diff_optimized(self):
data = self.relative_path(__file__, "data", "small.onnx")
with open(data, "rb") as f:
model = load(f)
optimized = ort_optimized_model(model)
check_model(optimized)
diff = text_diff(model, optimized)
self.assertIn("^^^^^^^^^^^^^^^^", diff)
ht = html_diff(model, optimized)
self.assertIn("<html><body>", ht)


if __name__ == "__main__":
unittest.main(verbosity=2)
20 changes: 20 additions & 0 deletions_unittests/ut_validation/test_tools.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
import unittest
from onnx import load
from onnx.checker import check_model
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.validation.tools import randomize_proto


class TestTools(ExtTestCase):
def test_randomize_proto(self):
data = self.relative_path(__file__, "data", "small.onnx")
with open(data, "rb") as f:
model = load(f)
check_model(model)
rnd = randomize_proto(model)
self.assertEqual(len(model.SerializeToString()), len(rnd.SerializeToString()))
check_model(rnd)


if __name__ == "__main__":
unittest.main(verbosity=2)
30 changes: 30 additions & 0 deletionsonnx_array_api/ext_test_case.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
import os
import sys
import unittest
import warnings
Expand DownExpand Up@@ -30,6 +31,20 @@ def call_f(self):
return wrapper


def example_path(path: str) -> str:
"""
Fixes a path for the examples.
Helps running the example within a unit test.
"""
if os.path.exists(path):
return path
this = os.path.abspath(os.path.dirname(__file__))
full = os.path.join(this, "..", "_doc", "examples", path)
if os.path.exists(full):
return full
raise FileNotFoundError(f"Unable to find path {path!r} or {full!r}.")


def measure_time(
stmt: Callable,
context: Optional[Dict[str, Any]] = None,
Expand DownExpand Up@@ -207,3 +222,18 @@ def capture(self, fct: Callable):
with redirect_stderr(serr):
res = fct()
return res, sout.getvalue(), serr.getvalue()

def relative_path(self, filename: str, *names: List[str]) -> str:
"""
Returns a path relative to the folder *filename*
is in. The function checks the path existence.

:param filename: filename
:param names: additional path pieces
:return: new path
"""
dir = os.path.abspath(os.path.dirname(filename))
name = os.path.join(dir, *names)
if not os.path.exists(name):
raise FileNotFoundError(f"Path {name!r} does not exists.")
return name
16 changes: 12 additions & 4 deletionsonnx_array_api/ort/ort_optimizers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
from typing import Union
from typing import Union, Optional
from onnx import ModelProto, load
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
from ..cache import get_cache_file


def ort_optimized_model(
onx: Union[str, ModelProto], level: str = "ORT_ENABLE_ALL"
onx: Union[str, ModelProto],
level: str = "ORT_ENABLE_ALL",
output: Optional[str] = None,
) -> Union[str, ModelProto]:
"""
Returns the optimized model used by onnxruntime before
Expand All@@ -15,6 +17,7 @@ def ort_optimized_model(
:param onx: ModelProto
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
:param output: output file if the proposed cache is not wanted
:return: optimized model
"""
glevel = getattr(GraphOptimizationLevel, level, None)
Expand All@@ -23,13 +26,18 @@ def ort_optimized_model(
f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
)

cache = get_cache_file("ort_optimized_model.onnx", remove=True)
if output is not None:
cache = output
else:
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
so = SessionOptions()
so.graph_optimization_level = glevel
so.optimized_model_filepath = str(cache)
InferenceSession(onx if isinstance(onx, str) else onx.SerializeToString(), so)
if not cache.exists():
ifoutput is None andnot cache.exists():
raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
if output is not None:
return output
if isinstance(onx, str):
return str(cache)
opt_onx = load(str(cache))
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp