Source code for arrayfire.util

######################################################## Copyright (c) 2015, ArrayFire# All rights reserved.## This file is distributed under 3-clause BSD license.# The complete license agreement can be obtained at:# http://arrayfire.com/licenses/BSD-3-Clause########################################################"""Utility functions to help with Array metadata."""from.libraryimport*importnumbers
[docs]defdim4(d0=1,d1=1,d2=1,d3=1):c_dim4=c_dim_t*4out=c_dim4(1,1,1,1)fori,diminenumerate((d0,d1,d2,d3)):if(dimisnotNone):out[i]=c_dim_t(dim)returnout
def_is_number(a):returnisinstance(a,numbers.Number)
[docs]defnumber_dtype(a):ifisinstance(a,bool):returnDtype.b8ifisinstance(a,int):returnDtype.s64elifisinstance(a,float):returnDtype.f64elifisinstance(a,complex):returnDtype.c64else:returnto_dtype[a.dtype.char]
[docs]defimplicit_dtype(number,a_dtype):n_dtype=number_dtype(number)n_value=n_dtype.valuef64v=Dtype.f64.valuef32v=Dtype.f32.valuec32v=Dtype.c32.valuec64v=Dtype.c64.valueifn_value==f64vand(a_dtype==f32vora_dtype==c32v):returnDtype.f32ifn_value==c64vand(a_dtype==f32vora_dtype==c32v):returnDtype.c32returnn_dtype
[docs]defdim4_to_tuple(dims,default=1):assert(isinstance(dims,tuple))if(defaultisnotNone):assert(_is_number(default))out=[default]*4fori,diminenumerate(dims):out[i]=dimreturntuple(out)
[docs]defto_str(c_str):returnstr(c_str.value.decode('utf-8'))
[docs]defsafe_call(af_error):if(af_error!=ERR.NONE.value):err_str=c_char_ptr_t(0)err_len=c_dim_t(0)backend.get().af_get_last_error(c_pointer(err_str),c_pointer(err_len))raiseRuntimeError(to_str(err_str))
[docs]defget_version():""" Function to get the version of arrayfire. """major=c_int_t(0)minor=c_int_t(0)patch=c_int_t(0)safe_call(backend.get().af_get_version(c_pointer(major),c_pointer(minor),c_pointer(patch)))returnmajor.value,minor.value,patch.value
[docs]defget_reversion():""" Function to get the revision hash of the library. """returnto_str(backend.get().af_get_revision())
to_dtype={'f':Dtype.f32,'d':Dtype.f64,'b':Dtype.b8,'B':Dtype.u8,'h':Dtype.s16,'H':Dtype.u16,'i':Dtype.s32,'I':Dtype.u32,'l':Dtype.s64,'L':Dtype.u64,'F':Dtype.c32,'D':Dtype.c64,'hf':Dtype.f16}to_typecode={Dtype.f32.value:'f',Dtype.f64.value:'d',Dtype.b8.value:'b',Dtype.u8.value:'B',Dtype.s16.value:'h',Dtype.u16.value:'H',Dtype.s32.value:'i',Dtype.u32.value:'I',Dtype.s64.value:'l',Dtype.u64.value:'L',Dtype.c32.value:'F',Dtype.c64.value:'D',Dtype.f16.value:'hf'}to_c_type={Dtype.f32.value:c_float_t,Dtype.f64.value:c_double_t,Dtype.b8.value:c_char_t,Dtype.u8.value:c_uchar_t,Dtype.s16.value:c_short_t,Dtype.u16.value:c_ushort_t,Dtype.s32.value:c_int_t,Dtype.u32.value:c_uint_t,Dtype.s64.value:c_longlong_t,Dtype.u64.value:c_ulonglong_t,Dtype.c32.value:c_float_t*2,Dtype.c64.value:c_double_t*2,Dtype.f16.value:c_ushort_t}to_typename={Dtype.f32.value:'float',Dtype.f64.value:'double',Dtype.b8.value:'bool',Dtype.u8.value:'unsigned char',Dtype.s16.value:'short int',Dtype.u16.value:'unsigned short int',Dtype.s32.value:'int',Dtype.u32.value:'unsigned int',Dtype.s64.value:'long int',Dtype.u64.value:'unsigned long int',Dtype.c32.value:'float complex',Dtype.c64.value:'double complex',Dtype.f16.value:'half'}