Device Management#
Created On: Nov 14, 2025 | Last Updated On: Dec 09, 2025
Background#
Device management covers basics such as querying how many devices are available and switching between them. Accelerator backends wrap their device‑runtime APIs and expose them to PyTorch.
Design#
Accelerator vendors should implement these core functions:
Function name | Description | Application scenarios |
|---|---|---|
| Query the total number of available devices in the system | - Application initialization |
| Get the currently active device for the calling thread | - Debugging and logging |
| Change the active device for subsequent operations | - Switching context between devices |
| Atomically swap device and return the previous device | - Implementing device guards |
| Conditionally exchange device only if the index is valid (−1 allowed) | - Safe device switching with optional indices |
These functions are the building blocks for streams, events, and memory management. Validate inputs and handle errors properly.
Implementation#
This section illustrates device management usingset_device as an example. The implementation requires:
C++ wrappers around the device runtime
Python bindings to expose the C++ functions
User-friendly Python APIs
For illustration, OpenReg (Open Registration) is a PyTorch integration example that fills the gap for out‑of‑tree accelerator backend integration. Its implementation (OpenRegFunctions.h/cpp) demonstrates how to wrap a third‑party runtime cleanly. These functions are reused across the backend—for streams, events, generators, and Python bindings.
C++ side#
Wrap the device‑runtime API and add error handling. TheSetDevice function shows this pattern:
1orError_tSetDevice(DeviceIndexdevice){2intcur_device=-1;3OPENREG_CHECK(orGetDevice(&cur_device));4if(device==cur_device){5returnorSuccess;6}7returnorSetDevice(device);8}
1OPENREG_EXPORTvoidset_device(DeviceIndexdevice){2check_device_index(device);3OPENREG_CHECK(SetDevice(device));4}
Bindings#
Expose the C++ functions to Python using pybind11:
1PyObject*_setDevice(PyObject*self,PyObject*arg){ 2HANDLE_TH_ERRORS 3TORCH_CHECK(THPUtils_checkLong(arg),"invalid argument to setDevice"); 4autodevice=THPUtils_unpackDeviceIndex(arg); 5torch::utils::device_lazy_init(at::kPrivateUse1); 6c10::openreg::set_device(device); 7 8Py_RETURN_NONE; 9END_HANDLE_TH_ERRORS10}
1staticPyMethodDefmethods[]={2{"_init",_initExtension,METH_NOARGS,nullptr},3{"_get_default_generator",_getDefaultGenerator,METH_O,nullptr},4{"_get_device",_getDevice,METH_NOARGS,nullptr},5{"_set_device",_setDevice,METH_O,nullptr},6{"_exchangeDevice",_exchangeDevice,METH_O,nullptr},7{"_get_device_count",_getDeviceCount,METH_NOARGS,nullptr},8{nullptr,nullptr,0,nullptr}};
Python side#
Wrap the C++ bindings with user-friendly Python functions:
1defset_device(device)->None:2ifdevice>=0:3torch_openreg._C._set_device(device)45
Here’s the complete mapping from C++ to Python:
C++ binding function | C++ binding API (pybind11) | Python user API | Description |
|---|---|---|---|
|
|
| Returns the total number of devices |
|
|
| Returns the current active device index |
|
|
| Sets the active device |
|
| N/A (internal use only) | Atomically swaps device and returns previous |
Guard#
Device guards provide automatic device switching with exception safety. They’re similar to C++ lock guards—they switch devices on construction and restore on destruction.
ImplementDeviceGuardImplInterface to integrate with PyTorch’s guard system:
1/** 2 * Return the type of device managed by this guard implementation. 3 */ 4DeviceTypetype()constoverride{ 5returnstatic_type; 6} 7/** 8 * Set the current device to device d, and return the previous Device. 9 */10// LITERALINCLUDE START: OPENREG GUARD DEVICE MANAGEMENT11DeviceexchangeDevice(Deviced)constoverride{12TORCH_CHECK(d.is_privateuseone(),"Expected a PrivateUse1 device, but got ",d);1314autoold_device_index=ExchangeDevice(d.index());15returnDevice(static_type,old_device_index);16}17// LITERALINCLUDE END: OPENREG GUARD DEVICE MANAGEMENT1819/**20 * Get the current device.21 */22DevicegetDevice()constoverride{23intdevice_index=current_device();24returnc10::Device(static_type,device_index);25}2627/**28 * Get the device capability for a given device.29 * By default, OpenReg has 2 same devices with the same capability.30 */31DeviceCapabilitygetDeviceCapability(Device/*unused*/)constoverride{32returnDeviceCapability();33}3435/**36 * Set the current device to c10::Device.37 */38voidsetDevice(Deviced)constoverride{39TORCH_CHECK(d.is_privateuseone(),"Expected a PrivateUse1 device, but got ",d);4041set_device(d.index());42}4344/**45 * Set the current device to device d, without checking for errors46 * (so, e.g., this can be called from a destructor).47 */48voiduncheckedSetDevice(Deviced)constnoexceptoverride{49set_device(d.index());50}5152/**53 * Get the number of devices.54 *55 * WARNING: This is REQUIRED to not raise an exception.56 * If there is some sort of problem, e.g., driver error,57 * you should report that there are zero available devices.58 */59DeviceIndexdeviceCount()constnoexceptoverride{60returndevice_count();61}6263/**64 * Wait (by blocking the calling thread) until all the work has65 * completed running on the device.66 */67voidsynchronizeDevice(constDeviceIndexdevice_index)constoverride{68OPENREG_CHECK(orDeviceSynchronize());69}
This makes the guard available in PyTorch for thePrivateUse1 device type; users can then use standard PyTorch device guards with the custom backend.