Source code for arrayfire.device

######################################################## Copyright (c) 2015, ArrayFire# All rights reserved.## This file is distributed under 3-clause BSD license.# The complete license agreement can be obtained at:# http://arrayfire.com/licenses/BSD-3-Clause########################################################"""Functions to handle the available devices in the backend."""from.libraryimport*from.utilimport(safe_call,to_str,get_version)
[docs]definit():""" Note ----- This function may need to be called when interoperating with other libraries """safe_call(backend.get().af_init())
[docs]definfo():""" Displays the information about the following: - ArrayFire build and version number. - The number of devices available. - The names of the devices. - The current device being used. """safe_call(backend.get().af_info())
[docs]defdevice_info():""" Returns a map with the following fields: - 'device': Name of the current device. - 'backend': The current backend being used. - 'toolkit': The toolkit version for the backend. - 'compute': The compute version of the device. """c_char_256=c_char_t*256device_name=c_char_256()backend_name=c_char_256()toolkit=c_char_256()compute=c_char_256()safe_call(backend.get().af_device_info(c_pointer(device_name),c_pointer(backend_name),c_pointer(toolkit),c_pointer(compute)))dev_info={}dev_info['device']=to_str(device_name)dev_info['backend']=to_str(backend_name)dev_info['toolkit']=to_str(toolkit)dev_info['compute']=to_str(compute)returndev_info
[docs]defget_device_count():""" Returns the number of devices available. """c_num=c_int_t(0)safe_call(backend.get().af_get_device_count(c_pointer(c_num)))returnc_num.value
[docs]defget_device():""" Returns the id of the current device. """c_dev=c_int_t(0)safe_call(backend.get().af_get_device(c_pointer(c_dev)))returnc_dev.value
[docs]defset_device(num):""" Change the active device to the specified id. Parameters ----------- num: int. id of the desired device. """safe_call(backend.get().af_set_device(num))
[docs]definfo_str(verbose=False):""" Returns information about the following as a string: - ArrayFire version number. - The number of devices available. - The names of the devices. - The current device being used. """importplatformres_str='ArrayFire'major,minor,patch=get_version()dev_info=device_info()backend_str=dev_info['backend']res_str+=' v'+str(major)+'.'+str(minor)+'.'+str(patch)res_str+=' ('+backend_str+' '+platform.architecture()[0]+')\n'num_devices=get_device_count()curr_device_id=get_device()forninrange(num_devices):# To suppress warnings on CPUif(n!=curr_device_id):set_device(n)if(n==curr_device_id):res_str+='[%d] '%nelse:res_str+='-%d- '%ndev_info=device_info()if(backend_str.lower()=='opencl'):res_str+=dev_info['toolkit']res_str+=': '+dev_info['device']if(backend_str.lower()!='cpu'):res_str+=' (Compute '+dev_info['compute']+')'res_str+='\n'# To suppress warnings on CPUif(curr_device_id!=get_device()):set_device(curr_device_id)returnres_str
[docs]defis_dbl_supported(device=None):""" Check if double precision is supported on specified device. Parameters ----------- device: optional: int. default: None. id of the desired device. Returns -------- - True if double precision supported. - False if double precision not supported. """dev=deviceifdeviceisnotNoneelseget_device()res=c_bool_t(False)safe_call(backend.get().af_get_dbl_support(c_pointer(res),dev))returnres.value
[docs]defis_half_supported(device=None):""" Check if half precision is supported on specified device. Parameters ----------- device: optional: int. default: None. id of the desired device. Returns -------- - True if half precision supported. - False if half precision not supported. """dev=deviceifdeviceisnotNoneelseget_device()res=c_bool_t(False)safe_call(backend.get().af_get_half_support(c_pointer(res),dev))returnres.value
[docs]defsync(device=None):""" Block until all the functions on the device have completed execution. Parameters ----------- device: optional: int. default: None. id of the desired device. """dev=deviceifdeviceisnotNoneelseget_device()safe_call(backend.get().af_sync(dev))
def__eval(*args):nargs=len(args)if(nargs==1):safe_call(backend.get().af_eval(args[0].arr))else:c_void_p_n=c_void_ptr_t*nargsarrs=c_void_p_n()forninrange(nargs):arrs[n]=args[n].arrsafe_call(backend.get().af_eval_multiple(c_int_t(nargs),c_pointer(arrs)))return
[docs]defeval(*args):""" Evaluate one or more inputs together Parameters ----------- args : arguments to be evaluated Note ----- All the input arrays to this function should be of the same size. Examples -------- >>> a = af.constant(1, 3, 3) >>> b = af.constant(2, 3, 3) >>> c = a + b >>> d = a - b >>> af.eval(c, d) # A single kernel is launched here >>> c arrayfire.Array() Type: float [3 3 1 1] 3.0000 3.0000 3.0000 3.0000 3.0000 3.0000 3.0000 3.0000 3.0000 >>> d arrayfire.Array() Type: float [3 3 1 1] -1.0000 -1.0000 -1.0000 -1.0000 -1.0000 -1.0000 -1.0000 -1.0000 -1.0000 """forarginargs:ifnotisinstance(arg,Array):raiseRuntimeError("All inputs to eval must be of type arrayfire.Array")__eval(*args)
[docs]defset_manual_eval_flag(flag):""" Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree. Parameters ---------- flag : optional: bool. - Specifies if the heuristic evaluation of the JIT tree needs to be disabled. Note ---- This does not affect the evaluation that occurs when a non JIT function forces the evaluation. """safe_call(backend.get().af_set_manual_eval_flag(flag))
[docs]defget_manual_eval_flag():""" Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree. Note ---- This does not affect the evaluation that occurs when a non JIT function forces the evaluation. """res=c_bool_t(False)safe_call(backend.get().af_get_manual_eval_flag(c_pointer(res)))returnres.value
[docs]defdevice_mem_info():""" Returns a map with the following fields: - 'alloc': Contains the map of the following - 'buffers' : Total number of buffers allocated by memory manager. - 'bytes' : Total number of bytes allocated by memory manager. - 'lock': Contains the map of the following - 'buffers' : Total number of buffers currently in scope. - 'bytes' : Total number of bytes currently in scope. Note ----- ArrayFire does not free memory when array goes out of scope. The memory is marked for reuse. - The difference between alloc buffers and lock buffers equals the number of free buffers. - The difference between alloc bytes and lock bytes equals the number of free bytes. """alloc_bytes=c_size_t(0)alloc_buffers=c_size_t(0)lock_bytes=c_size_t(0)lock_buffers=c_size_t(0)safe_call(backend.get().af_device_mem_info(c_pointer(alloc_bytes),c_pointer(alloc_buffers),c_pointer(lock_bytes),c_pointer(lock_buffers)))mem_info={}mem_info['alloc']={'buffers':alloc_buffers.value,'bytes':alloc_bytes.value}mem_info['lock']={'buffers':lock_buffers.value,'bytes':lock_bytes.value}returnmem_info
[docs]defprint_mem_info(title="Memory Info",device_id=None):""" Prints the memory used for the specified device. Parameters ---------- title: optional. Default: "Memory Info" - Title to display before printing the memory info. device_id: optional. Default: None - Specifies the device for which the memory info should be displayed. - If None, uses the current device. Examples -------- >>> a = af.randu(5,5) >>> af.print_mem_info() Memory Info --------------------------------------------------------- | POINTER | SIZE | AF LOCK | USER LOCK | --------------------------------------------------------- | 0x706400000 | 1 KB | Yes | No | --------------------------------------------------------- >>> b = af.randu(5,5) >>> af.print_mem_info() Memory Info --------------------------------------------------------- | POINTER | SIZE | AF LOCK | USER LOCK | --------------------------------------------------------- | 0x706400400 | 1 KB | Yes | No | | 0x706400000 | 1 KB | Yes | No | --------------------------------------------------------- >>> a = af.randu(1000,1000) >>> af.print_mem_info() Memory Info --------------------------------------------------------- | POINTER | SIZE | AF LOCK | USER LOCK | --------------------------------------------------------- | 0x706500000 | 3.815 MB | Yes | No | | 0x706400400 | 1 KB | Yes | No | | 0x706400000 | 1 KB | No | No | --------------------------------------------------------- """device_id=device_idifdevice_idelseget_device()safe_call(backend.get().af_print_mem_info(title.encode('utf-8'),device_id))
[docs]defdevice_gc():""" Ask the garbage collector to free all unlocked memory """safe_call(backend.get().af_device_gc())
[docs]defget_device_ptr(a):""" Get the raw device pointer of an array Parameters ---------- a: af.Array - A multi dimensional arrayfire array. Returns ------- - internal device pointer held by a Note ----- - The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called. - This function enables the user to interoperate arrayfire with other CUDA/OpenCL/C libraries. """ptr=c_void_ptr_t(0)safe_call(backend.get().af_get_device_ptr(c_pointer(ptr),a.arr))returnptr
[docs]deflock_device_ptr(a):""" This functions is deprecated. Please use lock_array instead. """importwarningswarnings.warn("This function is deprecated. Use lock_array instead.",DeprecationWarning)lock_array(a)
[docs]deflock_array(a):""" Ask arrayfire to not perform garbage collection on raw data held by an array. Parameters ---------- a: af.Array - A multi dimensional arrayfire array. Note ----- - The device pointer of `a` is not freed by memory manager until `unlock_array()` is called. """safe_call(backend.get().af_lock_array(a.arr))
[docs]defis_locked_array(a):""" Check if the input array is locked by the user. Parameters ---------- a: af.Array - A multi dimensional arrayfire array. Returns ----------- A bool specifying if the input array is locked. """res=c_bool_t(False)safe_call(backend.get().af_is_locked_array(c_pointer(res),a.arr))returnres.value
[docs]defunlock_device_ptr(a):""" This functions is deprecated. Please use unlock_array instead. """importwarningswarnings.warn("This function is deprecated. Use unlock_array instead.",DeprecationWarning)unlock_array(a)
[docs]defunlock_array(a):""" Tell arrayfire to resume garbage collection on raw data held by an array. Parameters ---------- a: af.Array - A multi dimensional arrayfire array. """safe_call(backend.get().af_unlock_array(a.arr))
[docs]defalloc_device(num_bytes):""" Allocate a buffer on the device with specified number of bytes. """ptr=c_void_ptr_t(0)c_num_bytes=c_dim_t(num_bytes)safe_call(backend.get().af_alloc_device(c_pointer(ptr),c_num_bytes))returnptr.value
[docs]defalloc_host(num_bytes):""" Allocate a buffer on the host with specified number of bytes. """ptr=c_void_ptr_t(0)c_num_bytes=c_dim_t(num_bytes)safe_call(backend.get().af_alloc_host(c_pointer(ptr),c_num_bytes))returnptr.value
[docs]defalloc_pinned(num_bytes):""" Allocate a buffer on the host using pinned memory with specified number of bytes. """ptr=c_void_ptr_t(0)c_num_bytes=c_dim_t(num_bytes)safe_call(backend.get().af_alloc_pinned(c_pointer(ptr),c_num_bytes))returnptr.value
[docs]deffree_device(ptr):""" Free the device memory allocated by alloc_device """cptr=c_void_ptr_t(ptr)safe_call(backend.get().af_free_device(cptr))
[docs]deffree_host(ptr):""" Free the host memory allocated by alloc_host """cptr=c_void_ptr_t(ptr)safe_call(backend.get().af_free_host(cptr))
[docs]deffree_pinned(ptr):""" Free the pinned memory allocated by alloc_pinned """cptr=c_void_ptr_t(ptr)safe_call(backend.get().af_free_pinned(cptr))
from.arrayimportArray