Profiling computation
Contents
Profiling computation#
Viewing program traces with Perfetto#
We can use the JAX profiler to generate traces of a JAX program that can bevisualized using thePerfetto visualizer. Currently,this method blocks the program until a link is clicked and the Perfetto UI loadsthe trace. If you wish to get profiling information without any interaction,check out the XProf profiler below.
withjax.profiler.trace("/tmp/jax-trace",create_perfetto_link=True):# Run the operations to be profiledkey=jax.random.key(0)x=jax.random.normal(key,(5000,5000))y=x@xy.block_until_ready()
After this computation is done, the program will prompt you to open a link toui.perfetto.dev. When you open the link, the Perfetto UI will load the tracefile and open a visualizer.

Program execution will continue after loading the link. The link is no longervalid after opening once, but it will redirect to a new URL that remains valid.You can then click the “Share” button in the Perfetto UI to create a permalinkto the trace that can be shared with others.
Remote profiling#
When profiling code that is running remotely (for example on a hosted VM),you need to establish an SSH tunnel on port 9001 for the link to work. You cando that with this command:
$ssh-L9001:127.0.0.1:9001<user>@<host>or if you’re using Google Cloud:
$gcloudcomputessh<machine-name>---L9001:127.0.0.1:9001Manual capture#
Instead of capturing traces programmatically usingjax.profiler.trace, you caninstead start a profiling server in the script of interest by callingjax.profiler.start_server(<port>). If you only need the profiler server to beactive for a portion of your script, you can shut it down by callingjax.profiler.stop_server().
Once the script is running and after the profiler server has started, we canmanually capture and trace by running:
$python-mjax.collect_profile<port><duration_in_ms>
By default, the resulting trace information is dumped into a temporary directorybut this can be overridden by passing in--log_dir=<directoryofchoice>.Also, by default, the program will prompt you to open a link toui.perfetto.dev. When you open the link, the Perfetto UI will load the tracefile and open a visualizer. This feature is disabled by passing in--no_perfetto_link into the command. Alternatively, you can also pointTensorboard to thelog_dir to analyze the trace (see the“XProf (Tensorboard Profiling)” section below).
XProf (TensorBoard profiling)#
XProfcan be used to profile JAX programs. XProf is a great way to acquire andvisualize performance traces and profiles of your program, including activity onGPU and TPU. The end result looks something like this:

