Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit032aff5

Browse files
committed
2 parents4c12efd +a906010 commit032aff5

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`77`: supports ConcatOfShape and Slice with the light API
78
*:pr:`76`: add a mode to compare models without execution
89
*:pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
*:pr:`71`: adds tools to compare two onnx graphs

‎_unittests/ut_light_api/test_light_api.py‎

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
importunittest
33
fromtypingimportCallable,Optional
44
importnumpyasnp
5-
fromonnximportGraphProto,ModelProto
5+
fromonnximportGraphProto,ModelProto,TensorProto
66
fromonnx.defsimport (
77
get_all_schemas_with_history,
88
onnx_opset_version,
@@ -526,7 +526,47 @@ def test_input_shape(self):
526526
i=str(model.graph.input[0]).replace("\n","").replace(" ","")
527527
self.assertNotIn("shape{}",i)
528528

529+
deftest_constant_of_shape(self):
530+
onx= (
531+
start()
532+
.vin("X",TensorProto.INT64,shape=[None,None])
533+
.ConstantOfShape()
534+
.vout(shape=[])
535+
.to_onnx()
536+
)
537+
ref=ReferenceEvaluator(onx)
538+
got=ref.run(None, {"X":np.array([2,3],dtype=np.int64)})[0]
539+
self.assertEqualArray(np.zeros((2,3),dtype=np.float32),got)
540+
541+
deftest_constant_of_shape_value(self):
542+
onx= (
543+
start()
544+
.vin("X",TensorProto.INT64,shape=[None,None])
545+
.ConstantOfShape(value=np.array([1],dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref=ReferenceEvaluator(onx)
550+
got=ref.run(None, {"X":np.array([2,3],dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2,3),dtype=np.float32),got)
552+
553+
deftest_slice(self):
554+
onx= (
555+
start(opset=18,ir_version=9)
556+
.cst(np.array([1],dtype=np.int64),name="one")
557+
.cst(np.array([2],dtype=np.int64),name="two")
558+
.vin("X",TensorProto.INT64,shape=[None,None])
559+
.ConstantOfShape(value=np.array([1],dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX","one","two","one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref=ReferenceEvaluator(onx)
567+
got=ref.run(None, {"X":np.array([2,3],dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2,1),dtype=np.float32),got)
569+
529570

530571
if__name__=="__main__":
531-
TestLightApi().test_add()
532572
unittest.main(verbosity=2)

‎onnx_array_api/light_api/__init__.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
defstart(
99
opset:Optional[int]=None,
1010
opsets:Optional[Dict[str,int]]=None,
11+
ir_version:Optional[int]=None,
1112
)->OnnxGraph:
1213
"""
1314
Starts an onnx model.
1415
1516
:param opset: main opset version
1617
:param opsets: others opsets as a dictionary
18+
:param ir_version: specify the ir_version as well
1719
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
1820
1921
A very simple model:
@@ -45,7 +47,7 @@ def start(
4547
)
4648
print(onx)
4749
"""
48-
returnOnnxGraph(opset=opset,opsets=opsets)
50+
returnOnnxGraph(opset=opset,opsets=opsets,ir_version=ir_version)
4951

5052

5153
defg()->OnnxGraph:

‎onnx_array_api/light_api/_op_var.py‎

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
fromtypingimportList,Optional,Union
2+
importnumpyasnp
3+
from ..referenceimportfrom_array_extended
24
from ..annotationsimportAI_ONNX_ML,domain
35

46

@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
6971
defCelu(self,alpha:float=1.0)->"Var":
7072
returnself.make_node("Celu",self,alpha=alpha)
7173

74+
defConstantOfShape(self,value:Optional[np.array]=None)->"Var":
75+
ifvalueisNone:
76+
returnself.make_node("ConstantOfShape",self)
77+
returnself.make_node("ConstantOfShape",self,value=from_array_extended(value))
78+
7279
defDepthToSpace(self,blocksize:int=0,mode:str="DCR")->"Var":
7380
returnself.make_node("DepthToSpace",self,blocksize=blocksize,mode=mode)
7481

@@ -307,6 +314,13 @@ def Selu(
307314
defShrink(self,bias:float=0.0,lambd:float=0.5)->"Var":
308315
returnself.make_node("Shrink",self,bias=bias,lambd=lambd)
309316

317+
defSlice(
318+
self,starts:"Var",ends:"Var",axes:"Var",steps:Optional["Var"]=None
319+
)->"Var":
320+
ifstepsisNone:
321+
returnself.make_node("Slice",self,starts,ends,axes)
322+
returnself.make_node("Slice",self,starts,ends,axes,steps)
323+
310324
defSoftmax(self,axis:int=-1)->"Var":
311325
returnself.make_node("Softmax",self,axis=axis)
312326

‎onnx_array_api/light_api/model.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class OnnxGraph:
4242
4343
:param opset: main opset version
4444
:param opsets: other opsets as a dictionary
45+
:param ir_version: to specify an ir_version
4546
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
4647
"""
4748

4849
def__init__(
4950
self,
5051
opset:Optional[int]=None,
5152
opsets:Optional[Dict[str,int]]=None,
53+
ir_version:Optional[int]=None,
5254
proto_type:ProtoType=ProtoType.MODEL,
5355
):
5456
ifopsetsisnotNoneand""inopsets:
@@ -65,6 +67,7 @@ def __init__(
6567
self.proto_type=proto_type
6668
self.opsets=opsets
6769
self.opset=opset
70+
self.ir_version=ir_version
6871
self.nodes:List[Union[NodeProto,TensorProto]]= []
6972
self.inputs:List[ValueInfoProto]= []
7073
self.outputs:List[ValueInfoProto]= []
@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
402405
# If no opsets, it a subgraph, not a model.
403406
returngraph
404407
model=make_model(graph,opset_imports=opsets)
408+
ifself.ir_version:
409+
model.ir_version=self.ir_version
405410
ifnotis_windows()ornotis_azure():
406411
# check_model fails sometimes on Windows
407412
check_model(model)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp