|
3 | 3 | importunittest |
4 | 4 | importnumpyasnp |
5 | 5 | importonnx |
6 | | -fromonnx.referenceimportReferenceEvaluator |
7 | 6 | fromonnx_array_api.ext_test_caseimportExtTestCase |
8 | | -fromonnx_array_api.graph_api.graph_builderimportGraphBuilder |
| 7 | +fromonnx_array_api.graph_api.graph_builderimportGraphBuilder,OptimizationOptions |
| 8 | +fromonnx_array_api.referenceimport ( |
| 9 | +from_array_extended, |
| 10 | +ExtendedReferenceEvaluatorasReferenceEvaluator, |
| 11 | +) |
9 | 12 |
|
10 | 13 |
|
11 | 14 | classTestGraphBuilder(ExtTestCase): |
@@ -130,6 +133,35 @@ def test_constant_folding(self): |
130 | 133 | got=ref.run(None,feeds) |
131 | 134 | self.assertEqualArray(expected,got[0]) |
132 | 135 |
|
| 136 | +deftest_constant_folding2(self): |
| 137 | +g=GraphBuilder( |
| 138 | +optimization_options=OptimizationOptions(constant_folding=True) |
| 139 | + ) |
| 140 | + |
| 141 | +shape= (10,4) |
| 142 | +w=np.random.randn(*shape).astype(np.float32) |
| 143 | +x=g.make_tensor_input("X",np.float32,shape) |
| 144 | +weight=g.make_initializer(w) |
| 145 | +cst=g.get_constant(weight) |
| 146 | +self.assertEqualArray(w,cst) |
| 147 | +one=g.make_initializer(np.array([-1,1],dtype=np.int64)) |
| 148 | +transposed=g.make_node("Transpose", [weight],perm=[1,0]) |
| 149 | +res=g.op.MatMul(x,transposed) |
| 150 | +g.op.Reshape(res,one,outputs="y") |
| 151 | +g.make_tensor_output("y",np.float32, (10,1)) |
| 152 | + |
| 153 | +g.optimize() |
| 154 | + |
| 155 | +onx=g.to_onnx() |
| 156 | +node_types= [n.op_typeforninonx.graph.node] |
| 157 | +self.assertNotIn("Transpose",node_types) |
| 158 | +ref=ReferenceEvaluator(onx) |
| 159 | +x=np.random.randn(*shape).astype(np.float32) |
| 160 | +expected= (x @w.T).reshape((-1,1)) |
| 161 | +feeds= {"X":x} |
| 162 | +got=ref.run(None,feeds) |
| 163 | +self.assertEqualArray(expected,got[0]) |
| 164 | + |
133 | 165 | deftest_remove_identity(self): |
134 | 166 | withcontextlib.redirect_stdout(io.StringIO()): |
135 | 167 | g=GraphBuilder(verbose=10) |
@@ -238,6 +270,112 @@ def test_remove_unused_nodes_simple(self): |
238 | 270 | got=ref.run(None,feeds) |
239 | 271 | self.assertEqualArray(expected,got[0]) |
240 | 272 |
|
| 273 | +deftest_constant_array(self): |
| 274 | +withcontextlib.redirect_stdout(io.StringIO()): |
| 275 | +g=GraphBuilder(verbose=10) |
| 276 | + |
| 277 | +shape= (10,4) |
| 278 | +w=np.random.randn(*shape).astype(np.float32) |
| 279 | + |
| 280 | +x=g.make_tensor_input("X",np.float32,shape) |
| 281 | +one=g.make_initializer(np.array([-1,1],dtype=np.int64)) |
| 282 | +res=g.op.MatMul(x,w.T) |
| 283 | +g.op.Reshape(res,one,outputs="y") |
| 284 | +g.make_tensor_output("y",np.float32, (10,1)) |
| 285 | +onx=g.to_onnx() |
| 286 | +ref=ReferenceEvaluator(onx) |
| 287 | +x=np.random.randn(*shape).astype(np.float32) |
| 288 | +expected= (x @w.T).reshape((-1,1)) |
| 289 | +feeds= {"X":x} |
| 290 | +got=ref.run(None,feeds) |
| 291 | +self.assertEqualArray(expected,got[0]) |
| 292 | + |
| 293 | +deftest_constant_array_2(self): |
| 294 | +withcontextlib.redirect_stdout(io.StringIO()): |
| 295 | +g=GraphBuilder(verbose=10) |
| 296 | + |
| 297 | +shape= (10,4) |
| 298 | +w=np.random.randn(*shape).astype(np.float32) |
| 299 | + |
| 300 | +x=g.make_tensor_input("X",np.float32,shape) |
| 301 | +one=g.make_initializer(np.array([-1,1],dtype=np.int64)) |
| 302 | +opc=g.op.Constant(value=from_array_extended(w.T)) |
| 303 | +res=g.op.MatMul(x,opc) |
| 304 | +g.op.Reshape(res,one,outputs="y") |
| 305 | +g.make_tensor_output("y",np.float32, (10,1)) |
| 306 | +self.assertTrue(g.has_shape("X")) |
| 307 | +self.assertTrue(g.has_type("X")) |
| 308 | +self.assertEqual(g.get_type("X"),1) |
| 309 | +self.assertEqual(g.get_shape("X"), (10,4)) |
| 310 | +self.assertEqual(g.rank("X"),2) |
| 311 | +onx=g.to_onnx() |
| 312 | +ref=ReferenceEvaluator(onx) |
| 313 | +x=np.random.randn(*shape).astype(np.float32) |
| 314 | +expected= (x @w.T).reshape((-1,1)) |
| 315 | +feeds= {"X":x} |
| 316 | +got=ref.run(None,feeds) |
| 317 | +self.assertEqualArray(expected,got[0]) |
| 318 | + |
| 319 | +deftest_get_type(self): |
| 320 | +g=GraphBuilder() |
| 321 | +self.assertEqual(g._get_type(np.float32),onnx.TensorProto.FLOAT) |
| 322 | +self.assertEqual(g._get_type(np.int64),onnx.TensorProto.INT64) |
| 323 | +self.assertEqual(g._get_type(None),onnx.TensorProto.UNDEFINED) |
| 324 | + |
| 325 | +deftest_make_nodes_prefix(self): |
| 326 | +g1=GraphBuilder() |
| 327 | +g1.make_tensor_input("X",np.float32,shape=None) |
| 328 | +g1.op.Add("X",np.array([1],dtype=np.float32),outputs=["y"]) |
| 329 | +g1.make_tensor_output("y",np.float32,shape=None) |
| 330 | + |
| 331 | +g=GraphBuilder() |
| 332 | + |
| 333 | +shape= (10,4) |
| 334 | +w=np.random.randn(*shape).astype(np.float32) |
| 335 | + |
| 336 | +x=g.make_tensor_input("X",np.float32,shape) |
| 337 | +weight=g.make_initializer(w) |
| 338 | +one=g.make_initializer(np.array([-1,1],dtype=np.int64)) |
| 339 | +transposed=g.make_node("Transpose", [weight],perm=[1,0]) |
| 340 | +res=g.op.MatMul(x,transposed) |
| 341 | +res2=g.make_nodes(g1, [res], ["k"],prefix="J") |
| 342 | +g.op.Reshape(res2,one,outputs="y") |
| 343 | +g.make_tensor_output("y",np.float32, (10,1)) |
| 344 | +onx=g.to_onnx() |
| 345 | +ref=ReferenceEvaluator(onx) |
| 346 | +x=np.random.randn(*shape).astype(np.float32) |
| 347 | +expected= (x @w.T).reshape((-1,1))+1 |
| 348 | +feeds= {"X":x} |
| 349 | +got=ref.run(None,feeds) |
| 350 | +self.assertEqualArray(expected,got[0]) |
| 351 | + |
| 352 | +deftest_make_nodes_noprefix(self): |
| 353 | +g1=GraphBuilder() |
| 354 | +g1.make_tensor_input("X",np.float32,shape=None) |
| 355 | +g1.op.Add("X",np.array([1],dtype=np.float32),outputs=["y"]) |
| 356 | +g1.make_tensor_output("y",np.float32,shape=None) |
| 357 | + |
| 358 | +g=GraphBuilder() |
| 359 | + |
| 360 | +shape= (10,4) |
| 361 | +w=np.random.randn(*shape).astype(np.float32) |
| 362 | + |
| 363 | +x=g.make_tensor_input("X",np.float32,shape) |
| 364 | +weight=g.make_initializer(w) |
| 365 | +one=g.make_initializer(np.array([-1,1],dtype=np.int64)) |
| 366 | +transposed=g.make_node("Transpose", [weight],perm=[1,0]) |
| 367 | +res=g.op.MatMul(x,transposed) |
| 368 | +res2=g.make_nodes(g1, [res], ["k"]) |
| 369 | +g.op.Reshape(res2,one,outputs="y") |
| 370 | +g.make_tensor_output("y",np.float32, (10,1)) |
| 371 | +onx=g.to_onnx() |
| 372 | +ref=ReferenceEvaluator(onx) |
| 373 | +x=np.random.randn(*shape).astype(np.float32) |
| 374 | +expected= (x @w.T).reshape((-1,1))+1 |
| 375 | +feeds= {"X":x} |
| 376 | +got=ref.run(None,feeds) |
| 377 | +self.assertEqualArray(expected,got[0]) |
| 378 | + |
241 | 379 |
|
242 | 380 | if__name__=="__main__": |
243 | 381 | unittest.main(verbosity=2) |