|
1 | 1 | importunittest |
2 | 2 | fromcontextlibimportredirect_stdout |
3 | 3 | fromioimportStringIO |
4 | | - |
5 | 4 | importnumpyasnp |
6 | 5 | fromonnx.defsimportonnx_opset_version |
7 | 6 | fromonnx.referenceimportReferenceEvaluator |
8 | 7 | fromonnxruntimeimportInferenceSession |
9 | | - |
10 | 8 | fromonnx_array_api.ext_test_caseimportExtTestCase |
11 | 9 | fromonnx_array_api.npximporteager_onnx,jit_onnx |
12 | 10 | fromonnx_array_api.npx.npx_functionsimportabsoluteasabsolute_inline |
13 | 11 | fromonnx_array_api.npx.npx_functionsimportcdistascdist_inline |
14 | 12 | fromonnx_array_api.npx.npx_functions_testimportabsolute |
15 | | -fromonnx_array_api.npx.npx_typesimportFloat32,Float64 |
| 13 | +fromonnx_array_api.npx.npx_functionsimportcopyascopy_inline |
| 14 | +fromonnx_array_api.npx.npx_typesimportFloat32,Float64,DType |
16 | 15 | fromonnx_array_api.npx.npx_varimportInput |
17 | 16 | fromonnx_array_api.ort.ort_tensorsimportEagerOrtTensor,JitOrtTensor,OrtTensor |
18 | 17 |
|
@@ -193,6 +192,49 @@ def impl(xa, xb): |
193 | 192 | iflen(pieces)>2: |
194 | 193 | raiseAssertionError(f"Function is not using argument:\n{onx}") |
195 | 194 |
|
| 195 | +deftest_astype(self): |
| 196 | +f=absolute_inline(copy_inline(Input("A")).astype(np.float32)) |
| 197 | +onx=f.to_onnx(constraints={"A":Float64[None]}) |
| 198 | +x=np.array([[-5,6]],dtype=np.float64) |
| 199 | +z=np.abs(x.astype(np.float32)) |
| 200 | +ref=InferenceSession( |
| 201 | +onx.SerializeToString(),providers=["CPUExecutionProvider"] |
| 202 | + ) |
| 203 | +got=ref.run(None, {"A":x}) |
| 204 | +self.assertEqualArray(z,got[0]) |
| 205 | + |
| 206 | +deftest_astype0(self): |
| 207 | +f=absolute_inline(copy_inline(Input("A")).astype(np.float32)) |
| 208 | +onx=f.to_onnx(constraints={"A":Float64[None]}) |
| 209 | +x=np.array(-5,dtype=np.float64) |
| 210 | +z=np.abs(x.astype(np.float32)) |
| 211 | +ref=InferenceSession( |
| 212 | +onx.SerializeToString(),providers=["CPUExecutionProvider"] |
| 213 | + ) |
| 214 | +got=ref.run(None, {"A":x}) |
| 215 | +self.assertEqualArray(z,got[0]) |
| 216 | + |
| 217 | +deftest_eager_ort_cast(self): |
| 218 | +defimpl(A): |
| 219 | +returnA.astype(DType("FLOAT")) |
| 220 | + |
| 221 | +e=eager_onnx(impl) |
| 222 | +self.assertEqual(len(e.versions),0) |
| 223 | + |
| 224 | +# Float64 |
| 225 | +x=np.array([0,1,-2],dtype=np.float64) |
| 226 | +z=x.astype(np.float32) |
| 227 | +res=e(x) |
| 228 | +self.assertEqualArray(z,res) |
| 229 | +self.assertEqual(res.dtype,np.float32) |
| 230 | + |
| 231 | +# again |
| 232 | +x=np.array(1,dtype=np.float64) |
| 233 | +z=x.astype(np.float32) |
| 234 | +res=e(x) |
| 235 | +self.assertEqualArray(z,res) |
| 236 | +self.assertEqual(res.dtype,np.float32) |
| 237 | + |
196 | 238 |
|
197 | 239 | if__name__=="__main__": |
198 | 240 | # TestNpx().test_eager_numpy() |
|