Installation#
XProf is available as a plugin to TensorBoard, as well as an independentlyrun program.
pipinstallxprof
If you have TensorBoard installed, thexprof pip package will also installthe TensorBoard Profiler plugin. Be careful to only install one version ofTensorFlow or TensorBoard, otherwise you may encounter the “duplicate plugins”error describedbelow. Seehttps://www.tensorflow.org/guide/profiler for more information on installingTensorBoard.
Profiling with the nightly version of TensorBoard requires the nightlyXProf.
pipinstalltb-nightlyxprof-nightly
Programmatic capture#
You can instrument your code to capture a profiler trace via thejax.profiler.start_trace() andjax.profiler.stop_trace() methods.Callstart_trace() with the directory to write trace filesto. This should be the same--logdir directory used to start XProf.Then, you can XProf to view the traces.
For example, to take a profiler trace:
importjaxjax.profiler.start_trace("/tmp/profile-data")# Run the operations to be profiledkey=jax.random.key(0)x=jax.random.normal(key,(5000,5000))y=x@xy.block_until_ready()jax.profiler.stop_trace()
Note theblock_until_ready() call. We use this to make sure on-deviceexecution is captured by the trace. SeeAsynchronous dispatch for details on whythis is necessary.
You can also use thejax.profiler.trace() context manager as analternative tostart_trace andstop_trace:
importjaxwithjax.profiler.trace("/tmp/profile-data"):key=jax.random.key(0)x=jax.random.normal(key,(5000,5000))y=x@xy.block_until_ready()
Viewing the trace#
After capturing a trace, you can view it using the XProf UI.
You can launch the profiler UI directly using the standalone XProf command bypointing it to your log directory:
$xprof--port8791/tmp/profile-dataAttemptingtostartXProfserver:LogDirectory:/tmp/profile-dataPort:8791XProfathttp://localhost:8791/(PressCTRL+Ctoquit)
Navigate to the provided URL (e.g., http://localhost:8791/) in your browserto view the profile.
Available traces appear in the “Runs” dropdown menu on the left. Select therun you’re interested in, and then under the “Tools” dropdown, selecttrace_viewer. You should now see a timeline of the execution. You can use theWASD keys to navigate the trace, and click or drag to select events for moredetails. Seethese TensorFlow docsfor more details on using the trace viewer.
Manual capture via XProf#
The following are instructions for capturing a manually-triggered N-second tracefrom a running program.
Start an XProf server:
xprof--logdir/tmp/profile-data/
You should be able to load XProf athttp://localhost:8791/. You canspecify a different port with the
--portflag. SeeProfiling on a remote machinebelow if running JAX on a remote server.In the Python program or process you’d like to profile, add the followingsomewhere near the beginning:
importjax.profilerjax.profiler.start_server(9999)
This starts the profiler server that XProf connects to. The profilerserver must be running before you move on to the next step. When you’re doneusing the server, you can call
jax.profiler.stop_server()to shut it down.If you’d like to profile a snippet of a long-running program (e.g. a longtraining loop), you can put this at the beginning of the program and startyour program as usual. If you’d like to profile a short program (e.g. amicrobenchmark), one option is to start the profiler server in an IPythonshell, and run the short program with
%runafter starting the capture inthe next step. Another option is to start the profiler server at thebeginning of the program and usetime.sleep()to give you enough time tostart the capture.Openhttp://localhost:8791/, and click the “CAPTURE PROFILE” buttonin the upper left. Enter “localhost:9999” as the profile service URL (this isthe address of the profiler server you started in the previous step). Enterthe number of milliseconds you’d like to profile for, and click “CAPTURE”.
If the code you’d like to profile isn’t already running (e.g. if you startedthe profiler server in a Python shell), run it while the capture isrunning.
After the capture finishes, XProf should automatically refresh. (Notall of the XProf profiling features are hooked up with JAX, so it mayinitially look like nothing was captured.) On the left under “Tools”, select
trace_viewer.
You should now see a timeline of the execution. You can use the WASD keys tonavigate the trace, and click or drag to select events to see more details atthe bottom. Seethese XProf docsfor more details on using the trace viewer.
You can also use the following tools:
XProf and Tensorboard#
XProf is the underlying tool that powers the profiling and trace capturingfunctionality in Tensorboard. As long asxprof is installed, a “Profile” tabwill be present within Tensorboard. Using this is identical to launching XProfindependently, as long as it is launched pointing to the same log directory.This includes profile capture, analysis, and viewing functionality. XProfsupplants thetensorboard_plugin_profile functionality that was previouslyrecommended.
$tensorboard--logdir=/tmp/profile-data[...]ServingTensorBoardonlocalhost;toexposetothenetwork,useaproxyorpass--bind_allTensorBoard2.19.0athttp://localhost:6006/(PressCTRL+Ctoquit)
Adding custom trace events#
By default, the events in the trace viewer are mostly low-level internal JAXfunctions. You can add your own events and functions by usingjax.profiler.TraceAnnotation andjax.profiler.annotate_function() inyour code.
Configuring profiler options#
Thestart_trace method accepts an optionalprofiler_options parameter, whichallows for fine-grained control over the profiler’s behavior. This parametershould be an instance ofjax.profiler.ProfileOptions.
For example, to disable all python and host traces:
importjaxoptions=jax.profiler.ProfileOptions()options.python_tracer_level=0options.host_tracer_level=0jax.profiler.start_trace("/tmp/profile-data",profiler_options=options)# Run the operations to be profiledkey=jax.random.key(0)x=jax.random.normal(key,(5000,5000))y=x@xy.block_until_ready()jax.profiler.stop_trace()
General options#
host_tracer_level: Sets the trace level for host-side activities.Supported Values:
0: Disables host (CPU) tracing entirely.1: Enables tracing of only user-instrumented TraceMe events.2: Includes level 1 traces plus high-level program execution details likeexpensive XLA operations (default).3: Includes level 2 traces plus more verbose, low-level program executiondetails such as cheap XLA operations.device_tracer_level: Controls whether device tracing is enabled.Supported Values:
0: Disables device tracing.1: Enables device tracing (default).python_tracer_level: Controls whether Python tracing is enabled.Supported Values:
0: Disables Python function call tracing (default).1: Enables Python tracing.
Advanced configuration options#
TPU options#
tpu_trace_mode: Specifies the mode for TPU tracing.Supported Values:
TRACE_ONLY_HOST: This means only host-side (CPU) activities are traced,and no device (TPU/GPU) traces are collected.TRACE_ONLY_XLA: This means only XLA-level operations on the device aretraced.TRACE_COMPUTE: This traces compute operations on the device.TRACE_COMPUTE_AND_SYNC: This traces both compute operations andsynchronization events on the device.If “tpu_trace_mode” is not provided the trace_mode defaults toTRACE_ONLY_XLA.
tpu_num_sparse_cores_to_trace: Specifies the number of sparse cores totrace on the TPU.tpu_num_sparse_core_tiles_to_trace: Specifies the number of tiles withineach sparse core to trace on the TPU.tpu_num_chips_to_profile_per_task: Specifies the number of TPU chips toprofile per task.
GPU options#
The following options are available for GPU profiling:
gpu_max_callback_api_events: Sets the maximum number of events collectedby the CUPTI callback API. Defaults to2*1024*1024.gpu_max_activity_api_events: Sets the maximum number of events collectedby the CUPTI activity API. Defaults to2*1024*1024.gpu_max_annotation_strings: Sets the maximum number of annotationstrings that can be collected. Defaults to1024*1024.gpu_enable_nvtx_tracking: Enables NVTX tracking in CUPTI. Defaults toFalse.gpu_enable_cupti_activity_graph_trace: Enables CUPTI activity graphtracing for CUDA graphs. Defaults toFalse.gpu_pm_sample_counters: A comma-separated string of GPUPerformance Monitoring metrics to collect using CUPTI’s PM sampling feature(e.g."sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). PM samplingis disabled by default. For available metrics, seeNVIDIA’s CUPTI documentation.gpu_pm_sample_interval_us: Sets the sampling interval in microsecondsfor CUPTI PM sampling. Defaults to500.gpu_pm_sample_buffer_size_per_gpu_mb: Sets the system memory buffer sizeper device in MB for CUPTI PM sampling. Defaults to 64MB. The maximumsupported value is 4GB.gpu_num_chips_to_profile_per_task: Specifies the number of GPU devices toprofile per task. If not specified, set to 0, or set to an invalid value,all available GPUs will be profiled. This can be used to decrease the tracecollection size.gpu_dump_graph_node_mapping: If enabled, dumps CUDA graph nodemapping information into the trace. Defaults toFalse.
For example:
options=ProfileOptions()options.advanced_configuration={"tpu_trace_mode":"TRACE_ONLY_HOST","tpu_num_sparse_cores_to_trace":2}
Returns InvalidArgumentError if any unrecognized keys or option values arefound.
Troubleshooting#
GPU profiling#
Programs running on GPU should produce traces for the GPU streams near the topof the trace viewer. If you’re only seeing the host traces, check your programlogs and/or output for the following error messages.
If you get an error like:Couldnotloaddynamiclibrary'libcupti.so.10.1'
Full error:
Wexternal/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55]Couldnotloaddynamiclibrary'libcupti.so.10.1';dlerror:libcupti.so.10.1:cannotopensharedobjectfile:Nosuchfileordirectory2020-06-1213:19:59.822799:Eexternal/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422]functioncupti_interface_->Subscribe(&subscriber_,(CUpti_CallbackFunc)ApiCallback,this)failedwitherrorCUPTIcouldnotbeloadedorsymbolcouldnotbefound.
Add the path tolibcupti.so to the environment variableLD_LIBRARY_PATH.(Trylocatelibcupti.so to find the path.) For example:
exportLD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
If you still get theCouldnotloaddynamiclibrary message after doing this,check if the GPU trace shows up in the trace viewer anyway. This messagesometimes occurs even when everything is working, since it looks for thelibcupti library in multiple places.
If you get an error like:failedwitherrorCUPTI_ERROR_INSUFFICIENT_PRIVILEGES
Full error:
Eexternal/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445]functioncupti_interface_->EnableCallback(0,subscriber_,CUPTI_CB_DOMAIN_DRIVER_API,cbid)failedwitherrorCUPTI_ERROR_INSUFFICIENT_PRIVILEGES2020-06-1214:31:54.097791:Eexternal/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487]functioncupti_interface_->ActivityDisable(activity)failedwitherrorCUPTI_ERROR_NOT_INITIALIZED
Run the following commands (note this requires a reboot):
echo'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"'|sudotee-a/etc/modprobe.d/nvidia-kernel-common.confsudoupdate-initramfs-usudorebootnow
SeeNVIDIA’s documentation on thiserrorfor more information.
Profiling on a remote machine#
If the JAX program you’d like to profile is running on a remote machine, oneoption is to run all the instructions above on the remote machine (inparticular, start the TensorBoard server on the remote machine), then use SSHlocal port forwarding to access the TensorBoard web UI from your localmachine. Use the following SSH command to forward the default TensorBoard port6006 from the local to the remote machine:
ssh-L6006:localhost:6006<remoteserveraddress>or if you’re using Google Cloud:
$gcloudcomputessh<machine-name>---L6006:localhost:6006Multiple TensorBoard installs#
If starting TensorBoard fails with an error like:ValueError:Duplicatepluginsfornameprojector
It’s often because there are two versions of TensorBoard and/or TensorFlowinstalled (e.g. thetensorflow,tf-nightly,tensorboard, andtb-nightlypip packages all include TensorBoard). Uninstalling a single pip package canresult in thetensorboard executable being removed which is then hard toreplace, so it may be necessary to uninstall everything and reinstall a singleversion:
pipuninstalltensorflowtf-nightlytensorboardtb-nightlyxprofxprof-nightlytensorboard-plugin-profiletbp-nightlypipinstalltensorboardxprof
Nsight#
NVIDIA’sNsight tools can be used to trace and profile JAX code on GPU. Fordetails, see theNsightdocumentation.
