|
1 | 1 | fromtypingimportAny,Optional,Tuple,Union |
2 | 2 |
|
| 3 | +importarray_api_compat.numpyasnp_array_api |
3 | 4 | importnumpyasnp |
4 | 5 | fromonnximportFunctionProto,ModelProto,NodeProto,TensorProto |
5 | 6 | fromonnx.helperimportnp_dtype_to_tensor_dtype |
6 | 7 | fromonnx.numpy_helperimportfrom_array |
7 | 8 |
|
8 | 9 | from .npx_constantsimportFUNCTION_DOMAIN |
9 | | -from .npx_core_apiimportcst,make_tuple,npxapi_inline,var |
| 10 | +from .npx_core_apiimportcst,make_tuple,npxapi_inline,npxapi_no_inline,var |
10 | 11 | from .npx_tensorsimportArrayApi |
11 | 12 | from .npx_typesimport ( |
| 13 | +DType, |
12 | 14 | ElemType, |
13 | 15 | OptParType, |
14 | 16 | ParType, |
@@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: |
397 | 399 | returnv |
398 | 400 |
|
399 | 401 |
|
| 402 | +@npxapi_no_inline |
| 403 | +defisdtype( |
| 404 | +dtype:DType,kind:Union[DType,str,Tuple[Union[DType,str], ...]] |
| 405 | +)->bool: |
| 406 | +""" |
| 407 | + See :epkg:`ArrayAPI:isdtype`. |
| 408 | + This function is not converted into an onnx graph. |
| 409 | + """ |
| 410 | +returnnp_array_api.isdtype(dtype,kind) |
| 411 | + |
| 412 | + |
400 | 413 | @npxapi_inline |
401 | 414 | defisnan(x:TensorType[ElemType.numerics,"T"])->TensorType[ElemType.bool_,"T"]: |
402 | 415 | "See :func:`numpy.isnan`." |
@@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, |
460 | 473 |
|
461 | 474 | @npxapi_inline |
462 | 475 | defreshape( |
463 | | -x:TensorType[ElemType.numerics,"T"],shape:TensorType[ElemType.int64,"I"] |
| 476 | +x:TensorType[ElemType.numerics,"T"], |
| 477 | +shape:TensorType[ElemType.int64,"I", (None,)], |
464 | 478 | )->TensorType[ElemType.numerics,"T"]: |
465 | | -"See :func:`numpy.reshape`." |
| 479 | +""" |
| 480 | + See :func:`numpy.reshape`. |
| 481 | +
|
| 482 | + .. warning:: |
| 483 | +
|
| 484 | + Numpy definition is tricky because onnxruntime does not handle well |
| 485 | + dimensions with an undefined number of dimensions. |
| 486 | + However the array API defines a more stricly signature for |
| 487 | + `reshape <https://data-apis.org/array-api/2022.12/ |
| 488 | + API_specification/generated/array_api.reshape.html>`_. |
| 489 | + :epkg:`scikit-learn` updated its code to follow the Array API in |
| 490 | + `PR 26030 ENH Forces shape to be tuple when using Array API's reshape |
| 491 | + <https://github.com/scikit-learn/scikit-learn/pull/26030>`_. |
| 492 | + """ |
466 | 493 | ifisinstance(shape,int): |
467 | 494 | shape=cst(np.array([shape],dtype=np.int64)) |
468 | 495 | shape_reshaped=var(shape,cst(np.array([-1],dtype=np.int64)),op="Reshape") |
|