33"""
44from typing import Any ,Optional
55import numpy as np
6- from onnx import TensorProto
76from ..npx .npx_functions import (
87all ,
98abs ,
1615reshape ,
1716take ,
1817)
18+ from ..npx .npx_functions import full as generic_full
1919from ..npx .npx_functions import ones as generic_ones
2020from ..npx .npx_functions import zeros as generic_zeros
2121from ..npx .npx_numpy_tensors import EagerNumpyTensor
22- from ..npx .npx_types import DType ,ElemType ,TensorType ,OptParType
22+ from ..npx .npx_types import DType ,ElemType ,TensorType ,OptParType , ParType , Scalar
2323from ._onnx_common import template_asarray
2424from .import _finalize_array_api
2525
3131"astype" ,
3232"empty" ,
3333"equal" ,
34+ "full" ,
3435"isdtype" ,
3536"isfinite" ,
3637"isnan" ,
@@ -58,7 +59,7 @@ def asarray(
5859
5960def ones (
6061shape :TensorType [ElemType .int64 ,"I" , (None ,)],
61- dtype :OptParType [DType ]= DType ( TensorProto . FLOAT ) ,
62+ dtype :OptParType [DType ]= None ,
6263order :OptParType [str ]= "C" ,
6364)-> TensorType [ElemType .numerics ,"T" ]:
6465if isinstance (shape ,tuple ):
@@ -76,7 +77,7 @@ def ones(
7677
7778def empty (
7879shape :TensorType [ElemType .int64 ,"I" , (None ,)],
79- dtype :OptParType [DType ]= DType ( TensorProto . FLOAT ) ,
80+ dtype :OptParType [DType ]= None ,
8081order :OptParType [str ]= "C" ,
8182)-> TensorType [ElemType .numerics ,"T" ]:
8283raise RuntimeError (
@@ -87,7 +88,7 @@ def empty(
8788
8889def zeros (
8990shape :TensorType [ElemType .int64 ,"I" , (None ,)],
90- dtype :OptParType [DType ]= DType ( TensorProto . FLOAT ) ,
91+ dtype :OptParType [DType ]= None ,
9192order :OptParType [str ]= "C" ,
9293)-> TensorType [ElemType .numerics ,"T" ]:
9394if isinstance (shape ,tuple ):
@@ -103,6 +104,32 @@ def zeros(
103104return generic_zeros (shape ,dtype = dtype ,order = order )
104105
105106
107+ def full (
108+ shape :TensorType [ElemType .int64 ,"I" , (None ,)],
109+ fill_value :ParType [Scalar ]= None ,
110+ dtype :OptParType [DType ]= None ,
111+ order :OptParType [str ]= "C" ,
112+ )-> TensorType [ElemType .numerics ,"T" ]:
113+ if fill_value is None :
114+ raise TypeError ("fill_value cannot be None" )
115+ value = fill_value
116+ if isinstance (shape ,tuple ):
117+ return generic_full (
118+ EagerNumpyTensor (np .array (shape ,dtype = np .int64 )),
119+ fill_value = value ,
120+ dtype = dtype ,
121+ order = order ,
122+ )
123+ if isinstance (shape ,int ):
124+ return generic_full (
125+ EagerNumpyTensor (np .array ([shape ],dtype = np .int64 )),
126+ fill_value = value ,
127+ dtype = dtype ,
128+ order = order ,
129+ )
130+ return generic_full (shape ,fill_value = value ,dtype = dtype ,order = order )
131+
132+
106133def _finalize ():
107134"""
108135 Adds common attributes to Array API defined in this modules