|
| 1 | +fromtypingimportAny,Dict,List |
| 2 | +fromonnximportTensorProto |
| 3 | +from .base_emitterimportBaseEmitter |
| 4 | + |
| 5 | +_types= { |
| 6 | +TensorProto.FLOAT:"FLOAT", |
| 7 | +TensorProto.FLOAT16:"FLOAT16", |
| 8 | +TensorProto.INT64:"INT64", |
| 9 | +TensorProto.INT32:"INT32", |
| 10 | +} |
| 11 | + |
| 12 | + |
| 13 | +def_itype_to_string(itype:int)->str: |
| 14 | +return_types[itype] |
| 15 | + |
| 16 | + |
| 17 | +classBuilderEmitter(BaseEmitter): |
| 18 | +""" |
| 19 | + Converts event into proper code. |
| 20 | + """ |
| 21 | + |
| 22 | +defjoin(self,rows:List[str],single_line:bool=False)->str: |
| 23 | +"Join the rows" |
| 24 | +assert ( |
| 25 | +notsingle_line |
| 26 | + ),f"The emitter{type(self)} does not work with single_line=True." |
| 27 | +return"\n".join(rows) |
| 28 | + |
| 29 | +def_emit_start(self,**kwargs:Dict[str,Any])->List[str]: |
| 30 | +self.opsets=kwargs.get("opsets", {}) |
| 31 | +return [] |
| 32 | + |
| 33 | +def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]: |
| 34 | +inps=", ".join(["g.op",*self.inputs]) |
| 35 | +inputs= [] |
| 36 | +forinp,stype,shapeinself.inputs_full_: |
| 37 | +inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype},{shape})') |
| 38 | +outputs= [] |
| 39 | +forinp,stype,shapeinself.outputs_full_: |
| 40 | +outputs.append( |
| 41 | +f'g.make_tensor_output("{inp}", TensorProto.{stype},{shape})' |
| 42 | + ) |
| 43 | +rows= [ |
| 44 | +"", |
| 45 | +f"g = GraphBuilder({self.opsets})", |
| 46 | +*inputs, |
| 47 | +f"{self.name}({inps})", |
| 48 | +*outputs, |
| 49 | +"model = g.to_onnx()", |
| 50 | + ] |
| 51 | +returnrows |
| 52 | + |
| 53 | +def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]: |
| 54 | +self.inputs= [] |
| 55 | +self.inputs_full= [] |
| 56 | +self.outputs= [] |
| 57 | +self.inits= [] |
| 58 | +self.inputs_full_= [] |
| 59 | +self.outputs_full_= [] |
| 60 | +self.name=kwargs.get("name","make_graph") |
| 61 | +return [] |
| 62 | + |
| 63 | +def_emit_end_graph(self,**kwargs:Dict[str,Any])->List[str]: |
| 64 | +return [] |
| 65 | + |
| 66 | +def_emit_initializer(self,**kwargs:Dict[str,Any])->List[str]: |
| 67 | +assertFalse,f"not implemented yet with{kwargs}" |
| 68 | + |
| 69 | +def_emit_input(self,**kwargs:Dict[str,Any])->List[str]: |
| 70 | +name=kwargs["name"] |
| 71 | +itype=kwargs.get("elem_type",0) |
| 72 | +shape=kwargs.get("shape",None) |
| 73 | +ifitype==0: |
| 74 | +inp="X" |
| 75 | +else: |
| 76 | +ifshapeisNone: |
| 77 | +inp=f'X: "{_itype_to_string(itype)}"' |
| 78 | +else: |
| 79 | +inp=f'X: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"' |
| 80 | +self.inputs_full.append(inp) |
| 81 | +self.inputs.append(name) |
| 82 | +self.inputs_full_.append((name,_itype_to_string(itype),shape)) |
| 83 | +return [] |
| 84 | + |
| 85 | +def_emit_begin_signature(self,**kwargs:Dict[str,Any])->List[str]: |
| 86 | +return [] |
| 87 | + |
| 88 | +def_emit_end_signature(self,**kwargs:Dict[str,Any])->List[str]: |
| 89 | +rows= ["",f"def{self.name}(",' op: "GraphBuilder",'] |
| 90 | +foriinself.inputs_full: |
| 91 | +rows.append(f"{i},") |
| 92 | +rows.append("):") |
| 93 | +returnrows |
| 94 | + |
| 95 | +def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]: |
| 96 | +return [] |
| 97 | + |
| 98 | +def_emit_end_return(self,**kwargs:Dict[str,Any])->List[str]: |
| 99 | +outs=", ".join(self.outputs) |
| 100 | +return [f" return{outs}"] |
| 101 | + |
| 102 | +def_emit_output(self,**kwargs:Dict[str,Any])->List[str]: |
| 103 | +name=kwargs["name"] |
| 104 | +itype=kwargs.get("elem_type",0) |
| 105 | +shape=kwargs.get("shape",None) |
| 106 | +self.outputs.append(name) |
| 107 | +self.outputs_full_.append((name,_itype_to_string(itype),shape)) |
| 108 | +return [f' op.Identity({name}, outputs=["{name}"])'] |
| 109 | + |
| 110 | +def_emit_node(self,**kwargs:Dict[str,Any])->List[str]: |
| 111 | +op_type=kwargs["op_type"] |
| 112 | +inputs=kwargs["inputs"] |
| 113 | +outputs=kwargs["outputs"] |
| 114 | +ifkwargs.get("domain","")!="": |
| 115 | +domain=kwargs["domain"] |
| 116 | +op_type=f"{domain}.{op_type}" |
| 117 | +atts=kwargs.get("atts", {}) |
| 118 | +args= [] |
| 119 | +fork,vinatts.items(): |
| 120 | +before,vatt=self.render_attribute_value(v) |
| 121 | +ifbefore: |
| 122 | +raiseNotImplementedError("Graph attribute not supported yet.") |
| 123 | +args.append(f"{k}={vatt}") |
| 124 | + |
| 125 | +outs=", ".join(outputs) |
| 126 | +inps=", ".join(inputs) |
| 127 | +ifargs: |
| 128 | +sargs=", ".join(args) |
| 129 | +row=f"{outs} = op.{op_type}({inps},{sargs})" |
| 130 | +else: |
| 131 | +row=f"{outs} = op.{op_type}({inps})" |
| 132 | +return [row] |