22import unittest
33import numpy as np
44from onnx import TensorProto
5- from onnx_array_api .ext_test_case import ExtTestCase
5+ from onnx_array_api .ext_test_case import ExtTestCase , ignore_warnings
66from onnx_array_api .array_api import onnx_numpy as xp
77from onnx_array_api .npx .npx_types import DType
88from onnx_array_api .npx .npx_numpy_tensors import EagerNumpyTensor as EagerTensor
9+ from onnx_array_api .npx .npx_functions import linspace as linspace_inline
10+ from onnx_array_api .npx .npx_types import Float64 ,Int64
11+ from onnx_array_api .npx .npx_var import Input
12+ from onnx_array_api .reference import ExtendedReferenceEvaluator
913
1014
1115class TestOnnxNumpy (ExtTestCase ):
@@ -22,6 +26,7 @@ def test_zeros(self):
2226a = xp .absolute (mat )
2327self .assertEqualArray (np .absolute (mat .numpy ()),a .numpy ())
2428
29+ @ignore_warnings (DeprecationWarning )
2530def test_arange_default (self ):
2631a = EagerTensor (np .array ([0 ],dtype = np .int64 ))
2732b = EagerTensor (np .array ([2 ],dtype = np .int64 ))
@@ -30,6 +35,7 @@ def test_arange_default(self):
3035self .assertEqual (matnp .shape , (2 ,))
3136self .assertEqualArray (matnp ,np .arange (0 ,2 ).astype (np .int64 ))
3237
38+ @ignore_warnings (DeprecationWarning )
3339def test_arange_step (self ):
3440a = EagerTensor (np .array ([4 ],dtype = np .int64 ))
3541s = EagerTensor (np .array ([2 ],dtype = np .int64 ))
@@ -78,6 +84,7 @@ def test_full_bool(self):
7884self .assertNotEmpty (matnp [0 ,0 ])
7985self .assertEqualArray (matnp ,np .full ((4 ,5 ),False ))
8086
87+ @ignore_warnings (DeprecationWarning )
8188def test_arange_int00a (self ):
8289a = EagerTensor (np .array ([0 ],dtype = np .int64 ))
8390b = EagerTensor (np .array ([0 ],dtype = np .int64 ))
@@ -89,6 +96,7 @@ def test_arange_int00a(self):
8996expected = expected .astype (np .int64 )
9097self .assertEqualArray (matnp ,expected )
9198
99+ @ignore_warnings (DeprecationWarning )
92100def test_arange_int00 (self ):
93101mat = xp .arange (0 ,0 )
94102matnp = mat .numpy ()
@@ -160,10 +168,94 @@ def test_eye_k(self):
160168got = xp .eye (nr ,k = 1 )
161169self .assertEqualArray (expected ,got .numpy ())
162170
171+ def test_linspace_int (self ):
172+ a = EagerTensor (np .array ([0 ],dtype = np .int64 ))
173+ b = EagerTensor (np .array ([6 ],dtype = np .int64 ))
174+ c = EagerTensor (np .array (3 ,dtype = np .int64 ))
175+ mat = xp .linspace (a ,b ,c )
176+ matnp = mat .numpy ()
177+ expected = np .linspace (a .numpy (),b .numpy (),c .numpy ()).astype (np .int64 )
178+ self .assertEqualArray (expected ,matnp )
179+
180+ def test_linspace_int5 (self ):
181+ a = EagerTensor (np .array ([0 ],dtype = np .int64 ))
182+ b = EagerTensor (np .array ([5 ],dtype = np .int64 ))
183+ c = EagerTensor (np .array (3 ,dtype = np .int64 ))
184+ mat = xp .linspace (a ,b ,c )
185+ matnp = mat .numpy ()
186+ expected = np .linspace (a .numpy (),b .numpy (),c .numpy ()).astype (np .int64 )
187+ self .assertEqualArray (expected ,matnp )
188+
189+ def test_linspace_float (self ):
190+ a = EagerTensor (np .array ([0.5 ],dtype = np .float64 ))
191+ b = EagerTensor (np .array ([5.5 ],dtype = np .float64 ))
192+ c = EagerTensor (np .array (2 ,dtype = np .int64 ))
193+ mat = xp .linspace (a ,b ,c )
194+ matnp = mat .numpy ()
195+ expected = np .linspace (a .numpy (),b .numpy (),c .numpy ())
196+ self .assertEqualArray (expected ,matnp )
197+
198+ def test_linspace_float_noendpoint (self ):
199+ a = EagerTensor (np .array ([0.5 ],dtype = np .float64 ))
200+ b = EagerTensor (np .array ([5.5 ],dtype = np .float64 ))
201+ c = EagerTensor (np .array (2 ,dtype = np .int64 ))
202+ mat = xp .linspace (a ,b ,c ,endpoint = 0 )
203+ matnp = mat .numpy ()
204+ expected = np .linspace (a .numpy (),b .numpy (),c .numpy (),endpoint = 0 )
205+ self .assertEqualArray (expected ,matnp )
206+
207+ @ignore_warnings ((RuntimeWarning ,DeprecationWarning ))# division by zero
208+ def test_linspace_zero (self ):
209+ expected = np .linspace (0.0 ,0.0 ,0 ,endpoint = False )
210+ mat = xp .linspace (0.0 ,0.0 ,0 ,endpoint = False )
211+ matnp = mat .numpy ()
212+ self .assertEqualArray (expected ,matnp )
213+
214+ @ignore_warnings ((RuntimeWarning ,DeprecationWarning ))# division by zero
215+ def test_linspace_zero_one (self ):
216+ expected = np .linspace (0.0 ,0.0 ,1 ,endpoint = True )
217+
218+ f = linspace_inline (Input ("start" ),Input ("stop" ),Input ("num" ))
219+ onx = f .to_onnx (
220+ constraints = {
221+ "start" :Float64 [None ],
222+ "stop" :Float64 [None ],
223+ "num" :Int64 [None ],
224+ (0 ,False ):Float64 [None ],
225+ }
226+ )
227+ ref = ExtendedReferenceEvaluator (onx )
228+ got = ref .run (
229+ None ,
230+ {
231+ "start" :np .array (0 ,dtype = np .float64 ),
232+ "stop" :np .array (0 ,dtype = np .float64 ),
233+ "num" :np .array (1 ,dtype = np .int64 ),
234+ },
235+ )
236+ self .assertEqualArray (expected ,got [0 ])
237+
238+ mat = xp .linspace (0.0 ,0.0 ,1 ,endpoint = True )
239+ matnp = mat .numpy ()
240+
241+ self .assertEqualArray (expected ,matnp )
242+
243+ def test_slice_minus_one (self ):
244+ g = EagerTensor (np .array ([0.0 ]))
245+ expected = g .numpy ()[:- 1 ]
246+ got = g [:- 1 ]
247+ self .assertEqualArray (expected ,got .numpy ())
248+
249+ def test_linspace_bug1 (self ):
250+ expected = np .linspace (16777217.0 ,0.0 ,1 )
251+ mat = xp .linspace (16777217.0 ,0.0 ,1 )
252+ matnp = mat .numpy ()
253+ self .assertEqualArray (expected ,matnp )
254+
163255
164256if __name__ == "__main__" :
165257# import logging
166258
167259# logging.basicConfig(level=logging.DEBUG)
168- TestOnnxNumpy ().test_eye ()
260+ TestOnnxNumpy ().test_linspace_float_noendpoint ()
169261unittest .main (verbosity = 2 )