44from .base_emitter import BaseEmitter
55
66_types = {
7+ TensorProto .DOUBLE :"DOUBLE" ,
78TensorProto .FLOAT :"FLOAT" ,
89TensorProto .FLOAT16 :"FLOAT16" ,
910TensorProto .INT64 :"INT64" ,
1011TensorProto .INT32 :"INT32" ,
12+ TensorProto .INT16 :"INT16" ,
13+ TensorProto .UINT64 :"UINT64" ,
14+ TensorProto .UINT32 :"UINT32" ,
15+ TensorProto .UINT16 :"UINT16" ,
16+ TensorProto .STRING :"STRING" ,
17+ TensorProto .BOOL :"BOOL" ,
1118}
1219
1320
@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
2027 Converts event into proper code.
2128 """
2229
30+ def __init__ (self ,make_model_function :str = "" ):
31+ super ().__init__ ()
32+ self .make_model_function = make_model_function
33+
2334def join (self ,rows :List [str ],single_line :bool = False )-> str :
2435"Join the rows"
2536assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2940
3041def _emit_start (self ,** kwargs :Dict [str ,Any ])-> List [str ]:
3142self .opsets = kwargs .get ("opsets" , {})
43+ self .ir_version = kwargs .get ("ir_version" ,None )
3244return []
3345
3446def _emit_to_onnx_model (self ,** kwargs :Dict [str ,Any ])-> List [str ]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4355 )
4456rows = [
4557"" ,
46- f"g = GraphBuilder({ self .opsets } )" ,
58+ (
59+ f"g = GraphBuilder({ self .opsets } , ir_version={ self .ir_version } )"
60+ if self .ir_version
61+ else f"GraphBuilder({ self .opsets } )"
62+ ),
4763* inputs ,
4864f"{ self .name } ({ inps } )" ,
4965* outputs ,
5066"model = g.to_onnx()" ,
5167 ]
68+ if self .make_model_function :
69+ rows = [
70+ "" ,
71+ "" ,
72+ f'def{ self .make_model_function } () -> "ModelProto":' ,
73+ * [" " + _ for _ in rows [1 :]],
74+ " return model" ,
75+ "" ,
76+ "" ,
77+ f"model ={ self .make_model_function } ()" ,
78+ ]
5279return rows
5380
5481def _emit_begin_graph (self ,** kwargs :Dict [str ,Any ])-> List [str ]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78105name = kwargs ["name" ]
79106itype = kwargs .get ("elem_type" ,0 )
80107shape = kwargs .get ("shape" ,None )
108+ name = self ._clean_result_name (name )
81109if itype == 0 :
82- inp = "X"
110+ inp = name or "X"
83111else :
84112if shape is None :
85- inp = f'X : "{ _itype_to_string (itype )} "'
113+ inp = f'{ name } : "{ _itype_to_string (itype )} "'
86114else :
87- inp = f'X: "{ _itype_to_string (itype )} [{ ", " .join (map (str ,shape ))} ]"'
115+ inp = (
116+ f'{ name } : "{ _itype_to_string (itype )} [{ ", " .join (map (str ,shape ))} ]"'
117+ )
88118self .inputs_full .append (inp )
89119self .inputs .append (name )
90120self .inputs_full_ .append ((name ,_itype_to_string (itype ),shape ))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113143
114144def _emit_output (self ,** kwargs :Dict [str ,Any ])-> List [str ]:
115145name = kwargs ["name" ]
146+ name = self ._clean_result_name (name )
116147itype = kwargs .get ("elem_type" ,0 )
117148shape = kwargs .get ("shape" ,None )
118149self .outputs .append (name )
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126157if kwargs .get ("domain" ,"" )!= "" :
127158domain = kwargs ["domain" ]
128159op_type = f"{ domain } .{ op_type } "
160+ else :
161+ domain = ""
129162atts = kwargs .get ("atts" , {})
130163args = []
131164for k ,v in atts .items ():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134167raise NotImplementedError ("Graph attribute not supported yet." )
135168args .append (f"{ k } ={ vatt } " )
136169
137- outs = ", " .join (outputs )
138- inps = ", " .join (inputs )
170+ outs = ", " .join (map (self ._clean_result_name ,outputs ))
171+ inps = ", " .join (map (self ._clean_result_name ,inputs ))
172+ op_type = self ._emit_node_type (op_type ,domain )
173+ sdomain = "" if not domain else f", domain={ domain !r} "
139174if args :
140175sargs = ", " .join (args )
141- row = f"{ outs } = op.{ op_type } ({ inps } ,{ sargs } )"
176+ if inps :
177+ row = f"{ outs } = op.{ op_type } ({ inps } ,{ sargs } { sdomain } )"
178+ else :
179+ row = f"{ outs } = op.{ op_type } ({ sargs } { sdomain } )"
142180else :
143- row = f"{ outs } = op.{ op_type } ({ inps } )"
181+ row = f"{ outs } = op.{ op_type } ({ inps } { sdomain } )"
144182return [row ]
183+
184+ def _clean_result_name (self ,name ):
185+ return name
186+
187+ def _emit_node_type (self ,op_type ,domain ):
188+ return op_type