|
| 1 | +importunittest |
| 2 | +importnumpyasnp |
| 3 | +importonnx |
| 4 | +importonnx.helperasoh |
| 5 | +importonnx.numpy_helperasonh |
| 6 | +fromonnximportTensorProto |
| 7 | +fromonnx_array_api.ext_test_caseimportExtTestCase |
| 8 | +fromonnx_array_api.referenceimport ( |
| 9 | +ExtendedReferenceEvaluatorasReferenceEvaluator, |
| 10 | +) |
| 11 | +fromonnx_array_api.tools.replace_constantsimport ( |
| 12 | +replace_initializer_by_constant_of_shape, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +classTestReplaceConstants(ExtTestCase): |
| 17 | + |
| 18 | +deftest_replace_initializer(self): |
| 19 | +dtype=np.float32 |
| 20 | +value=np.random.randn(2,100).astype(dtype) |
| 21 | +A=onh.from_array(value,name="A") |
| 22 | +value=np.array([1],dtype=dtype) |
| 23 | +C=onh.from_array(value,name="C") |
| 24 | + |
| 25 | +X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None]) |
| 26 | +Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None]) |
| 27 | +node1=oh.make_node("MatMul", ["X","A"], ["AX"]) |
| 28 | +node2=oh.make_node("Sub", ["AX","C"], ["Y"]) |
| 29 | +graph=oh.make_graph([node1,node2],"lr", [X], [Y], [A,C]) |
| 30 | +model_def=oh.make_model(graph) |
| 31 | + |
| 32 | +x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2)) |
| 33 | +oinf1=ReferenceEvaluator(model_def) |
| 34 | +y1=oinf1.run(None, {"X":x})[0]# type: ignore[index] |
| 35 | +repl=replace_initializer_by_constant_of_shape(model_def) |
| 36 | +node_types= {n.op_typeforninrepl.graph.node} |
| 37 | +self.assertIn("ConstantOfShape",node_types) |
| 38 | +oinf2=ReferenceEvaluator(repl) |
| 39 | +y1[:, :]=3.5 |
| 40 | +y1[0, :]=0.5 |
| 41 | +y2=oinf2.run(None, {"X":x})[0]# type: ignore[index] |
| 42 | +self.assertEqualArray(y1,y2) |
| 43 | + |
| 44 | +deftest_replace_constant(self): |
| 45 | +dtype=np.float32 |
| 46 | +value=np.random.randn(2,10).astype(dtype) |
| 47 | +A=onh.from_array(value,name="A") |
| 48 | +value=np.array([1],dtype=dtype) |
| 49 | +C=onh.from_array(value,name="C") |
| 50 | + |
| 51 | +X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None]) |
| 52 | +Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None]) |
| 53 | +node0=oh.make_node("Constant", [], ["A"],value=A) |
| 54 | +node1=oh.make_node("MatMul", ["X","A"], ["AX"]) |
| 55 | +node2=oh.make_node("Sub", ["AX","C"], ["Y"]) |
| 56 | +graph=oh.make_graph([node0,node1,node2],"lr", [X], [Y], [C]) |
| 57 | +model_def=oh.make_model(graph) |
| 58 | + |
| 59 | +x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2)) |
| 60 | +oinf1=ReferenceEvaluator(model_def) |
| 61 | +y1=oinf1.run(None, {"X":x})[0]# type: ignore[index] |
| 62 | +repl=replace_initializer_by_constant_of_shape(model_def,threshold=0) |
| 63 | +node_types= {n.op_typeforninrepl.graph.node} |
| 64 | +self.assertIn("ConstantOfShape",node_types) |
| 65 | +oinf2=ReferenceEvaluator(repl) |
| 66 | +y1[:, :]=4 |
| 67 | +y1[0, :]=1 |
| 68 | +y2=oinf2.run(None, {"X":x})[0]# type: ignore[index] |
| 69 | +self.assertEqualArray(y1,y2) |
| 70 | + |
| 71 | +deftest_replace_constant_function(self): |
| 72 | +dtype=np.float32 |
| 73 | +value=np.random.randn(2,100).astype(dtype) |
| 74 | +A=onh.from_array(value,name="A") |
| 75 | +value=np.array([1],dtype=dtype) |
| 76 | +C=onh.from_array(value,name="C") |
| 77 | + |
| 78 | +X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None]) |
| 79 | +Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None]) |
| 80 | +nodeC=oh.make_node("Constant", [], ["C"],value=C) |
| 81 | +node0=oh.make_node("Constant", [], ["A"],value=A) |
| 82 | +node1=oh.make_node("MatMul", ["X","A"], ["AX"]) |
| 83 | +node2=oh.make_node("Sub", ["AX","C"], ["Y"]) |
| 84 | +opset_imports= [ |
| 85 | +oh.make_opsetid("",onnx.defs.onnx_opset_version()), |
| 86 | +oh.make_opsetid("custom",1), |
| 87 | + ] |
| 88 | +fct=oh.make_function( |
| 89 | +"custom", |
| 90 | +"unittest", |
| 91 | + ["X"], |
| 92 | + ["Y"], |
| 93 | + [nodeC,node0,node1,node2], |
| 94 | +opset_imports, |
| 95 | + ) |
| 96 | + |
| 97 | +node=oh.make_node("unittest", ["X"], ["Y"],domain="custom") |
| 98 | +graph=oh.make_graph([node],"lr", [X], [Y], [C]) |
| 99 | +model_def=oh.make_model(graph,functions=[fct],opset_imports=opset_imports) |
| 100 | + |
| 101 | +x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2)) |
| 102 | +oinf1=ReferenceEvaluator(model_def) |
| 103 | +y1=oinf1.run(None, {"X":x})[0]# type: ignore[index] |
| 104 | +repl=replace_initializer_by_constant_of_shape(model_def) |
| 105 | +node_types= {n.op_typeforninrepl.functions[0].node} |
| 106 | +self.assertIn("ConstantOfShape",node_types) |
| 107 | +oinf2=ReferenceEvaluator(repl) |
| 108 | +y1[:, :]=3.5 |
| 109 | +y1[0, :]=0.5 |
| 110 | +y2=oinf2.run(None, {"X":x})[0]# type: ignore[index] |
| 111 | +self.assertEqualArray(y1,y2) |
| 112 | + |
| 113 | +deftest_replace_constant_graph(self): |
| 114 | +value=np.array([0],dtype=np.float32) |
| 115 | +zero=onh.from_array(value,name="zero") |
| 116 | + |
| 117 | +X=oh.make_tensor_value_info("X",onnx.TensorProto.FLOAT, [None,None]) |
| 118 | +Y=oh.make_tensor_value_info("Y",onnx.TensorProto.FLOAT, [None]) |
| 119 | + |
| 120 | +rsum=oh.make_node("ReduceSum", ["X"], ["rsum"]) |
| 121 | +cond=oh.make_node("Greater", ["rsum","zero"], ["cond"]) |
| 122 | + |
| 123 | +then_out=oh.make_tensor_value_info("then_out",onnx.TensorProto.FLOAT,None) |
| 124 | +then_cst=onh.from_array(np.array([1]*129).astype(np.float32)) |
| 125 | + |
| 126 | +then_const_node=oh.make_node( |
| 127 | +"Constant",inputs=[],outputs=["then_out"],value=then_cst,name="cst1" |
| 128 | + ) |
| 129 | +then_body=oh.make_graph([then_const_node],"then_body", [], [then_out]) |
| 130 | + |
| 131 | +else_out=oh.make_tensor_value_info("else_out",onnx.TensorProto.FLOAT,None) |
| 132 | +else_cst=onh.from_array(np.array([-1]*129).astype(np.float32)) |
| 133 | +else_const_node=oh.make_node( |
| 134 | +"Constant",inputs=[],outputs=["else_out"],value=else_cst,name="cst2" |
| 135 | + ) |
| 136 | +else_body=oh.make_graph([else_const_node],"else_body", [], [else_out]) |
| 137 | + |
| 138 | +if_node=oh.make_node( |
| 139 | +"If", ["cond"], ["Y"],then_branch=then_body,else_branch=else_body |
| 140 | + ) |
| 141 | +graph=oh.make_graph([rsum,cond,if_node],"if", [X], [Y], [zero]) |
| 142 | +onnx_model=oh.make_model( |
| 143 | +graph,opset_imports=[oh.make_opsetid("",onnx.defs.onnx_opset_version())] |
| 144 | + ) |
| 145 | +self.assertNotIn("ConstantOfShape",str(onnx_model)) |
| 146 | + |
| 147 | +x=np.ones((3,2),dtype=np.float32) |
| 148 | +oinf1=ReferenceEvaluator(onnx_model) |
| 149 | +y1=oinf1.run(None, {"X":x})[0]# type: ignore[index] |
| 150 | +repl=replace_initializer_by_constant_of_shape(onnx_model) |
| 151 | +self.assertIn("ConstantOfShape",str(repl)) |
| 152 | +oinf2=ReferenceEvaluator(repl) |
| 153 | +y2=oinf2.run(None, {"X":x})[0]# type: ignore[index] |
| 154 | +y1=y1.copy() |
| 155 | +y1[:]=0.5 |
| 156 | +self.assertEqualArray(y1,y2) |
| 157 | + |
| 158 | + |
| 159 | +if__name__=="__main__": |
| 160 | +unittest.main(verbosity=2) |