Rate this Page

Device Management#

Created On: Nov 14, 2025 | Last Updated On: Dec 09, 2025

Background#

Device management covers basics such as querying how many devices are available and switching between them. Accelerator backends wrap their device‑runtime APIs and expose them to PyTorch.

Design#

Accelerator vendors should implement these core functions:

Function name

Description

Application scenarios

device_count()

Query the total number of available devices in the system

- Application initialization
- Multi-device workload distribution
- Validating device indices before use

current_device()

Get the currently active device for the calling thread

- Debugging and logging
- Determining tensor placement
- Guard implementations

set_device()

Change the active device for subsequent operations

- Switching context between devices
- Initializing specific device resources
- Multi-GPU training loops

exchange_device()

Atomically swap device and return the previous device

- Implementing device guards
- Temporarily switching device context
- RAII-based device management

maybe_exchange_device()

Conditionally exchange device only if the index is valid (−1 allowed)

- Safe device switching with optional indices
- Guard implementations with nullable device values

These functions are the building blocks for streams, events, and memory management. Validate inputs and handle errors properly.

Implementation#

This section illustrates device management usingset_device as an example. The implementation requires:

  1. C++ wrappers around the device runtime

  2. Python bindings to expose the C++ functions

  3. User-friendly Python APIs

For illustration, OpenReg (Open Registration) is a PyTorch integration example that fills the gap for out‑of‑tree accelerator backend integration. Its implementation (OpenRegFunctions.h/cpp) demonstrates how to wrap a third‑party runtime cleanly. These functions are reused across the backend—for streams, events, generators, and Python bindings.

C++ side#

Wrap the device‑runtime API and add error handling. TheSetDevice function shows this pattern:

1orError_tSetDevice(DeviceIndexdevice){2intcur_device=-1;3OPENREG_CHECK(orGetDevice(&cur_device));4if(device==cur_device){5returnorSuccess;6}7returnorSetDevice(device);8}
1OPENREG_EXPORTvoidset_device(DeviceIndexdevice){2check_device_index(device);3OPENREG_CHECK(SetDevice(device));4}

Bindings#

Expose the C++ functions to Python using pybind11:

 1PyObject*_setDevice(PyObject*self,PyObject*arg){ 2HANDLE_TH_ERRORS 3TORCH_CHECK(THPUtils_checkLong(arg),"invalid argument to setDevice"); 4autodevice=THPUtils_unpackDeviceIndex(arg); 5torch::utils::device_lazy_init(at::kPrivateUse1); 6c10::openreg::set_device(device); 7 8Py_RETURN_NONE; 9END_HANDLE_TH_ERRORS10}
1staticPyMethodDefmethods[]={2{"_init",_initExtension,METH_NOARGS,nullptr},3{"_get_default_generator",_getDefaultGenerator,METH_O,nullptr},4{"_get_device",_getDevice,METH_NOARGS,nullptr},5{"_set_device",_setDevice,METH_O,nullptr},6{"_exchangeDevice",_exchangeDevice,METH_O,nullptr},7{"_get_device_count",_getDeviceCount,METH_NOARGS,nullptr},8{nullptr,nullptr,0,nullptr}};

Python side#

Wrap the C++ bindings with user-friendly Python functions:

1defset_device(device)->None:2ifdevice>=0:3torch_openreg._C._set_device(device)45

Here’s the complete mapping from C++ to Python:

C++ binding function

C++ binding API (pybind11)

Python user API

Description

_getDeviceCount

torch_openreg._C._get_device_count()

torch.openreg.device_count()

Returns the total number of devices

_getDevice

torch_openreg._C._get_device()

torch.openreg.current_device()

Returns the current active device index

_setDevice

torch_openreg._C._set_device(idx)

torch.openreg.set_device(idx)

Sets the active device

_exchangeDevice

torch_openreg._C._exchange_device(idx)

N/A (internal use only)

Atomically swaps device and returns previous

Guard#

Device guards provide automatic device switching with exception safety. They’re similar to C++ lock guards—they switch devices on construction and restore on destruction.

ImplementDeviceGuardImplInterface to integrate with PyTorch’s guard system:

 1/** 2   * Return the type of device managed by this guard implementation. 3   */ 4DeviceTypetype()constoverride{ 5returnstatic_type; 6} 7/** 8   * Set the current device to device d, and return the previous Device. 9   */10// LITERALINCLUDE START: OPENREG GUARD DEVICE MANAGEMENT11DeviceexchangeDevice(Deviced)constoverride{12TORCH_CHECK(d.is_privateuseone(),"Expected a PrivateUse1 device, but got ",d);1314autoold_device_index=ExchangeDevice(d.index());15returnDevice(static_type,old_device_index);16}17// LITERALINCLUDE END: OPENREG GUARD DEVICE MANAGEMENT1819/**20   * Get the current device.21   */22DevicegetDevice()constoverride{23intdevice_index=current_device();24returnc10::Device(static_type,device_index);25}2627/**28   * Get the device capability for a given device.29   * By default, OpenReg has 2 same devices with the same capability.30   */31DeviceCapabilitygetDeviceCapability(Device/*unused*/)constoverride{32returnDeviceCapability();33}3435/**36   * Set the current device to c10::Device.37   */38voidsetDevice(Deviced)constoverride{39TORCH_CHECK(d.is_privateuseone(),"Expected a PrivateUse1 device, but got ",d);4041set_device(d.index());42}4344/**45   * Set the current device to device d, without checking for errors46   * (so, e.g., this can be called from a destructor).47   */48voiduncheckedSetDevice(Deviced)constnoexceptoverride{49set_device(d.index());50}5152/**53   * Get the number of devices.54   *55   * WARNING: This is REQUIRED to not raise an exception.56   * If there is some sort of problem, e.g., driver error,57   * you should report that there are zero available devices.58   */59DeviceIndexdeviceCount()constnoexceptoverride{60returndevice_count();61}6263/**64   * Wait (by blocking the calling thread) until all the work has65   * completed running on the device.66   */67voidsynchronizeDevice(constDeviceIndexdevice_index)constoverride{68OPENREG_CHECK(orDeviceSynchronize());69}

This makes the guard available in PyTorch for thePrivateUse1 device type; users can then use standard PyTorch device guards with the custom backend.