Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

GPU memory allocation#

JAX will preallocate 75% of the total GPU memory when the first JAXoperation is run. Preallocating minimizes allocation overhead and memoryfragmentation, but can sometimes cause out-of-memory (OOM) errors. If your JAXprocess fails with OOM, the following environment variables can be used tooverride the default behavior:

XLA_PYTHON_CLIENT_PREALLOCATE=false

This disables the preallocation behavior. JAX will instead allocate GPUmemory as needed, potentially decreasing the overall memory usage. However,this behavior is more prone to GPU memory fragmentation, meaning a JAX programthat uses most of the available GPU memory may OOM with preallocationdisabled.

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

If preallocation is enabled, this makes JAX preallocate XX% ofthe total GPU memory, instead of the default 75%. Lowering theamount preallocated can fix OOMs that occur when the JAX program starts.

XLA_PYTHON_CLIENT_ALLOCATOR=platform

This makes JAX allocate exactly what is needed on demand, and deallocatememory that is no longer needed (note that this is the only configuration thatwill deallocate GPU memory, instead of reusing it). This is very slow, so isnot recommended for general use, but may be useful for running with theminimal possible GPU memory footprint or debugging OOM failures.

Common causes of OOM failures#

Running multiple JAX processes concurrently.

Either useXLA_PYTHON_CLIENT_MEM_FRACTION to give each process anappropriate amount of memory, or setXLA_PYTHON_CLIENT_PREALLOCATE=false.

Running JAX and GPU TensorFlow concurrently.

TensorFlow also preallocates by default, so this is similar to runningmultiple JAX processes concurrently.

One solution is to use CPU-onlyTensorFlow (e.g. if you’re only doing data loading with TF). You can preventTensorFlow from using the GPU with the commandtf.config.experimental.set_visible_devices([],"GPU")

Alternatively, useXLA_PYTHON_CLIENT_MEM_FRACTION orXLA_PYTHON_CLIENT_PREALLOCATE. There arealso similar options to configure TensorFlow’s GPU memory allocation(gpu_memory_fractionandallow_growthin TF1, which should be set in atf.ConfigProto passed totf.Session. SeeUsing GPUs: Limiting GPU memory growthfor TF2).

Running JAX on the display GPU.

UseXLA_PYTHON_CLIENT_MEM_FRACTION orXLA_PYTHON_CLIENT_PREALLOCATE.

Disabling rematerialization HLO pass

Sometimes disabling the automatic rematerialization HLO pass is favorable to avoidpoor remat choices by the compiler. The pass can be enable/disable by settingjax.config.update('jax_compiler_enable_remat_pass',True) orjax.config.update('jax_compiler_enable_remat_pass',False) respectively. Enabling ordisabling the automatic remat pass produces different trade-offs between compute andmemory. Note however, that the algorithm is basic and you can often get bettertrade-off between compute and memory by disabling the automatic remat pass and doingit manually withthe jax.remat API

Experimental features#

Features here are experimental and must be tried with caution.

TF_GPU_ALLOCATOR=cuda_malloc_async

This replace XLA’s own BFC memory allocator withcudaMallocAsync.This will remove the big fixed pre-allocation and use a memory pool that grows.The expected benefit is no need to setXLA_PYTHON_CLIENT_MEM_FRACTION.

The risk are:

  • that memory fragmentation is different, so if you are close to thelimit, the exact OOM case due to fragmentation will be different.

  • The allocation time won’t be all paid at the start, but be incurredwhen the memory pool need to be increased. So you couldexperience less speed stability at the start and for benchmarksit will be even more important to ignore the first few iterations.

The risks can be mitigated by pre-allocating a signigicant chunk andstill get the benefit of having a growing memory pool. This can bedone withTF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N. If N is-1it will preallocate the same as what was allocatedy bydefault. Otherwise, it is the size in bytes that you want topreallocate.


[8]ページ先頭

©2009-2025 Movatter.jp