Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Improves translation to GraphBuilder#95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 6 commits intomainfrombugr
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions.github/workflows/check-urls.yml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -42,6 +42,6 @@ jobs:
print_all:false
timeout:2
retry_count# : 2
exclude_urls:https://hal.archives-ouvertes.fr/hal-00990252/document
exclude_patterns:https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
exclude_urls:https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
exclude_patterns:https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
# force_pass : true
5 changes: 5 additions & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.3.1
+++++

*:pr:`95`: improves translation to GraphBuilder

0.3.0
+++++

Expand Down
67 changes: 62 additions & 5 deletions_unittests/ut_translate_api/test_translate_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,7 +8,8 @@
fromonnx_array_api.ext_test_caseimportExtTestCase
fromonnx_array_api.light_apiimportstart
fromonnx_array_api.graph_apiimportGraphBuilder
fromonnx_array_api.translate_apiimporttranslate
fromonnx_array_api.translate_apiimporttranslate,Translater
fromonnx_array_api.translate_api.builder_emitterimportBuilderEmitter


OPSET_API=min(19,onnx_opset_version()-1)
Expand All@@ -19,7 +20,7 @@ def setUp(self):
self.maxDiff=None

deftest_exp(self):
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
onx=start(opset=19,ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx,ModelProto)
self.assertIn("Exp",str(onx))
ref=ReferenceEvaluator(onx)
Expand All@@ -38,7 +39,7 @@ def light_api(
op.Identity(Y, outputs=["Y"])
return Y
g = GraphBuilder({'': 19})
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
Expand DownExpand Up@@ -68,7 +69,7 @@ def light_api(

deftest_zdoc(self):
onx= (
start(opset=19)
start(opset=19,ir_version=10)
.vin("X")
.reshape((-1,1))
.Transpose(perm=[1,0])
Expand All@@ -89,7 +90,7 @@ def light_api(
op.Identity(Y, outputs=["Y"])
return Y
g = GraphBuilder({'': 19})
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
Expand DownExpand Up@@ -117,6 +118,62 @@ def light_api(
self.assertNotEmpty(model)
check_model(model)

deftest_exp_f(self):
onx=start(opset=19,ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx,ModelProto)
self.assertIn("Exp",str(onx))
ref=ReferenceEvaluator(onx)
a=np.arange(10).astype(np.float32)
got=ref.run(None, {"X":a})[0]
self.assertEqualArray(np.exp(a),got)

tr=Translater(onx,emitter=BuilderEmitter("mm"))
code=tr.export(as_str=True)

expected=dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y
def mm() -> "ModelProto":
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
return model
model = mm()
"""
).strip("\n")
self.assertEqual(expected,code.strip("\n"))

deflight_api(
op:"GraphBuilder",
X:"FLOAT[]",# noqa: F722
):
Y=op.Exp(X)
op.Identity(Y,outputs=["Y"])
returnY

g2=GraphBuilder({"":19})
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
light_api(g2.op,"X")
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
onx2=g2.to_onnx()

ref=ReferenceEvaluator(onx2)
a=np.arange(10).astype(np.float32)
got=ref.run(None, {"X":a})[0]
self.assertEqualArray(np.exp(a),got)


if__name__=="__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletiononnx_array_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -2,5 +2,5 @@
APIs to create ONNX Graphs.
"""

__version__="0.3.0"
__version__="0.3.1"
__author__="Xavier Dupré"
60 changes: 52 additions & 8 deletionsonnx_array_api/translate_api/builder_emitter.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,10 +4,17 @@
from .base_emitterimportBaseEmitter

_types= {
TensorProto.DOUBLE:"DOUBLE",
TensorProto.FLOAT:"FLOAT",
TensorProto.FLOAT16:"FLOAT16",
TensorProto.INT64:"INT64",
TensorProto.INT32:"INT32",
TensorProto.INT16:"INT16",
TensorProto.UINT64:"UINT64",
TensorProto.UINT32:"UINT32",
TensorProto.UINT16:"UINT16",
TensorProto.STRING:"STRING",
TensorProto.BOOL:"BOOL",
}


Expand All@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
Converts event into proper code.
"""

def__init__(self,make_model_function:str=""):
super().__init__()
self.make_model_function=make_model_function

defjoin(self,rows:List[str],single_line:bool=False)->str:
"Join the rows"
assert (
Expand All@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:

def_emit_start(self,**kwargs:Dict[str,Any])->List[str]:
self.opsets=kwargs.get("opsets", {})
self.ir_version=kwargs.get("ir_version",None)
return []

def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
Expand All@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
)
rows= [
"",
f"g = GraphBuilder({self.opsets})",
(
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
ifself.ir_version
elsef"GraphBuilder({self.opsets})"
),
*inputs,
f"{self.name}({inps})",
*outputs,
"model = g.to_onnx()",
]
ifself.make_model_function:
rows= [
"",
"",
f'def{self.make_model_function}() -> "ModelProto":',
*[" "+_for_inrows[1:]],
" return model",
"",
"",
f"model ={self.make_model_function}()",
]
returnrows

def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]:
Expand DownExpand Up@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
name=kwargs["name"]
itype=kwargs.get("elem_type",0)
shape=kwargs.get("shape",None)
name=self._clean_result_name(name)
ifitype==0:
inp="X"
inp=nameor"X"
else:
ifshapeisNone:
inp=f'X: "{_itype_to_string(itype)}"'
inp=f'{name}: "{_itype_to_string(itype)}"'
else:
inp=f'X: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
inp= (
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
)
self.inputs_full.append(inp)
self.inputs.append(name)
self.inputs_full_.append((name,_itype_to_string(itype),shape))
Expand DownExpand Up@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:

def_emit_output(self,**kwargs:Dict[str,Any])->List[str]:
name=kwargs["name"]
name=self._clean_result_name(name)
itype=kwargs.get("elem_type",0)
shape=kwargs.get("shape",None)
self.outputs.append(name)
Expand All@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
ifkwargs.get("domain","")!="":
domain=kwargs["domain"]
op_type=f"{domain}.{op_type}"
else:
domain=""
atts=kwargs.get("atts", {})
args= []
fork,vinatts.items():
Expand All@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
raiseNotImplementedError("Graph attribute not supported yet.")
args.append(f"{k}={vatt}")

outs=", ".join(outputs)
inps=", ".join(inputs)
outs=", ".join(map(self._clean_result_name,outputs))
inps=", ".join(map(self._clean_result_name,inputs))
op_type=self._emit_node_type(op_type,domain)
sdomain=""ifnotdomainelsef", domain={domain!r}"
ifargs:
sargs=", ".join(args)
row=f"{outs} = op.{op_type}({inps},{sargs})"
ifinps:
row=f"{outs} = op.{op_type}({inps},{sargs}{sdomain})"
else:
row=f"{outs} = op.{op_type}({sargs}{sdomain})"
else:
row=f"{outs} = op.{op_type}({inps})"
row=f"{outs} = op.{op_type}({inps}{sdomain})"
return [row]

def_clean_result_name(self,name):
returnname

def_emit_node_type(self,op_type,domain):
returnop_type
6 changes: 5 additions & 1 deletiononnx_array_api/translate_api/translate.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
last_event=None
ifisinstance(self.proto_,ModelProto):
opsets= {d.domain:d.versionfordinself.proto_.opset_import}
rows.extend(self.emitter(EventType.START,opsets=opsets))
rows.extend(
self.emitter(
EventType.START,opsets=opsets,ir_version=self.proto_.ir_version
)
)
inputs=self.proto_.graph.input
outputs=self.proto_.graph.output
nodes=self.proto_.graph.node
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp