######################################################## Copyright (c) 2015, ArrayFire# All rights reserved.## This file is distributed under 3-clause BSD license.# The complete license agreement can be obtained at:# http://arrayfire.com/licenses/BSD-3-Clause########################################################"""BLAS functions (matmul, dot, etc)"""from.libraryimport*from.arrayimport*[docs]defmatmul(lhs,rhs,lhs_opts=MATPROP.NONE,rhs_opts=MATPROP.NONE):""" Generalized matrix multiplication for two matrices. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """out=Array()safe_call(backend.get().af_matmul(c_pointer(out.arr),lhs.arr,rhs.arr,lhs_opts.value,rhs_opts.value))returnout [docs]defmatmulTN(lhs,rhs):""" Matrix multiplication after transposing the first matrix. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `transpose(lhs)` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """out=Array()safe_call(backend.get().af_matmul(c_pointer(out.arr),lhs.arr,rhs.arr,MATPROP.TRANS.value,MATPROP.NONE.value))returnout [docs]defmatmulNT(lhs,rhs):""" Matrix multiplication after transposing the second matrix. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `transpose(rhs)`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """out=Array()safe_call(backend.get().af_matmul(c_pointer(out.arr),lhs.arr,rhs.arr,MATPROP.NONE.value,MATPROP.TRANS.value))returnout [docs]defmatmulTT(lhs,rhs):""" Matrix multiplication after transposing both inputs. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. Returns ------- out : af.Array Output of the matrix multiplication on `transpose(lhs)` and `transpose(rhs)`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """out=Array()safe_call(backend.get().af_matmul(c_pointer(out.arr),lhs.arr,rhs.arr,MATPROP.TRANS.value,MATPROP.TRANS.value))returnout [docs]defdot(lhs,rhs,lhs_opts=MATPROP.NONE,rhs_opts=MATPROP.NONE,return_scalar=False):""" Dot product of two input vectors. Parameters ---------- lhs : af.Array A 1 dimensional, real or complex arrayfire array. rhs : af.Array A 1 dimensional, real or complex arrayfire array. lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - No other options are currently supported. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - No other options are currently supported. return_scalar: optional: bool. default: False. - When set to true, the input arrays are flattened and the output is a scalar Returns ------- out : af.Array or scalar Output of dot product of `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ifreturn_scalar:real=c_double_t(0)imag=c_double_t(0)safe_call(backend.get().af_dot_all(c_pointer(real),c_pointer(imag),lhs.arr,rhs.arr,lhs_opts.value,rhs_opts.value))real=real.valueimag=imag.valuereturnrealifimag==0elsereal+imag*1jelse:out=Array()safe_call(backend.get().af_dot(c_pointer(out.arr),lhs.arr,rhs.arr,lhs_opts.value,rhs_opts.value))returnout [docs]defgemm(lhs,rhs,alpha=1.0,beta=0.0,lhs_opts=MATPROP.NONE,rhs_opts=MATPROP.NONE,C=None):""" BLAS general matrix multiply (GEMM) of two af_array objects. This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as: C = alpha * opA(A) opB(B) + beta * C where alpha and beta are both scalars; A and B are the matrix multiply operands; and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations on A or B before the actual GEMM operation. Batched GEMM is supported if at least either A or B have more than two dimensions (see af::matmul for more details on broadcasting). However, only one alpha and one beta can be used for all of the batched matrix operands. Parameters ---------- lhs : af.Array A 2 dimensional, real or complex arrayfire array. rhs : af.Array A 2 dimensional, real or complex arrayfire array. alpha : scalar beta : scalar lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `lhs`. - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying. rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE. Can be one of - af.MATPROP.NONE - If no op should be done on `rhs`. - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying. - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying. Returns ------- out : af.Array Output of the matrix multiplication on `lhs` and `rhs`. Note ----- - The data types of `lhs` and `rhs` should be the same. - Batches are not supported. """ifCisNone:out=Array()else:out=Cltype=lhs.dtype()ifltype==Dtype.f32:aptr=c_cast(c_pointer(c_float_t(alpha)),c_void_ptr_t)bptr=c_cast(c_pointer(c_float_t(beta)),c_void_ptr_t)elifltype==Dtype.c32:ifisinstance(alpha,af_cfloat_t):aptr=c_cast(c_pointer(alpha),c_void_ptr_t)elifisinstance(alpha,tuple):aptr=c_cast(c_pointer(af_cfloat_t(alpha[0],alpha[1])),c_void_ptr_t)else:aptr=c_cast(c_pointer(af_cfloat_t(alpha)),c_void_ptr_t)ifisinstance(beta,af_cfloat_t):bptr=c_cast(c_pointer(beta),c_void_ptr_t)elifisinstance(beta,tuple):bptr=c_cast(c_pointer(af_cfloat_t(beta[0],beta[1])),c_void_ptr_t)else:bptr=c_cast(c_pointer(af_cfloat_t(beta)),c_void_ptr_t)elifltype==Dtype.f64:aptr=c_cast(c_pointer(c_double_t(alpha)),c_void_ptr_t)bptr=c_cast(c_pointer(c_double_t(beta)),c_void_ptr_t)elifltype==Dtype.c64:ifisinstance(alpha,af_cdouble_t):aptr=c_cast(c_pointer(alpha),c_void_ptr_t)elifisinstance(alpha,tuple):aptr=c_cast(c_pointer(af_cdouble_t(alpha[0],alpha[1])),c_void_ptr_t)else:aptr=c_cast(c_pointer(af_cdouble_t(alpha)),c_void_ptr_t)ifisinstance(beta,af_cdouble_t):bptr=c_cast(c_pointer(beta),c_void_ptr_t)elifisinstance(beta,tuple):bptr=c_cast(c_pointer(af_cdouble_t(beta[0],beta[1])),c_void_ptr_t)else:bptr=c_cast(c_pointer(af_cdouble_t(beta)),c_void_ptr_t)elifltype==Dtype.f16:raiseTypeError("fp16 currently unsupported gemm() input type")else:raiseTypeError("unsupported input type")safe_call(backend.get().af_gemm(c_pointer(out.arr),lhs_opts.value,rhs_opts.value,aptr,lhs.arr,rhs.arr,bptr))returnout