Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add ConstantOfShape to light API#77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 11 commits intomainfromlig2
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
11 commits
Select commitHold shift + click to select a range
6eb6adf
update requirements
xadupreFeb 2, 2024
4f0a994
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 2, 2024
c7bb055
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 2, 2024
a3d4ccf
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 6, 2024
7ed1385
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 6, 2024
bab2a6b
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 7, 2024
014404b
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 7, 2024
00e2a1c
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 14, 2024
d1aff97
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 15, 2024
4c12efd
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 22, 2024
ee55645
Add ConstantOfShape to light API
xadupreFeb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion_unittests/ut_light_api/test_light_api.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -2,7 +2,7 @@
import unittest
from typing import Callable, Optional
import numpy as np
from onnx import GraphProto, ModelProto
from onnx import GraphProto, ModelProto, TensorProto
from onnx.defs import (
get_all_schemas_with_history,
onnx_opset_version,
Expand DownExpand Up@@ -526,6 +526,18 @@ def test_input_shape(self):
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
self.assertNotIn("shape{}", i)

def test_constant_of_shape(self):
onx = (
start()
.vin("X", TensorProto.INT64, shape=[None, None])
.ConstantOfShape()
.vout(shape=[])
.to_onnx()
)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)


if __name__ == "__main__":
TestLightApi().test_add()
Expand Down
4 changes: 3 additions & 1 deletiononnx_array_api/light_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,12 +8,14 @@
def start(
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
ir_version: Optional[int] = None,
) -> OnnxGraph:
"""
Starts an onnx model.

:param opset: main opset version
:param opsets: others opsets as a dictionary
:param ir_version: specify the ir_version as well
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`

A very simple model:
Expand DownExpand Up@@ -45,7 +47,7 @@ def start(
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets)
return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)


def g() -> OnnxGraph:
Expand Down
7 changes: 7 additions & 0 deletionsonnx_array_api/light_api/_op_var.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Union
import numpy as np
from ..reference import from_array_extended
from ..annotations import AI_ONNX_ML, domain


Expand DownExpand Up@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
def Celu(self, alpha: float = 1.0) -> "Var":
return self.make_node("Celu", self, alpha=alpha)

def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
if value is None:
return self.make_node("ConstantOfShape", self)
return self.make_node("ConstantOfShape", self, value=from_array_extended(value))

def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)

Expand Down
5 changes: 5 additions & 0 deletionsonnx_array_api/light_api/model.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -42,13 +42,15 @@ class OnnxGraph:

:param opset: main opset version
:param opsets: other opsets as a dictionary
:param ir_version: to specify an ir_version
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
"""

def __init__(
self,
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
ir_version: Optional[int] = None,
proto_type: ProtoType = ProtoType.MODEL,
):
if opsets is not None and "" in opsets:
Expand All@@ -65,6 +67,7 @@ def __init__(
self.proto_type = proto_type
self.opsets = opsets
self.opset = opset
self.ir_version = ir_version
self.nodes: List[Union[NodeProto, TensorProto]] = []
self.inputs: List[ValueInfoProto] = []
self.outputs: List[ValueInfoProto] = []
Expand DownExpand Up@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
# If no opsets, it a subgraph, not a model.
return graph
model = make_model(graph, opset_imports=opsets)
if self.ir_version:
model.ir_version = ir_version
if not is_windows() or not is_azure():
# check_model fails sometimes on Windows
check_model(model)
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp