Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita914b02

Browse files
syurkevi9prady9
authored andcommitted
adds gemm functionality, complex ctypes
1 parente053bb5 commita914b02

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

‎arrayfire/blas.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,109 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar =
202202
safe_call(backend.get().af_dot(c_pointer(out.arr),lhs.arr,rhs.arr,
203203
lhs_opts.value,rhs_opts.value))
204204
returnout
205+
206+
defgemm(lhs,rhs,alpha=1.0,beta=0.0,lhs_opts=MATPROP.NONE,rhs_opts=MATPROP.NONE,C=None):
207+
"""
208+
BLAS general matrix multiply (GEMM) of two af_array objects.
209+
210+
This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as:
211+
212+
C = α ∗ opA(A) opB(B)+ β∗C
213+
214+
where α (alpha) and β (beta) are both scalars; A and B are the matrix multiply operands;
215+
and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations
216+
on A or B before the actual GEMM operation.
217+
Batched GEMM is supported if at least either A or B have more than two dimensions
218+
(see af::matmul for more details on broadcasting).
219+
However, only one alpha and one beta can be used for all of the batched matrix operands.
220+
221+
Parameters
222+
----------
223+
224+
lhs : af.Array
225+
A 2 dimensional, real or complex arrayfire array.
226+
227+
rhs : af.Array
228+
A 2 dimensional, real or complex arrayfire array.
229+
230+
alpha : scalar
231+
232+
beta : scalar
233+
234+
lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
235+
Can be one of
236+
- af.MATPROP.NONE - If no op should be done on `lhs`.
237+
- af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying.
238+
- af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying.
239+
240+
rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
241+
Can be one of
242+
- af.MATPROP.NONE - If no op should be done on `rhs`.
243+
- af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying.
244+
- af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying.
245+
246+
Returns
247+
-------
248+
249+
out : af.Array
250+
Output of the matrix multiplication on `lhs` and `rhs`.
251+
252+
Note
253+
-----
254+
255+
- The data types of `lhs` and `rhs` should be the same.
256+
- Batches are not supported.
257+
258+
"""
259+
ifCisNone:
260+
out=Array()
261+
else:
262+
out=C
263+
264+
ltype=lhs.dtype()
265+
266+
ifltype==Dtype.f32:
267+
aptr=c_cast(c_pointer(c_float_t(alpha)),c_void_ptr_t)
268+
bptr=c_cast(c_pointer(c_float_t(beta)),c_void_ptr_t)
269+
elifltype==Dtype.c32:
270+
ifisinstance(alpha,af_cfloat_t):
271+
aptr=c_cast(c_pointer(alpha),c_void_ptr_t)
272+
elifisinstance(alpha,tuple):
273+
aptr=c_cast(c_pointer(af_cfloat_t(alpha[0],alpha[1])),c_void_ptr_t)
274+
else:
275+
aptr=c_cast(c_pointer(af_cfloat_t(alpha)),c_void_ptr_t)
276+
277+
ifisinstance(beta,af_cfloat_t):
278+
bptr=c_cast(c_pointer(beta),c_void_ptr_t)
279+
elifisinstance(beta,tuple):
280+
bptr=c_cast(c_pointer(af_cfloat_t(beta[0],beta[1])),c_void_ptr_t)
281+
else:
282+
bptr=c_cast(c_pointer(af_cfloat_t(beta)),c_void_ptr_t)
283+
284+
elifltype==Dtype.f64:
285+
aptr=c_cast(c_pointer(c_double_t(alpha)),c_void_ptr_t)
286+
bptr=c_cast(c_pointer(c_double_t(beta)),c_void_ptr_t)
287+
elifltype==Dtype.c64:
288+
ifisinstance(alpha,af_cdouble_t):
289+
aptr=c_cast(c_pointer(alpha),c_void_ptr_t)
290+
elifisinstance(alpha,tuple):
291+
aptr=c_cast(c_pointer(af_cdouble_t(alpha[0],alpha[1])),c_void_ptr_t)
292+
else:
293+
aptr=c_cast(c_pointer(af_cdouble_t(alpha)),c_void_ptr_t)
294+
295+
ifisinstance(beta,af_cdouble_t):
296+
bptr=c_cast(c_pointer(beta),c_void_ptr_t)
297+
elifisinstance(beta,tuple):
298+
bptr=c_cast(c_pointer(af_cdouble_t(beta[0],beta[1])),c_void_ptr_t)
299+
else:
300+
bptr=c_cast(c_pointer(af_cdouble_t(beta)),c_void_ptr_t)
301+
elifltype==Dtype.f16:
302+
raiseTypeError("fp16 currently unsupported gemm() input type")
303+
else:
304+
raiseTypeError("unsupported input type")
305+
306+
307+
safe_call(backend.get().af_gemm(c_pointer(out.arr),
308+
lhs_opts.value,rhs_opts.value,
309+
aptr,lhs.arr,rhs.arr,bptr))
310+
returnout

‎arrayfire/library.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
c_void_ptr_t=ct.c_void_p
3232
c_char_ptr_t=ct.c_char_p
3333
c_size_t=ct.c_size_t
34+
c_cast=ct.cast
35+
36+
classaf_cfloat_t(ct.Structure):
37+
_fields_= [("real",ct.c_float), ("imag",ct.c_float)]
38+
39+
classaf_cdouble_t(ct.Structure):
40+
_fields_= [("real",ct.c_double), ("imag",ct.c_double)]
3441

3542

3643
AF_VER_MAJOR='3'

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp