@@ -13,7 +13,7 @@ class ArrayApiError(RuntimeError):
1313pass
1414
1515
16- class ArrayApi :
16+ class BaseArrayApi :
1717"""
1818 List of supported method by a tensor.
1919 """
@@ -36,145 +36,145 @@ def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
3636f"for class{ self .__class__ .__name__ !r} . "
3737f"Method 'generic_method' can be overwritten "
3838f"as well to change the behaviour "
39- f"for all methods supported by classArrayApi ."
39+ f"for all methods supported by classBaseArrayApi ."
4040 )
4141
4242def numpy (self )-> np .ndarray :
4343return self .generic_method ("numpy" )
4444
45- def __neg__ (self )-> "ArrayApi " :
45+ def __neg__ (self )-> "BaseArrayApi " :
4646return self .generic_method ("__neg__" )
4747
48- def __invert__ (self )-> "ArrayApi " :
48+ def __invert__ (self )-> "BaseArrayApi " :
4949return self .generic_method ("__invert__" )
5050
51- def __add__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
51+ def __add__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
5252return self .generic_method ("__add__" ,ov )
5353
54- def __radd__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
54+ def __radd__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
5555return self .generic_method ("__radd__" ,ov )
5656
57- def __sub__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
57+ def __sub__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
5858return self .generic_method ("__sub__" ,ov )
5959
60- def __rsub__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
60+ def __rsub__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
6161return self .generic_method ("__rsub__" ,ov )
6262
63- def __mul__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
63+ def __mul__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
6464return self .generic_method ("__mul__" ,ov )
6565
66- def __rmul__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
66+ def __rmul__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
6767return self .generic_method ("__rmul__" ,ov )
6868
69- def __matmul__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
69+ def __matmul__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
7070return self .generic_method ("__matmul__" ,ov )
7171
72- def __truediv__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
72+ def __truediv__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
7373return self .generic_method ("__truediv__" ,ov )
7474
75- def __rtruediv__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
75+ def __rtruediv__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
7676return self .generic_method ("__rtruediv__" ,ov )
7777
78- def __mod__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
78+ def __mod__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
7979return self .generic_method ("__mod__" ,ov )
8080
81- def __rmod__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
81+ def __rmod__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
8282return self .generic_method ("__rmod__" ,ov )
8383
84- def __pow__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
84+ def __pow__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
8585return self .generic_method ("__pow__" ,ov )
8686
87- def __rpow__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
87+ def __rpow__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
8888return self .generic_method ("__rpow__" ,ov )
8989
90- def __lt__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
90+ def __lt__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
9191return self .generic_method ("__lt__" ,ov )
9292
93- def __le__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
93+ def __le__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
9494return self .generic_method ("__le__" ,ov )
9595
96- def __gt__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
96+ def __gt__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
9797return self .generic_method ("__gt__" ,ov )
9898
99- def __ge__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
99+ def __ge__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
100100return self .generic_method ("__ge__" ,ov )
101101
102- def __eq__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
102+ def __eq__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
103103return self .generic_method ("__eq__" ,ov )
104104
105- def __ne__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
105+ def __ne__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
106106return self .generic_method ("__ne__" ,ov )
107107
108- def __lshift__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
108+ def __lshift__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
109109return self .generic_method ("__lshift__" ,ov )
110110
111- def __rshift__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
111+ def __rshift__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
112112return self .generic_method ("__rshift__" ,ov )
113113
114- def __and__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
114+ def __and__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
115115return self .generic_method ("__and__" ,ov )
116116
117- def __rand__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
117+ def __rand__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
118118return self .generic_method ("__rand__" ,ov )
119119
120- def __or__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
120+ def __or__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
121121return self .generic_method ("__or__" ,ov )
122122
123- def __ror__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
123+ def __ror__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
124124return self .generic_method ("__ror__" ,ov )
125125
126- def __xor__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
126+ def __xor__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
127127return self .generic_method ("__xor__" ,ov )
128128
129- def __rxor__ (self ,ov :"ArrayApi " )-> "ArrayApi " :
129+ def __rxor__ (self ,ov :"BaseArrayApi " )-> "BaseArrayApi " :
130130return self .generic_method ("__rxor__" ,ov )
131131
132132@property
133- def T (self )-> "ArrayApi " :
133+ def T (self )-> "BaseArrayApi " :
134134return self .generic_method ("T" )
135135
136- def astype (self ,dtype :Any )-> "ArrayApi " :
136+ def astype (self ,dtype :Any )-> "BaseArrayApi " :
137137return self .generic_method ("astype" ,dtype )
138138
139139@property
140- def shape (self )-> "ArrayApi " :
140+ def shape (self )-> "BaseArrayApi " :
141141return self .generic_method ("shape" )
142142
143- def reshape (self ,shape :"ArrayApi " )-> "ArrayApi " :
143+ def reshape (self ,shape :"BaseArrayApi " )-> "BaseArrayApi " :
144144return self .generic_method ("reshape" ,shape )
145145
146146def sum (
147147self ,axis :OptParType [TupleType [int ]]= None ,keepdims :ParType [int ]= 0
148- )-> "ArrayApi " :
148+ )-> "BaseArrayApi " :
149149return self .generic_method ("sum" ,axis = axis ,keepdims = keepdims )
150150
151151def mean (
152152self ,axis :OptParType [TupleType [int ]]= None ,keepdims :ParType [int ]= 0
153- )-> "ArrayApi " :
153+ )-> "BaseArrayApi " :
154154return self .generic_method ("mean" ,axis = axis ,keepdims = keepdims )
155155
156156def min (
157157self ,axis :OptParType [TupleType [int ]]= None ,keepdims :ParType [int ]= 0
158- )-> "ArrayApi " :
158+ )-> "BaseArrayApi " :
159159return self .generic_method ("min" ,axis = axis ,keepdims = keepdims )
160160
161161def max (
162162self ,axis :OptParType [TupleType [int ]]= None ,keepdims :ParType [int ]= 0
163- )-> "ArrayApi " :
163+ )-> "BaseArrayApi " :
164164return self .generic_method ("max" ,axis = axis ,keepdims = keepdims )
165165
166166def prod (
167167self ,axis :OptParType [TupleType [int ]]= None ,keepdims :ParType [int ]= 0
168- )-> "ArrayApi " :
168+ )-> "BaseArrayApi " :
169169return self .generic_method ("prod" ,axis = axis ,keepdims = keepdims )
170170
171- def copy (self )-> "ArrayApi " :
171+ def copy (self )-> "BaseArrayApi " :
172172return self .generic_method ("copy" )
173173
174- def flatten (self )-> "ArrayApi " :
174+ def flatten (self )-> "BaseArrayApi " :
175175return self .generic_method ("flatten" )
176176
177- def __getitem__ (self ,index :Any )-> "ArrayApi " :
177+ def __getitem__ (self ,index :Any )-> "BaseArrayApi " :
178178return self .generic_method ("__getitem__" ,index )
179179
180180def __setitem__ (self ,index :Any ,values :Any ):