Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitec86afa

Browse files
authored
Add 3.7 features to python wrapper (#221)
* adds af_pad to python wrapper* adds meanvar to python wrapper* adds inverse square root to python wrapper* adds pinverse to python wrapper* adds NN convolve and gradient functions to wrapper* adds reduce by key to python wrappermissing convolve gradient function* adds confidenceCC to python wrapper* adds fp16 support to python wrapper* update version* remove stray print statements* adds axes_label_format to python wrapper, removes mistakenly copied code
1 parentaead039 commitec86afa

23 files changed

+636
-7
lines changed

‎__af_version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
# http://arrayfire.com/licenses/BSD-3-Clause
1010
########################################################
1111

12-
version="3.5"
13-
release="20170718"
12+
version="3.7"
13+
release="20200213"
1414
full_version=version+"."+release

‎arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from .timerimport*
7575
from .randomimport*
7676
from .sparseimport*
77+
from .mlimport*
7778

7879
# do not export default modules as part of arrayfire
7980
delct

‎arrayfire/algorithm.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,31 @@ def _nan_reduce_all(a, c_func, nan_val):
4444
imag=imag.value
4545
returnrealifimag==0elsereal+imag*1j
4646

47+
def_FNSD(dim,dims):
48+
ifdim>=0:
49+
returnint(dim)
50+
51+
fnsd=0
52+
fori,dinenumerate(dims):
53+
ifd>1:
54+
fnsd=i
55+
break
56+
returnint(fnsd)
57+
58+
def_rbk_dim(keys,vals,dim,c_func):
59+
keys_out=Array()
60+
vals_out=Array()
61+
rdim=_FNSD(dim,vals.dims())
62+
safe_call(c_func(c_pointer(keys_out.arr),c_pointer(vals_out.arr),keys.arr,vals.arr,c_int_t(rdim)))
63+
returnkeys_out,vals_out
64+
65+
def_nan_rbk_dim(a,dim,c_func,nan_val):
66+
keys_out=Array()
67+
vals_out=Array()
68+
rdim=_FNSD(dim,vals.dims())
69+
safe_call(c_func(c_pointer(keys_out.arr),c_pointer(vals_out.arr),keys.arr,vals.arr,c_int_t(rdim),c_double_t(nan_val)))
70+
returnkeys_out,vals_out
71+
4772
defsum(a,dim=None,nan_val=None):
4873
"""
4974
Calculate the sum of all the elements along a specified dimension.
@@ -74,6 +99,34 @@ def sum(a, dim=None, nan_val=None):
7499
else:
75100
return_reduce_all(a,backend.get().af_sum_all)
76101

102+
103+
defsumByKey(keys,vals,dim=-1,nan_val=None):
104+
"""
105+
Calculate the sum of elements along a specified dimension according to a key.
106+
107+
Parameters
108+
----------
109+
keys : af.Array
110+
One dimensional arrayfire array with reduction keys.
111+
vals : af.Array
112+
Multi dimensional arrayfire array that will be reduced.
113+
dim: optional: int. default: -1
114+
Dimension along which the sum will occur.
115+
nan_val: optional: scalar. default: None
116+
The value that replaces NaN in the array
117+
118+
Returns
119+
-------
120+
keys: af.Array or scalar number
121+
The reduced keys of all elements in `vals` along dimension `dim`.
122+
values: af.Array or scalar number
123+
The sum of all elements in `vals` along dimension `dim` according to keys
124+
"""
125+
if (nan_valisnotNone):
126+
return_nan_rbk_dim(keys,vals,dim,backend.get().af_sum_by_key_nan,nan_val)
127+
else:
128+
return_rbk_dim(keys,vals,dim,backend.get().af_sum_by_key)
129+
77130
defproduct(a,dim=None,nan_val=None):
78131
"""
79132
Calculate the product of all the elements along a specified dimension.
@@ -104,6 +157,33 @@ def product(a, dim=None, nan_val=None):
104157
else:
105158
return_reduce_all(a,backend.get().af_product_all)
106159

160+
defproductByKey(keys,vals,dim=-1,nan_val=None):
161+
"""
162+
Calculate the product of elements along a specified dimension according to a key.
163+
164+
Parameters
165+
----------
166+
keys : af.Array
167+
One dimensional arrayfire array with reduction keys.
168+
vals : af.Array
169+
Multi dimensional arrayfire array that will be reduced.
170+
dim: optional: int. default: -1
171+
Dimension along which the product will occur.
172+
nan_val: optional: scalar. default: None
173+
The value that replaces NaN in the array
174+
175+
Returns
176+
-------
177+
keys: af.Array or scalar number
178+
The reduced keys of all elements in `vals` along dimension `dim`.
179+
values: af.Array or scalar number
180+
The product of all elements in `vals` along dimension `dim` according to keys
181+
"""
182+
if (nan_valisnotNone):
183+
return_nan_rbk_dim(keys,vals,dim,backend.get().af_product_by_key_nan,nan_val)
184+
else:
185+
return_rbk_dim(keys,vals,dim,backend.get().af_product_by_key)
186+
107187
defmin(a,dim=None):
108188
"""
109189
Find the minimum value of all the elements along a specified dimension.
@@ -126,6 +206,28 @@ def min(a, dim=None):
126206
else:
127207
return_reduce_all(a,backend.get().af_min_all)
128208

209+
defminByKey(keys,vals,dim=-1):
210+
"""
211+
Calculate the min of elements along a specified dimension according to a key.
212+
213+
Parameters
214+
----------
215+
keys : af.Array
216+
One dimensional arrayfire array with reduction keys.
217+
vals : af.Array
218+
Multi dimensional arrayfire array that will be reduced.
219+
dim: optional: int. default: -1
220+
Dimension along which the min will occur.
221+
222+
Returns
223+
-------
224+
keys: af.Array or scalar number
225+
The reduced keys of all elements in `vals` along dimension `dim`.
226+
values: af.Array or scalar number
227+
The min of all elements in `vals` along dimension `dim` according to keys
228+
"""
229+
return_rbk_dim(keys,vals,dim,backend.get().af_min_by_key)
230+
129231
defmax(a,dim=None):
130232
"""
131233
Find the maximum value of all the elements along a specified dimension.
@@ -148,6 +250,28 @@ def max(a, dim=None):
148250
else:
149251
return_reduce_all(a,backend.get().af_max_all)
150252

253+
defmaxByKey(keys,vals,dim=-1):
254+
"""
255+
Calculate the max of elements along a specified dimension according to a key.
256+
257+
Parameters
258+
----------
259+
keys : af.Array
260+
One dimensional arrayfire array with reduction keys.
261+
vals : af.Array
262+
Multi dimensional arrayfire array that will be reduced.
263+
dim: optional: int. default: -1
264+
Dimension along which the max will occur.
265+
266+
Returns
267+
-------
268+
keys: af.Array or scalar number
269+
The reduced keys of all elements in `vals` along dimension `dim`.
270+
values: af.Array or scalar number
271+
The max of all elements in `vals` along dimension `dim` according to keys.
272+
"""
273+
return_rbk_dim(keys,vals,dim,backend.get().af_max_by_key)
274+
151275
defall_true(a,dim=None):
152276
"""
153277
Check if all the elements along a specified dimension are true.
@@ -170,6 +294,28 @@ def all_true(a, dim=None):
170294
else:
171295
return_reduce_all(a,backend.get().af_all_true_all)
172296

297+
defallTrueByKey(keys,vals,dim=-1):
298+
"""
299+
Calculate if all elements are true along a specified dimension according to a key.
300+
301+
Parameters
302+
----------
303+
keys : af.Array
304+
One dimensional arrayfire array with reduction keys.
305+
vals : af.Array
306+
Multi dimensional arrayfire array that will be reduced.
307+
dim: optional: int. default: -1
308+
Dimension along which the all true check will occur.
309+
310+
Returns
311+
-------
312+
keys: af.Array or scalar number
313+
The reduced keys of all true check in `vals` along dimension `dim`.
314+
values: af.Array or scalar number
315+
Booleans denoting if all elements are true in `vals` along dimension `dim` according to keys
316+
"""
317+
return_rbk_dim(keys,vals,dim,backend.get().af_all_true_by_key)
318+
173319
defany_true(a,dim=None):
174320
"""
175321
Check if any the elements along a specified dimension are true.
@@ -192,6 +338,28 @@ def any_true(a, dim=None):
192338
else:
193339
return_reduce_all(a,backend.get().af_any_true_all)
194340

341+
defanyTrueByKey(keys,vals,dim=-1):
342+
"""
343+
Calculate if any elements are true along a specified dimension according to a key.
344+
345+
Parameters
346+
----------
347+
keys : af.Array
348+
One dimensional arrayfire array with reduction keys.
349+
vals : af.Array
350+
Multi dimensional arrayfire array that will be reduced.
351+
dim: optional: int. default: -1
352+
Dimension along which the any true check will occur.
353+
354+
Returns
355+
-------
356+
keys: af.Array or scalar number
357+
The reduced keys of any true check in `vals` along dimension `dim`.
358+
values: af.Array or scalar number
359+
Booleans denoting if any elements are true in `vals` along dimension `dim` according to keys.
360+
"""
361+
return_rbk_dim(keys,vals,dim,backend.get().af_any_true_by_key)
362+
195363
defcount(a,dim=None):
196364
"""
197365
Count the number of non zero elements in an array along a specified dimension.
@@ -214,6 +382,28 @@ def count(a, dim=None):
214382
else:
215383
return_reduce_all(a,backend.get().af_count_all)
216384

385+
defcountByKey(keys,vals,dim=-1):
386+
"""
387+
Counts non-zero elements along a specified dimension according to a key.
388+
389+
Parameters
390+
----------
391+
keys : af.Array
392+
One dimensional arrayfire array with reduction keys.
393+
vals : af.Array
394+
Multi dimensional arrayfire array that will be reduced.
395+
dim: optional: int. default: -1
396+
Dimension along which to count elements.
397+
398+
Returns
399+
-------
400+
keys: af.Array or scalar number
401+
The reduced keys of count in `vals` along dimension `dim`.
402+
values: af.Array or scalar number
403+
Count of non-zero elements in `vals` along dimension `dim` according to keys.
404+
"""
405+
return_rbk_dim(keys,vals,dim,backend.get().af_count_by_key)
406+
217407
defimin(a,dim=None):
218408
"""
219409
Find the value and location of the minimum value along a specified dimension

‎arrayfire/arith.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,26 @@ def sqrt(a):
958958
"""
959959
return_arith_unary_func(a,backend.get().af_sqrt)
960960

961+
defrsqrt(a):
962+
"""
963+
Reciprocal or inverse square root of each element in the array.
964+
965+
Parameters
966+
----------
967+
a : af.Array
968+
Multi dimensional arrayfire array.
969+
970+
Returns
971+
--------
972+
out : af.Array
973+
array containing the inverse square root of each value from `a`.
974+
975+
Note
976+
-------
977+
`a` must not be complex.
978+
"""
979+
return_arith_unary_func(a,backend.get().af_rsqrt)
980+
961981
defcbrt(a):
962982
"""
963983
Cube root of each element in the array.

‎arrayfire/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,14 @@ def is_single(self):
783783
safe_call(backend.get().af_is_single(c_pointer(res),self.arr))
784784
returnres.value
785785

786+
defis_half(self):
787+
"""
788+
Check if the array is of half floating point type (fp16).
789+
"""
790+
res=c_bool_t(False)
791+
safe_call(backend.get().af_is_half(c_pointer(res),self.arr))
792+
returnres.value
793+
786794
defis_real_floating(self):
787795
"""
788796
Check if the array is real and of floating point type.

‎arrayfire/data.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,58 @@ def replace(lhs, cond, rhs):
799799
else:
800800
safe_call(backend.get().af_replace_scalar(lhs.arr,cond.arr,c_double_t(rhs)))
801801

802+
defpad(a,beginPadding,endPadding,padFillType=PAD.ZERO):
803+
"""
804+
Pad an array
805+
806+
This function will pad an array with the specified border size.
807+
Newly padded values can be filled in several different ways.
808+
809+
Parameters
810+
----------
811+
812+
a: af.Array
813+
A multi dimensional input arrayfire array.
814+
815+
beginPadding: tuple of ints. default: (0, 0, 0, 0).
816+
817+
endPadding: tuple of ints. default: (0, 0, 0, 0).
818+
819+
padFillType: optional af.PAD default: af.PAD.ZERO
820+
specifies type of values to fill padded border with
821+
822+
Returns
823+
-------
824+
output: af.Array
825+
A padded array
826+
827+
Examples
828+
---------
829+
>>> import arrayfire as af
830+
>>> a = af.randu(3,3)
831+
>>> af.display(a)
832+
[3 3 1 1]
833+
0.4107 0.1794 0.3775
834+
0.8224 0.4198 0.3027
835+
0.9518 0.0081 0.6456
836+
837+
>>> padded = af.pad(a, (1, 1), (1, 1), af.ZERO)
838+
>>> af.display(padded)
839+
[5 5 1 1]
840+
0.0000 0.0000 0.0000 0.0000 0.0000
841+
0.0000 0.4107 0.1794 0.3775 0.0000
842+
0.0000 0.8224 0.4198 0.3027 0.0000
843+
0.0000 0.9518 0.0081 0.6456 0.0000
844+
0.0000 0.0000 0.0000 0.0000 0.0000
845+
"""
846+
out=Array()
847+
begin_dims=dim4(beginPadding[0],beginPadding[1],beginPadding[2],beginPadding[3])
848+
end_dims=dim4(endPadding[0],endPadding[1],endPadding[2],endPadding[3])
849+
850+
safe_call(backend.get().af_pad(c_pointer(out.arr),a.arr,4,c_pointer(begin_dims),4,c_pointer(end_dims),padFillType.value))
851+
returnout
852+
853+
802854
deflookup(a,idx,dim=0):
803855
"""
804856
Lookup the values of input array based on index.

‎arrayfire/device.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,25 @@ def is_dbl_supported(device=None):
150150
safe_call(backend.get().af_get_dbl_support(c_pointer(res),dev))
151151
returnres.value
152152

153+
defis_half_supported(device=None):
154+
"""
155+
Check if half precision is supported on specified device.
156+
157+
Parameters
158+
-----------
159+
device: optional: int. default: None.
160+
id of the desired device.
161+
162+
Returns
163+
--------
164+
- True if half precision supported.
165+
- False if half precision not supported.
166+
"""
167+
dev=deviceifdeviceisnotNoneelseget_device()
168+
res=c_bool_t(False)
169+
safe_call(backend.get().af_get_half_support(c_pointer(res),dev))
170+
returnres.value
171+
153172
defsync(device=None):
154173
"""
155174
Block until all the functions on the device have completed execution.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp