Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

GPU performance tips#

This document focuses on performance tips for neural network workloads

Matmul precision#

On recent GPU generations, such as the Nvidia A100 generation or later, it canbe a good idea to perform most computations inbfloat16 precision. Forexample, if usingFlax, instantiateDenselayers usingflax.linen.Dense(...,dtype=jax.numpy.bfloat16). Here are somecode examples:

XLA performance flags#

Note

JAX-Toolbox also has a page onNVIDIA XLA performance FLAGS.

The existence and exact behavior of XLA flags may bejaxlib-version dependent.

As ofjaxlib==0.4.18 (releasedOct 62023), setting these XLA flags canimprove performance. Some are related to communication between GPUs, and so areonly relevant when running computations on multiple devices, while others arerelated to code generation on each device.

Some of these may be set by default in future releases.

These flags can be set via theXLA_FLAGS shell environment variable. Forexample, we can add this to the top of a Python file:

importosos.environ['XLA_FLAGS']=('--xla_gpu_triton_gemm_any=True ''--xla_gpu_enable_latency_hiding_scheduler=true ')

For more examples, see alsoXLA Flags recommended for Paxtraining on Nvidia GPUs.

Code generation flags#

  • –xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter forany GEMM that it supports. The default value is False.

Communication tips#

Auto and manual PGLE#

The Profile Guided Latency Estimator (PGLE) workflow measures the actual running timeof compute and collectives, the the profile information is fed back into XLA compilerfor a better scheduling decision.

The Profile Guided Latency Estimator can be used manually or automatically. In the auto modeJAX will collect profile information and recompile a module in a single run. Whilein manual mode you need to run a task twice, the first time to collect and save profilesand the second to compile and run with provided data.

Important: the JAX profiler, which is used by both of the PGLE workflows documentedbelow, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can beavoided by using the JAX compilation cache, as described below.

Auto PGLE#

The auto PGLE can be turned on by setting the following environment variables:

Mandatory:

exportJAX_ENABLE_PGLE=true# For JAX version <= 0.5.0 make sure to include:exportXLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

Optional:

exportJAX_PGLE_PROFILING_RUNS=3exportJAX_PGLE_AGGREGATION_PERCENTILE=85# Right now the auto PGLE profile collection doesn't work with command buffer.# If the command buffer is enabled, Auto PGLE will disable it during profile# collection and enable it back after the recompilation. If you need to have a# consistent command buffer logic with and with PGLE profile you can disable it# manually:exportXLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

Or in the JAX this can be set as the following:

importjaxfromjax._srcimportconfigwithconfig.enable_pgle(True),config.pgle_profiling_runs(1):# Run with the profiler collecting performance information.train_step()# Automatically re-compile with PGLE profile resultstrain_step()...

You can control amount of reruns used to collect profile data by changingJAX_PGLE_PROFILING_RUNS.Increasing this parameter would lead to better profile information, but it will also increase theamount of non-optimized training steps.

Decreasing theJAX_PGLE_AGGREGATION_PERCENTILE parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures.

Attention: Auto PGLE doesn’t work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case:

importjaxfromjax._srcimportconfigtrain_step_compiled=train_step().lower().compile()withconfig.enable_pgle(True),config.pgle_profiling_runs(1):train_step_compiled()# No effect since module was pre-compiled.train_step_compiled()

Collecting NVIDIA Nsight Systems profiles when using AutoPGLE#

jax#24910 (JAX v0.5.1 and newer) added anew JAX configuration option,JAX_COMPILATION_CACHE_EXPECT_PGLE, which tells JAX toattempt to load PGLE-optimized compiled functions from the persistent compilationcache.

This allows a two-step process, where the first step writes a PGLE-optimized functionto the cache:

exportJAX_ENABLE_COMPILATION_CACHE=yes# not strictly needed, on by defaultexportJAX_COMPILATION_CACHE_DIR=/root/jax_cacheJAX_ENABLE_PGLE=yespythonmy-model.py

And the second step uses Nsight Systems and loads the PGLE-optimized function from thecache:

JAX_COMPILATION_CACHE_EXPECT_PGLE=yesnsysprofilepythonmy-model.py

See alsothis page for moreinformation about the persistent compilation cache and possible pitfalls.

Manual PGLE#

If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is:

    1. Run your workload once, with async collectives and latency hiding scheduler enabled.

You could do so by setting:

exportXLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file.

importosfrometilsimportepathimportjaxfromjax.experimentalimportprofilerasexp_profiler# Define your profile directoryprofile_dir='gs://my_bucket/profile'jax.profiler.start_trace(profile_dir)# run your workflow# for i in range(10):#   train_step()# Stop tracejax.profiler.stop_trace()profile_dir=epath.Path(profile_dir)directories=profile_dir.glob('plugins/profile/*/')directories=[dfordindirectoriesifd.is_dir()]rundir=directories[-1]logging.info('rundir:%s',rundir)# Post process the profilefdo_profile=exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))# Save the profile proto to a file.dump_dir=rundir/'profile.pb'dump_dir.parent.mkdir(parents=True,exist_ok=True)dump_dir.write_bytes(fdo_profile)

After this step, you will get aprofile.pb file under therundir printed in the code.

    1. Run the workload again feeding that file into the compilation.

You need to pass theprofile.pb file to the--xla_gpu_pgle_profile_file_or_directory_path flag.

exportXLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

To enable logging in the XLA and check if the profile is good, set the logging level to includeINFO:

exportTF_CPP_MIN_LOG_LEVEL=0

Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler:

2023-07-2116:09:43.551600:Iexternal/xla/xla/service/gpu/gpu_hlo_schedule.cc:478]UsingPGLEprofilefrom/tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb2023-07-2116:09:43.551741:Iexternal/xla/xla/service/gpu/gpu_hlo_schedule.cc:573]Foundprofile,usingprofileguidedlatencyestimator

Flags#

  • –xla_gpu_enable_latency_hiding_scheduler This flag enables latency hidingschedulers to overlap asynchronous communication with computation efficiently.The default value is False.

  • –xla_gpu_memory_limit_slop_factor This flag serves as a multiplier appliedto the total available memory, creating a threshold that guides the Latency HidingScheduler (LHS) in balancing memory reduction and latency hiding optimizations.The default value is 95.

    This factor effectively establishes a memory limit for compiler passes, determiningwhen the scheduler should prioritize:

    1. Memory reduction: When memory usage approaches or exceeds the calculated threshold.

    2. Latency hiding: When memory usage is below the threshold, allowing for moreaggressive optimizations that may temporarily increase memory usage but improveoverall performance.

    By adjusting this factor, users can fine-tune the trade-off between memory efficiencyand performance optimizations.

  • –xla_gpu_all_gather_combine_threshold_bytes–xla_gpu_reduce_scatter_combine_threshold_bytes–xla_gpu_all_reduce_combine_threshold_bytesThese flags tune when to combine multiple smallAllGather/ReduceScatter/AllReduce into one bigAllGather/ReduceScatter/AllReduce to reduce time spent on cross-devicecommunication. For example, for theAllGather/ReduceScatter thresholdson a Transformer-based workload, consider tuning them high enough so as tocombine at least a Transformer Layer’s weightAllGather/ReduceScatter. Bydefault, thecombine_threshold_bytes is set to 256.

Pipeline Parallelism on GPU#

Using XLA Flags#

XLA implements SPMD-based pipeline parallelism optimizations. This is a scalingtechnique where the forward and backward pass are split into multiple pipelinestages. Each device (or device group) processes the result of the previouspipeline stage (or the pipeline input) and sends its partial result to the nextstage until the end of the pipeline is reached. This optimization works bestwhen the latency of the computation is larger than communication. At compiletime, the operations will be rearranged to overlap communication withcomputation.

For an optimized schedule, we recommend these XLA flags:

--xla_gpu_enable_latency_hiding_scheduler=true--xla_gpu_enable_command_buffer=''--xla_disable_hlo_passes=collective-permute-motion--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE

The following JAX example demonstrates a pattern where communication operationsare scheduled to overlap with computations. In this example we will illustratehow to set up an optimized pipeline parallelism scheduling using 4 GPUs thatform a communication ring (device 0 -> device 1 -> device 2 -> device 3 ->device 0). We refer to the pattern0->1->2->3 as the forward edge, and3->0 as the back edge.

# Imports and setupimportfunctoolsimportjaxfromjaximportshardingfromjax.experimentalimportmesh_utilsimportjax.numpyasjnpimportjax.randomNUM_DEVICES=4NUM_MICROBATCHES=5NUM_CIRC_REPEATS=2CONTRACTING_DIM_SIZE=4096NON_CONTRACTING_DIM_SIZE=8192COMPUTE_INTENSITY=32# Creates a collective permute for the "forward edge".# 0->1, 1->2, ... (N-2)->(N-1)defshift_right(arr):padding=[[1,0]]+[[0,0]]*(arr.ndim-1)# Use lax.slice to guarantee the gradient is a pad.returnjax.lax.slice(jnp.pad(arr,padding),[0]*arr.ndim,arr.shape)# Creates a collective permute for the "back edge".# (N-1)->0defcycle_back(arr):padding=[[0,NUM_DEVICES-1]]+[[0,0]]*(arr.ndim-1)returnjax.lax.slice(jnp.pad(arr,padding),[NUM_DEVICES-1]+[0]*(arr.ndim-1),(NUM_DEVICES-1+arr.shape[0],)+arr.shape[1:],)defselect_on_first_device(then_value,else_value):assertthen_value.shape==else_value.shapeis_first_device=jax.lax.broadcasted_iota("int32",then_value.shape,0)==0returnjnp.where(is_first_device,then_value,else_value)defselect_on_last_device(then_value,else_value):assertthen_value.shape==else_value.shapeis_last_device=(jax.lax.broadcasted_iota("int32",then_value.shape,0)==NUM_DEVICES-1)returnjnp.where(is_last_device,then_value,else_value)defselect_on_first_cycle(i,then_value,else_value):assertthen_value.shape==else_value.shapeis_first_cycle=i<NUM_MICROBATCHESreturnjnp.where(is_first_cycle,then_value,else_value)defwhile_body(carry,i):"""Body of the pipeline while loop."""weights,input_buffer,output_buffer,fwd_edge_data,bwd_edge_data=carry# Read input data from input buffer.input_data=jax.lax.dynamic_slice(input_buffer,(0,(i+0)%NUM_MICROBATCHES,0,0),(NUM_DEVICES,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE),)# Collective permute on the "forward edge" shifts data to the next stage.fwd_edge_data=shift_right(fwd_edge_data)# Select compute argument based on device and pipeline cycle.compute_argument=select_on_first_device(select_on_first_cycle(i,input_data,bwd_edge_data),fwd_edge_data,).reshape((NUM_DEVICES,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE))# A few matmuls to simulate compute.tmp=compute_argumentfor_inrange(COMPUTE_INTENSITY):tmp=jax.lax.dot_general(weights,tmp,(((2,),(1,)),((0,),(0,))))compute_result=tmp.reshape((NUM_DEVICES,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE))# Read data from buffer to pass it to the first device of the pipeline on the# "back edge".bwd_edge_data=jax.lax.dynamic_slice(output_buffer,(0,(1+i)%NUM_MICROBATCHES,0,0),(NUM_DEVICES,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE),)# Collective permute on the "back edge" passes data to the first device.bwd_edge_data=cycle_back(bwd_edge_data)# Update output buffer. We do this after reading from it to avoid the data# dependency.output_buffer=jax.lax.dynamic_update_slice(output_buffer,compute_result,(0,(2+i)%NUM_MICROBATCHES,0,0),)fwd_edge_data=compute_resultcarry=(weights,input_buffer,output_buffer,fwd_edge_data,bwd_edge_data,)returncarry,i@functools.partial(jax.jit,static_argnames=["mesh"])defentry_computation(weights,input_buffer,mesh):# Init output buffer.output_buffer=jnp.zeros_like(input_buffer)# Init dummy data for forward and backward edge passed through the while loop.dummy_data=jnp.zeros(shape=(NUM_DEVICES,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE)).astype(jnp.float32)dummy_data=jax.device_put(dummy_data,sharding.NamedSharding(mesh,sharding.PartitionSpec("x")),)# Start pipeline.carry=weights,input_buffer,output_buffer,dummy_data,dummy_datanum_iterations=NUM_CIRC_REPEATS*NUM_MICROBATCHES+NUM_DEVICES-1carry,_=jax.lax.scan(while_body,carry,xs=jnp.arange(num_iterations))_,_,output_buffer,_,_=carryreturnoutput_bufferdefmain(_):# Expect constant number of devices.assertNUM_DEVICES==jax.local_device_count()# Create mesh.mesh=sharding.Mesh(mesh_utils.create_device_mesh([NUM_DEVICES]),axis_names=["x"],)# Init weights.weights=1.0/CONTRACTING_DIM_SIZEweights=jax.lax.broadcast_in_dim(weights,shape=(NUM_DEVICES,CONTRACTING_DIM_SIZE,CONTRACTING_DIM_SIZE),broadcast_dimensions=(),)weights=jax.device_put(weights,sharding.NamedSharding(mesh,sharding.PartitionSpec("x")),)# Init random input and replicate it across all devices.random_key=jax.random.key(0)input_buffer=jax.random.uniform(random_key,shape=(NUM_MICROBATCHES,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE,),)input_buffer=jax.lax.broadcast_in_dim(input_buffer,shape=(NUM_DEVICES,NUM_MICROBATCHES,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE,),broadcast_dimensions=[1,2,3],)input_buffer=jax.device_put(input_buffer,sharding.NamedSharding(mesh,sharding.PartitionSpec("x")),)# Run computation.output_buffer=entry_computation(weights,input_buffer,mesh)print(f"output_buffer =\n{output_buffer}")

