|
| 1 | +importunittest |
| 2 | +importnumpyasnp |
| 3 | +fromonnximportTensorProto |
| 4 | +fromonnx.helperimportmake_graph,make_model,make_node,make_tensor_value_info |
| 5 | +fromonnx_array_api.ext_test_caseimportExtTestCase |
| 6 | +fromonnx_array_api.referenceimport ( |
| 7 | +to_array_extended, |
| 8 | +from_array_extended, |
| 9 | +ExtendedReferenceEvaluator, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +classTestArrayTensor(ExtTestCase): |
| 14 | +deftest_from_array(self): |
| 15 | +fordtin (np.float32,np.float16,np.uint16,np.uint8): |
| 16 | +withself.subTest(dtype=dt): |
| 17 | +a=np.array([0,1,2],dtype=dt) |
| 18 | +t=from_array_extended(a,"a") |
| 19 | +b=to_array_extended(t) |
| 20 | +self.assertEqualArray(a,b) |
| 21 | +t2=from_array_extended(b,"a") |
| 22 | +self.assertEqual(t.SerializeToString(),t2.SerializeToString()) |
| 23 | + |
| 24 | +deftest_from_array_f8(self): |
| 25 | +defmake_model_f8(fr,to): |
| 26 | +model=make_model( |
| 27 | +make_graph( |
| 28 | + [make_node("Cast", ["X"], ["Y"],to=to)], |
| 29 | +"cast", |
| 30 | + [make_tensor_value_info("X",fr,None)], |
| 31 | + [make_tensor_value_info("Y",to,None)], |
| 32 | + ) |
| 33 | + ) |
| 34 | +returnmodel |
| 35 | + |
| 36 | +fordtin (np.float32,np.float16,np.uint16,np.uint8): |
| 37 | +withself.subTest(dtype=dt): |
| 38 | +a=np.array([0,1,2],dtype=dt) |
| 39 | +b=from_array_extended(a,"a") |
| 40 | +fortoin [ |
| 41 | +TensorProto.FLOAT8E4M3FN, |
| 42 | +TensorProto.FLOAT8E4M3FNUZ, |
| 43 | +TensorProto.FLOAT8E5M2, |
| 44 | +TensorProto.FLOAT8E5M2FNUZ, |
| 45 | +TensorProto.BFLOAT16, |
| 46 | + ]: |
| 47 | +withself.subTest(fr=b.data_type,to=to): |
| 48 | +model=make_model_f8(b.data_type,to) |
| 49 | +ref=ExtendedReferenceEvaluator(model) |
| 50 | +got=ref.run(None, {"X":a})[0] |
| 51 | +back=from_array_extended(got,"a") |
| 52 | +self.assertEqual(to,back.data_type) |
| 53 | + |
| 54 | + |
| 55 | +if__name__=="__main__": |
| 56 | +unittest.main(verbosity=2) |