HIP (ROCm) semantics#
Created On: May 12, 2021 | Last Updated On: Aug 08, 2025
ROCm™ is AMD’s open source software platform for GPU-accelerated highperformance computing and machine learning. HIP is ROCm’s C++ dialect designedto ease conversion of CUDA applications to portable C++ code. HIP is used whenconverting existing CUDA applications like PyTorch to portable C++ and for newprojects that require portability between AMD and NVIDIA.
HIP Interfaces Reuse the CUDA Interfaces#
PyTorch for HIP intentionally reuses the existingtorch.cuda interfaces.This helps to accelerate the porting of existing PyTorch code and models becausevery few code changes are necessary, if any.
The example fromCUDA semantics will work exactly the same for HIP:
cuda=torch.device('cuda')# Default HIP devicecuda0=torch.device('cuda:0')# 'rocm' or 'hip' are not valid, use 'cuda'cuda2=torch.device('cuda:2')# GPU 2 (these are 0-indexed)x=torch.tensor([1.,2.],device=cuda0)# x.device is device(type='cuda', index=0)y=torch.tensor([1.,2.]).cuda()# y.device is device(type='cuda', index=0)withtorch.cuda.device(1):# allocates a tensor on GPU 1a=torch.tensor([1.,2.],device=cuda)# transfers a tensor from CPU to GPU 1b=torch.tensor([1.,2.]).cuda()# a.device and b.device are device(type='cuda', index=1)# You can also use ``Tensor.to`` to transfer a tensor:b2=torch.tensor([1.,2.]).to(device=cuda)# b.device and b2.device are device(type='cuda', index=1)c=a+b# c.device is device(type='cuda', index=1)z=x+y# z.device is device(type='cuda', index=0)# even within a context, you can specify the device# (or give a GPU index to the .cuda call)d=torch.randn(2,device=cuda2)e=torch.randn(2).to(cuda2)f=torch.randn(2).cuda(cuda2)# d.device, e.device, and f.device are all device(type='cuda', index=2)
Checking for HIP#
Whether you are using PyTorch for CUDA or HIP, the result of callingis_available() will be the same. If you are using a PyTorchthat has been built with GPU support, it will returnTrue. If you must checkwhich version of PyTorch you are using, refer to this example below:
iftorch.cuda.is_available()andtorch.version.hip:# do something specific for HIPeliftorch.cuda.is_available()andtorch.version.cuda:# do something specific for CUDA
TensorFloat-32(TF32) on ROCm#
TF32 is not supported on ROCm.
Memory management#
PyTorch uses a caching memory allocator to speed up memory allocations. Thisallows fast memory deallocation without device synchronizations. However, theunused memory managed by the allocator will still show as if used inrocm-smi. You can usememory_allocated() andmax_memory_allocated() to monitor memory occupied bytensors, and usememory_reserved() andmax_memory_reserved() to monitor the total amount of memorymanaged by the caching allocator. Callingempty_cache()releases allunused cached memory from PyTorch so that those can be usedby other GPU applications. However, the occupied GPU memory by tensors will notbe freed so it can not increase the amount of GPU memory available for PyTorch.
For more advanced users, we offer more comprehensive memory benchmarking viamemory_stats(). We also offer the capability to capture acomplete snapshot of the memory allocator state viamemory_snapshot(), which can help you understand theunderlying allocation patterns produced by your code.
To debug memory errors, setPYTORCH_NO_HIP_MEMORY_CACHING=1 in your environment to disable caching.PYTORCH_NO_CUDA_MEMORY_CACHING=1 is also accepted for ease of porting.
hipBLAS workspaces#
For each combination of hipBLAS handle and HIP stream, a hipBLAS workspace will be allocated if thathandle and stream combination executes a hipBLAS kernel that requires a workspace. In order toavoid repeatedly allocating workspaces, these workspaces are not deallocated unlesstorch._C._cuda_clearCublasWorkspaces() is called; note that it’s the same function for CUDA orHIP. The workspace size per allocation can be specified via the environment variableHIPBLAS_WORKSPACE_CONFIG with the format:[SIZE]:[COUNT]. As an example, the environmentvariableHIPBLAS_WORKSPACE_CONFIG=:4096:2:16:8 specifies a total size of2*4096+8*16KiB or 8 MIB. The default workspace size is 32 MiB; MI300 and newer defaults to 128 MiB. To forcehipBLAS to avoid using workspaces, setHIPBLAS_WORKSPACE_CONFIG=:0:0. For convenience,CUBLAS_WORKSPACE_CONFIG is also accepted.
hipFFT/rocFFT plan cache#
Setting the size of the cache for hipFFT/rocFFT plans is not supported.
torch.distributed backends#
Currently, only the “nccl” and “gloo” backends for torch.distributed are supported on ROCm.
CUDA API to HIP API mappings in C++#
Please refer:https://rocm.docs.amd.com/projects/HIP/en/latest/reference/api_syntax.html
NOTE: The CUDA_VERSION macro, cudaRuntimeGetVersion and cudaDriverGetVersion APIs do notsemantically map to the same values as HIP_VERSION macro, hipRuntimeGetVersion andhipDriverGetVersion APIs. Please do not use them interchangeably when doing version checks.
For example: Instead of using
#ifdefined(CUDA_VERSION)&&CUDA_VERSION>=11000 to implicitly exclude ROCm/HIP,
use the following to not take the code path for ROCm/HIP:
#ifdefined(CUDA_VERSION)&&CUDA_VERSION>=11000&&!defined(USE_ROCM)
Alternatively, if it is desired to take the code path for ROCm/HIP:
#if(defined(CUDA_VERSION)&&CUDA_VERSION>=11000)||defined(USE_ROCM)
Or if it is desired to take the code path for ROCm/HIP only for specific HIP versions:
#if(defined(CUDA_VERSION)&&CUDA_VERSION>=11000)||(defined(USE_ROCM)&&ROCM_VERSION>=40300)
Refer to CUDA Semantics doc#
For any sections not listed here, please refer to the CUDA semantics doc:CUDA semantics
Enabling kernel asserts#
Kernel asserts are supported on ROCm, but they are disabled due to performance overhead. It can be enabledby recompiling the PyTorch from source.
Please add below line as an argument to cmake command parameters:
-DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON
Enabling/Disabling ROCm Composable Kernel#
Enabling composable_kernel (CK) for both SDPA and GEMMs is a two-part process. First the user must have builtpytorch while setting the corresponding environment variable to ‘1’
SDPA:USE_ROCM_CK_SDPA=1
GEMMs:USE_ROCM_CK_GEMM=1
Second, the user must explicitly request that CK be used as the backend library via the corresponding pythoncall
SDPA:setROCmFAPreferredBackend('<choice>')
GEMMs:setBlasPreferredBackend('<choice>')
To enable CK in either scenario, simply pass ‘ck’ to those functions.
In order to set the backend to CK, the user MUST have built with the correct environment variable. If not,PyTorch will print a warning and use the “default” backend. For GEMMs, this will route to hipblas andfor SDPA it routes to aotriton.