Usingpsend andprecv#

The JAX example above lowers tocollective-permute HLO instructions, which areare implemented throughncclSend andncclRecv on GPU. For users who wantmore granular control over the ordering of collectives, they can usejax.lax.psend andjax.lax.precv directly. Syntactically, these two functionsare analogous to their HLO counterparts. Users should keep in mind that theirprogram will deadlock when the source-target pairs in asinglepsend orprecv form a cycle, and whenpsend is not matched byprecv andvice-versa.

If cycles are required in the device communication pattern, deadlocks can beavoided by making sure that (1) no singlepsend orprecv function’ssource-target pairs contain a cycle, and that (2) a fake data dependencyis inserted to sequentialize the send/recv pairs. No collective can be scheduledbetweenpsend/precv paris, which can only be controlled throughjax.lax.optimization_barrier at the JAX level. The test casetest_psend_precv_basic_with_no_deadlock_cycle in the fileshard_map_test.py is one such example.

The pipeline parallelism example in the previous section uses the--xla_gpu_experimental_pipeline_parallelism_opt_level XLA flag. The sameprogram can be rewritten usingpsend andprecv without the flag, if manuallypipelined.

## same setup and importsdefwhile_body(carry,i):(weights,input_buffer,output_buffer,prev_compute_res,prev_stage_slice_fwd,prev_stage_slice_bwd,)=carry# Read input data from input buffer.input_slice=jax.lax.dynamic_slice(input_buffer,(0,(i+0)%NUM_MICROBATCHES,0,0),(1,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE),)# send_fwdfwd_send_token=jax.lax.psend(prev_compute_res,axis_name="x",perm=[(0,1),(1,2),(2,3)],)# Select compute argument based on device and pipeline cyclecompute_argument=select_on_first_device(select_on_first_cycle(i,input_slice,prev_stage_slice_bwd),prev_stage_slice_fwd,).reshape((1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE))tmp=compute_argumentfor_inrange(COMPUTE_INTENSITY):tmp=jax.lax.dot_general(weights,tmp,(((2,),(1,)),((0,),(0,))))compute_result=tmp.reshape((1,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE))buffer_slice_for_bwd_ppermute=jax.lax.dynamic_slice(output_buffer,(0,(i+1)%NUM_MICROBATCHES,0,0),(1,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE),)# make sure ppermute is scheduled after send_fwdbuffer_slice_for_bwd_ppermute_after_send_fwd,_=(jax.lax.optimization_barrier((buffer_slice_for_bwd_ppermute,fwd_send_token)))# ppermute_bwdppermute_bwd_data=jax.lax.ppermute(buffer_slice_for_bwd_ppermute_after_send_fwd,axis_name="x",perm=[(3,0)],)# make sure recv is scheduled after ppermuteprecv_token,_=jax.lax.optimization_barrier((jax.lax.create_token(),ppermute_bwd_data))# recv_fwd, matches the send_fwd in the next iterationfwd_recv_data=jax.lax.precv(precv_token,out_shape=jax.ShapeDtypeStruct(input_slice.shape,input_slice.dtype),axis_name="x",perm=[(0,1),(1,2),(2,3)],)update_output_buffer=jax.lax.dynamic_update_slice(output_buffer,compute_result,(0,(i+2)%NUM_MICROBATCHES,0,0),)carry=(weights,input_buffer,update_output_buffer,compute_result,fwd_recv_data,ppermute_bwd_data,)returncarry,idefentry_computation(weights,input_buffer,dummy_data,mesh):# Init output buffer.output_buffer=jnp.zeros_like(input_buffer)# Start pipeline.dummy_slice_fwd=jax.lax.precv(jax.lax.create_token(),jax.ShapeDtypeStruct(dummy_data.shape,dummy_data.dtype),axis_name="x",perm=[(0,1),(1,2),(2,3)],)carry=(weights,input_buffer,output_buffer,dummy_slice_fwd,dummy_data,dummy_data,)num_iterations=NUM_CIRC_REPEATS*NUM_MICROBATCHES+NUM_DEVICES-1carry,_=jax.lax.scan(while_body,carry,xs=jnp.arange(num_iterations))_=jax.lax.psend(carry[3],axis_name="x",perm=[(0,1),(1,2),(2,3)],)_,_,output_buffer,_,_,_=carryreturnoutput_bufferdefmain(_):# Expect constant number of devices.assertNUM_DEVICES==jax.local_device_count()# Create mesh.mesh=Mesh(mesh_utils.create_device_mesh([NUM_DEVICES]),axis_names=["x"],)# Init weights.weights=1.0/CONTRACTING_DIM_SIZEweights=jax.lax.broadcast_in_dim(weights,shape=(NUM_DEVICES,CONTRACTING_DIM_SIZE,CONTRACTING_DIM_SIZE),broadcast_dimensions=(),)weights=jax.device_put(weights,NamedSharding(mesh,P("x")))# Init input.random_key=jax.random.key(0)input_buffer=jax.random.uniform(random_key,shape=(NUM_MICROBATCHES,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE,),)input_buffer=jax.lax.broadcast_in_dim(input_buffer,shape=(NUM_DEVICES,NUM_MICROBATCHES,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE,),broadcast_dimensions=[1,2,3],)input_buffer=jax.device_put(input_buffer,NamedSharding(mesh,P("x")),)# Init dummy data for forward and backward edge passed through the while# loop.dummy_slice=jnp.zeros(shape=(NUM_DEVICES,1,CONTRACTING_DIM_SIZE,NON_CONTRACTING_DIM_SIZE)).astype(jnp.float32)dummy_data=jax.device_put(dummy_slice,NamedSharding(mesh,P("x")),)entry=partial(entry_computation,mesh=mesh)output_buffer=jax.jit(jax.shard_map(entry,mesh=mesh,in_specs=P("x"),out_specs=P("x"),check_vma=False,))(weights,input_buffer,dummy_data)print(f"output_buffer =\n{output_buffer}")

NCCL flags#

These Nvidia NCCL flag values may be useful for single-host multi-devicecomputations on Nvidia GPUs:

os.environ.update({"NCCL_LL128_BUFFSIZE":"-2","NCCL_LL_BUFFSIZE":"-2","NCCL_PROTO":"SIMPLE,LL,LL128",})

These NCCL flags could improve single-host communication speed. These flagsdon’t seem useful for multi-host communication yet.

Multi-Process#

We recommend using one process per GPU and not one per node. In somecases, this can speed up jitted computation. Thejax.distributed.initialize() API will automatically understandthat configuration when run under SLURM. However, this only a rule ofthumb and it may be useful to test both one process per GPU and oneprocess per node on your use case.


[8]ページ先頭

©2009-2025 Movatter.jp