Profiling device memory
Contents
Profiling device memory#
Note
June 2025 update: we recommend usingXProfprofiling for device memory analysis. After taking aprofile, open thememory_viewer tab of the Tensorboard profiler for moredetailed and understandable device memory usage.
The JAX device memory profiler allows us to explore how and why JAX programs areusing GPU or TPU memory. For example, it can be used to:
Figure out which arrays and executables are in GPU memory at a given time, or
Track down memory leaks.
Installation#
The JAX device memory profiler emits output that can be interpreted usingpprof (google/pprof). Start by installingpprof,by following itsinstallation instructions.At the time of writing, installingpprof requires first installingGo of version 1.16+,Graphviz, and then running
goinstallgithub.com/google/pprof@latest
which installspprof as$GOPATH/bin/pprof, whereGOPATH defaults to~/go.
Note
The version ofpprof fromgoogle/pprof is not the same asthe older tool of the same name distributed as part of thegperftools package.Thegperftools version ofpprof will not work with JAX.
Understanding how a JAX program is using GPU or TPU memory#
A common use of the device memory profiler is to figure out why a JAX program isusing a large amount of GPU or TPU memory, for example if trying to debug anout-of-memory problem.
To capture a device memory profile to disk, usejax.profiler.save_device_memory_profile(). For example, consider thefollowing Python program:
importjaximportjax.numpyasjnpimportjax.profilerdeffunc1(x):returnjnp.tile(x,10)*0.5deffunc2(x):y=func1(x)returny,jnp.tile(x,10)+1x=jax.random.normal(jax.random.key(42),(1000,1000))y,z=func2(x)z.block_until_ready()jax.profiler.save_device_memory_profile("memory.prof")
If we first run the program above and then execute
pprof--webmemory.prof
pprof opens a web browser containing the following visualization of the devicememory profile in callgraph format:
The callgraph is a visualization ofthe Python stack at the point the allocation of each live buffer was made.For example, in this specific case, the visualization shows thatfunc2 and its callees were responsible for allocating 76.30MB, of which38.15MB was allocated inside the call fromfunc1 tofunc2.For more information about how to interpret callgraph visualizations, see thepprof documentation.
Functions compiled withjax.jit() are opaque to the device memory profiler.That is, any memory allocated inside ajit-compiled function will beattributed to the function as a whole.
In the example, the call toblock_until_ready() is to ensure thatfunc2completes before the device memory profile is collected. SeeAsynchronous dispatch for more details.
Debugging memory leaks#
We can also use the JAX device memory profiler to track down memory leaks by usingpprof to visualize the change in memory usage between two device memory profilestaken at different times. For example, consider the following program whichaccumulates JAX arrays into a constantly-growing Python list.
importjaximportjax.numpyasjnpimportjax.profilerdefafunction():returnjax.random.normal(jax.random.key(77),(1000000,))z=afunction()defanotherfunc():arrays=[]foriinrange(1,10):x=jax.random.normal(jax.random.key(42),(i,10000))arrays.append(x)x.block_until_ready()jax.profiler.save_device_memory_profile(f"memory{i}.prof")anotherfunc()
If we simply visualize the device memory profile at the end of execution(memory9.prof), it may not be obvious that each iteration of the loop inanotherfunc accumulates more device memory allocations:
pprof--webmemory9.prof
The large but fixed allocation insideafunction dominates the profile but doesnot grow over time.
By usingpprof’s--diff_base feature to visualize the change in memory usageacross loop iterations, we can identify why the memory usage of theprogram increases over time:
pprof--web--diff_basememory1.profmemory9.prof
The visualization shows that the memory growth can be attributed to the call tonormal insideanotherfunc.
