Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

JAX Memories and Host Offloading#

This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:

  • Activation offloading

  • Parameter offloading

  • Optimizer state offloading

By applying offloading strategies, developers can better manage memory resources and reduce memory pressure on devices. To implement these strategies effectively, understanding JAX’s core mechanisms for data placement and movement is essential.

Building Blocks for Offloading#

JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. The following sections explore:

  • How to specify data distribution with sharding

  • How to control memory placement between host and device

  • How to manage data movement in jitted functions

NamedSharding and Memory Kinds#

NamedSharding defines how data are distributed across devices. It includes:

  • Basic data distribution configuration

  • memory_kind parameter for specifying memory type (device orpinned_host)

  • By default,memory_kind is set todevice memory

  • with_memory_kind method for creating new sharding with modified memory type

importjaximportjax.numpyasjnpfromjax.shardingimportMesh,NamedSharding,PartitionSpecasPimportnumpyasnp# Create mesh# 1x1 mesh represents a single device with two named dimensions (x and y)mesh=Mesh(np.array(jax.devices()[0]).reshape(1,1),('x','y'))# Device sharding - partitions data along x and y dimensionss_dev=NamedSharding(mesh,P('x','y'),memory_kind="device")# Host sharding - same partitioning but in pinned host memorys_host=s_dev.with_memory_kind('pinned_host')print(s_dev)# Shows device memory shardingprint(s_host)# Shows pinned host memory sharding
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)

Data Placement with device_put#

jax.device_put() is a function that explicitly transfers arrays to a specified memory location according to a sharding specification.

# Create a 2x4 arrayarr=jnp.arange(8.0).reshape(2,4)# Move arrays to different memory locations based on sharding objectsarr_host=jax.device_put(arr,s_host)# Places in pinned host memoryarr_dev=jax.device_put(arr,s_dev)# Places in device memory# Verify memory locationsprint(arr_host.sharding.memory_kind)# Output: pinned_hostprint(arr_dev.sharding.memory_kind)# Output: device
pinned_hostdevice

Output Sharding Controls#

Shardings determine how data is split across devices. JAX providesout_shardings to control how output arrays are partitioned when leaving a jitted function.

Key Features:

  • Can differ from input sharding

  • Allows different memory kinds for outputs

Examples:

Device Output Sharding#

f=jax.jit(lambdax:x,out_shardings=s_dev)out_dev=f(arr_host)print("Result value of H2D:\n",out_dev)
Result value of H2D:  [[0. 1. 2. 3.] [4. 5. 6. 7.]]

Moving data from host to device memory when needed for computation is the essence of host offloading. Usejax.device_put() to perform this transfer in this example to optimize performance.

# Instead of the lambda function, add_func can be defined explicitly# move data to device before computationdefadd_func(x):# Move data to device and add onex=jax.device_put(x,s_dev)returnx+1f=jax.jit(add_func,out_shardings=s_dev)out_dev=f(arr_host)print("Result value of H2D and add 1 in device memory:\n",out_dev)
Result value of H2D and add 1 in device memory:  [[1. 2. 3. 4.] [5. 6. 7. 8.]]

Host Output Sharding#

f=jax.jit(lambdax:x,out_shardings=s_host)out_host=f(arr_dev)# Input arrays in the device memory while output arrays in the host memoryprint("Result value of D2H:\n",out_host)
Result value of D2H:  [[0. 1. 2. 3.] [4. 5. 6. 7.]]

Activation Offloading#

Before diving into activation offloading, let’s first take a look at the baseline code.

This code implements a simple neural network with 10 layers, each consisting of two linear transformations. The code demonstrates basic memory usage patterns and provides a foundation for comparing offloading optimization techniques.

Key components:

  • Each layer consists of two sequential linear operations:

    1. First multiplication:x@w1

    2. Second multiplication:y@w2

  • 10-layer network using JAX’s scan operation

  • Memory usage analysis

  • Gradient computation with JIT compilation

To analyze memory usage in JAX, thejax.stages.Compiled.memory_analysis() method can be used on a compiled function. This provides detailed statistics about memory consumption during computation. The key metrics include temporary memory size, argument size, output size, and alias size. To calculate the total memory usage, sum the temporary, argument, and output sizes, then subtract the alias size to avoid double-counting the same memory multiple times. This provides a summarized view of how the device memory is utilized across different aspects of the computation.

# Initialize input and weights with small values (0.0001)input=jnp.ones((256,256),dtype=jnp.float32)*0.001# Input matrix: 256 x 256w1=jnp.ones((10,256,1024),dtype=jnp.float32)*0.001# 10 layers of 256 x 1024 matricesw2=jnp.ones((10,1024,256),dtype=jnp.float32)*0.001# 10 layers of 1024 x 256 matricesdeftwo_layers(x,w):# Simple two-layer linear transformationw1,w2=wy=x@w1returny@w2,Nonedefscanned(w,x):# Applies the layer function 10 times using JAX's scan operation# Input: w (tuple of weight matrices), x (input matrix)# Output: sum of the final layer's outputresult=jax.lax.scan(two_layers,x,w)[0]returnjnp.sum(result)# Compile and compute gradients of the scanned functionf=jax.jit(jax.grad(scanned))# Apply JIT compilation to gradient computation# Analyze memory usagecompiled_step=f.lower((w1,w2),input).compile()compiled_stats=compiled_step.memory_analysis()ifcompiled_statsisnotNone:# Calculate total memory usage including temporary storage, arguments, and outputs# Subtract alias size to avoid double-counting memory shared between different componentstotal=compiled_stats.temp_size_in_bytes+compiled_stats.argument_size_in_bytes \+compiled_stats.output_size_in_bytes-compiled_stats.alias_size_in_bytesprint(f"Temp size:{compiled_stats.temp_size_in_bytes/(1024**2):.2f} MB")print(f"Argument size:{compiled_stats.argument_size_in_bytes/(1024**2):.2f} MB")print(f"Total size:{total/(1024**2):.2f} MB")# Execute the function and print sample resultsresult=f((w1,w2),input)# Execute the function with weights and inputprint("Sample of results: ",result[0][0,0,:5])
Temp size: 17.25 MBArgument size: 20.25 MBTotal size: 57.50 MBSample of results:  [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]

The detailed coverage of activation offloading can be found in theGradient checkpointing with jax.checkpoint (jax.remat) tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.

To implement activation offloading effectively, it is important to understand checkpoint names and policies. Here’s how they work in a simple example:

Checkpoint Names#

Thecheckpoint_name() function allows labeling activations for memory management during computation. Here’s a simple example that a checkpoint namex is specified.

fromjax.ad_checkpointimportcheckpoint_namedeflayer_name(x,w):w1,w2=wx=checkpoint_name(x,"x")y=x@w1returny@w2,None

The checkpoint name helps the system decide whether to:

  • Keep the activation in device memory or

  • Offload it to host memory during computation

This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.

Checkpoint Policies#

This checkpoint policy implements a memory management strategy that optimizes memory usage during computation. It manages memory by handling intermediate values through three strategies:

  1. Recomputing during backward pass (default behavior)

  2. Storing on device

  3. Offloading to host memory after forward pass and loading back during backward pass

fromjaximportcheckpoint_policiesascppolicy=cp.save_and_offload_only_these_names(names_which_can_be_saved=[],# No values stored on devicenames_which_can_be_offloaded=["x"],# Offload activations labeled "x"offload_src="device",# Move from device memoryoffload_dst="pinned_host"# To pinned host memory)

jax.lax.scan() is commonly used in JAX for handling sequential operations (like RNNs or transformers). It can be integrated with JAX’s rematerialization to process sequential data.

Key components:

  • jax.remat() creates a rematerialized version of the layer function usingjax.remat() and applies the checkpoint policy to the layer function

  • prevent_cse=False enables XLA’s common subexpression elimination for better performance

  • jax.lax.scan() iterates the rematerialized layer along an axis

defscanned(w,x):remat_layer=jax.remat(layer_name,policy=policy,# Use our offloading policyprevent_cse=False)# Allow CSE optimizationsresult=jax.lax.scan(remat_layer,x,w)[0]returnjnp.sum(result)# Initialize input and weights with small values (0.0001)input=jnp.ones((256,256),dtype=jnp.float32)*0.001# Input matrix: 256 x 256w1=jnp.ones((10,256,1024),dtype=jnp.float32)*0.001# 10 layers of 256 x 1024 matricesw2=jnp.ones((10,1024,256),dtype=jnp.float32)*0.001# 10 layers of 1024 x 256 matrices# Compile and compute gradients of the scanned functionf=jax.jit(jax.grad(scanned))# Apply JIT compilation to gradient computation# Analyze memory usagecompiled_step=f.lower((w1,w2),input).compile()compiled_stats=compiled_step.memory_analysis()ifcompiled_statsisnotNone:total=compiled_stats.temp_size_in_bytes+compiled_stats.argument_size_in_bytes \+compiled_stats.output_size_in_bytes-compiled_stats.alias_size_in_bytesprint(f"Temp size:{compiled_stats.temp_size_in_bytes/(1024**2):.2f} MB")print(f"Argument size:{compiled_stats.argument_size_in_bytes/(1024**2):.2f} MB")print(f"Total size:{total/(1024**2):.2f} MB")result_activation=f((w1,w2),input)# Execute the function with weights and input# Verify numerical correctnessare_close=jnp.allclose(result_activation[0],# Result from activation offloading onlyresult[0],# Result from both activation and parameter offloadingrtol=1e-5,atol=1e-5)print(f"Results match within tolerance:{are_close}")print("Sample of results: ",result_activation[0][0,0,:5])
Temp size: 6.50 MBArgument size: 20.25 MBTotal size: 46.75 MBResults match within tolerance: TrueSample of results:  [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]

Activation offloading reduces temporary memory usage from 17.25 MB to 6.5 MB while input and output argument sizes remain the same. Totally 10.75 MB is saved. It is achieved by offloading activationx to host memory after the forward pass and loading it back to device memory before the backward pass.

Summary of Activation Offloading#

Activation offloading provides a powerful way to manage memory in large computations by:

  • Using checkpoint names to mark specific activations

  • Applying policies to control where and how activations are stored

  • Supporting common JAX patterns like scan operations

  • Moving selected activations to host memory when device memory is under budget

This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.

Parameter Offloading#

Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by usingjax.jit() with a sharding strategy that specifies host memory kind.

While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.

Parameter Placement for Computation#

Different from the earlierlayer function,jax.device_put() is applied to move parameterw1 andw2 to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.

Note that the activation offloading implementation remains unchanged, using the same:

  • Checkpoint name"x"

  • Checkpoint policy

  • scanned function combiningjax.remat() andjax.lax.scan()

Parameter Initialization with Host Offloading#

During the initialization, parameterw1 andw2 are placed on host memory before being passed to thejax.jit() functionf, while keeping theinput variable on the device.

# Hybrid version: Both activation and parameter offloadingdefhybrid_layer(x,w):# Move model parameters w1 and w2 to device memory via device_putw1,w2=jax.tree.map(lambdax:jax.device_put(x,s_dev),w)x=checkpoint_name(x,"x")# Offload activation x to host memoryy=x@w1returny@w2,Nonedefhybrid_scanned(w,x):remat_layer=jax.remat(hybrid_layer,# Use hybrid_layer instead of layerpolicy=policy,# Use offloading policyprevent_cse=False)# Allow CSE optimizationsresult=jax.lax.scan(remat_layer,x,w)[0]returnjnp.sum(result)# Move model parameters w1 and w2 to the host via device_put# Initialize input and weights with small values (0.0001)wh1=jax.device_put(w1,s_host)wh2=jax.device_put(w2,s_host)# Compile and compute gradients of the scanned functionf=jax.jit(jax.grad(hybrid_scanned))# Apply JIT compilation to gradient computation# Analyze memory usagecompiled_step=f.lower((wh1,wh2),input).compile()compiled_stats=compiled_step.memory_analysis()ifcompiled_statsisnotNone:total=compiled_stats.temp_size_in_bytes+compiled_stats.argument_size_in_bytes \+compiled_stats.output_size_in_bytes-compiled_stats.alias_size_in_bytesprint(f"Temp size:{compiled_stats.temp_size_in_bytes/(1024**2):.2f} MB")print(f"Argument size:{compiled_stats.argument_size_in_bytes/(1024**2):.2f} MB")print(f"Total size:{total/(1024**2):.2f} MB")result_both=f((wh1,wh2),input)# Execute with both activation and parameter offloading# Verify numerical correctnessare_close=jnp.allclose(result_activation[0],# Result from activation offloading onlyresult_both[0],# Result from both activation and parameter offloadingrtol=1e-5,atol=1e-5)print(f"Results match within tolerance:{are_close}")
Temp size: 4.75 MBArgument size: 0.25 MBTotal size: 25.00 MBResults match within tolerance: True

This implementation demonstrates how offloading model parameters together with activation offloading to host memory can significantly reduce device memory usage.

Memory Analysis#

Baseline Memory Usage:

  • Input tensor: 0.25 MB (256 × 256 × 4 bytes)

  • Model parameters (w1, w2): 10 MB each (256 × 1024 × 4 bytes ≈ 1 MB per layer × 10 layers)

Memory Usage Comparison:

  • Argument size without parameter offloading: 20.25 MB (0.25 + 10 + 10)

  • Argument size with parameter offloading: 0.25 MB (only input remains)

  • Temporary memory without activation offloading: 17.25 MB

  • Temporary memory with activation offloading: 6.50 MB

  • Temporary memory with activation and parameter offloading: 4.75 MB

Key Optimizations#

  1. Parameter Offloading: Moving parameters (w1, w2) to host memory reduces argument size by 20 MB (from 20.25 MB to 0.25 MB).

  2. Activation Offloading: Moving activations to host memory reduces temporary memory usage by 10.75 MB (from 17.25 to 6.50 MB).

  3. Hybrid Strategy: The rematerialization of activation offloading helps avoid keeping weights on the device and reduce temporary memory usage by 1.75 MB (from 6.50 MB to 4.75 MB). Without it, JAX would be eager to keep the on-device copies of the weights alive for the backward pass.

Results#

Total Memory Savings: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)

This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness.

Limitations of Parameter Offloading#

jax.lax.scan() is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. Whilejax.lax.scan() allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates atranspose operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.

The offloading performance can vary for different device types. It may degrade performance due to memory transfers between host and device, so it’s important to consider this trade-off when designing your optimization strategy.

Optimizer State Offloading#

Optimizer state offloading is a memory management technique that stores optimizer states in host memory instead of device memory. This approach is particularly useful when optimizer states are large, as it reduces device memory usage.

A basic JAX implementation using the Adam optimizer can serve as a starting point, where all tensors are stored on the device. This will serve as a reference implementation before introducing optimizer state offloading.

Basic Implementation#

This section, let’s implement a simple model with the Adam optimizer. This implementation helps establish the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training.

In the code example below, a neural network training loop is included to use JAX and Optax’s Adam optimizer. The network consists of four linear layers with GELU activation functions, processing large matrices of size 7168x7168. The training process involves:

  • Forward pass: The input flows through four layers, each applying a linear transformation followed by GELU activation

  • Loss computation: Calculates mean squared error between output and input, plus L2 regularization

  • Backward pass: Computes gradients using automatic differentiation

  • Optimization step: Updates parameters using Adam optimizer with gradient clipping

The code uses JIT compilation to optimize performance and includes memory usage analysis to monitor the computational resources required during training. The memory analysis provides insights into temporary memory usage, argument sizes, and total memory consumption during the optimization step.

importoptaxDIM=7168# Initialize data and parameter w1, w2, w3 and w4input=jnp.ones((DIM,DIM))params={f'w{i}':jnp.ones((DIM,DIM))foriinrange(1,5)}# Initialize optimizeroptimizer=optax.chain(optax.clip_by_global_norm(1.0),optax.adam(learning_rate=0.1))opt_state=optimizer.init(params)defgelu(x):return0.5*x*(1+jnp.tanh(jnp.sqrt(2/jnp.pi)*(x+0.044715*x**3)))defsingle_layer(x,w):returnx@wdefforward(params,x):foriinrange(1,5):x=gelu(single_layer(x,params[f'w{i}']))returnxdefcompute_loss(params,inputs):outputs=forward(params,inputs)loss=jnp.mean((outputs-inputs)**2)l2_reg=0.001*sum(jnp.sum(w**2)forwinjax.tree_util.tree_leaves(params))returnloss+l2_regdefstep(params,opt_state,inputs):grads=jax.grad(lambdap:compute_loss(p,inputs))(params)updates,new_opt_state=optimizer.update(grads,opt_state,params)returnoptax.apply_updates(params,updates),new_opt_state# JIT compile the step function with proper shardingstep=jax.jit(step,donate_argnums=(0,1))# Run a optimization stepnew_params,new_opt_state=step(params,opt_state,input)# Analyze memory usagecompiled_step=step.lower(params,opt_state,input).compile()compiled_stats=compiled_step.memory_analysis()ifcompiled_statsisnotNone:total=compiled_stats.temp_size_in_bytes+compiled_stats.argument_size_in_bytes \+compiled_stats.output_size_in_bytes-compiled_stats.alias_size_in_bytesprint(f"Temp size:{compiled_stats.temp_size_in_bytes/(1024**3):.2f} GB")print(f"Argument size:{compiled_stats.argument_size_in_bytes/(1024**3):.2f} GB")print(f"Total size:{total/(1024**3):.2f} GB")
Temp size: 2.11 GBArgument size: 2.49 GBTotal size: 4.59 GB

Optimizer state offloading can be implemented as follows.

Setting Up Sharding and Memory Kinds#

jax.sharding.SingleDeivceSharding() is adopted to simplify the shardings for both device and host memory kinds. During the model state initialization, move the optimizer state to the host usingdevice_put().

Model and Training Step Implementation#

Next, define the model architecture, loss function, and training step. The key addition here is moving the optimizer state to device memory viadevice_put() at the beginning of each training step, as it’s needed for the parameter update on the device.

Running and Comparing Results#

After setting up the sharding, the optimizer state is moved to host memory and the step function is run withjax.jit().

The JIT compilation of the step function uses several important parameters:

  • donate_argnums=(0,): Indicates that the first argument (parameters) can be modified in-place, allowing JAX to reuse its memory

  • out_shardings: Specifies how output tensors should be sharded across the mesh (devices and hosts)

# Create sharding specifications for device and host memorys_dev=jax.sharding.SingleDeviceSharding(jax.devices()[0],memory_kind="device")s_host=jax.sharding.SingleDeviceSharding(jax.devices()[0],memory_kind="pinned_host")defstep(params,opt_state,inputs):grads=jax.grad(lambdap:compute_loss(p,inputs))(params)opt_state=jax.device_put(opt_state,s_dev)updates,new_opt_state=optimizer.update(grads,opt_state,params)new_params=optax.apply_updates(params,updates)returnnew_params,new_opt_stateparams={f'w{i}':jnp.ones((DIM,DIM))foriinrange(1,5)}opt_state=optimizer.init(params)# Initialize optimizeroptimizer=optax.chain(optax.clip_by_global_norm(1.0),optax.adam(learning_rate=0.1))# Optimizer state is placed on the host during initializationopt_state=jax.device_put(opt_state,s_host)# JIT compile the step function with proper sharding and memory optimizationstep=jax.jit(step,donate_argnums=(0,),out_shardings=(s_dev,s_host))# Run an optimization stepnew_params,offload_opt_state=step(params,opt_state,input)# Analyze memory usagecompiled_step=step.lower(params,opt_state,input).compile()compiled_stats=compiled_step.memory_analysis()ifcompiled_statsisnotNone:total=compiled_stats.temp_size_in_bytes+compiled_stats.argument_size_in_bytes \+compiled_stats.output_size_in_bytes-compiled_stats.alias_size_in_bytesprint(f"Temp size:{compiled_stats.temp_size_in_bytes/(1024**3):.2f} GB")print(f"Argument size:{compiled_stats.argument_size_in_bytes/(1024**3):.2f} MB")print(f"Total size:{total/(1024**3):.2f} GB")
Temp size: 1.91 GBArgument size: 0.96 MBTotal size: 2.87 GB

This implementation demonstrates how to:

  1. Set up sharding specifications fordevice andpinned_host

  2. Move optimizer states between host and device memory viajax.device_put()

  3. Useout_shardings to ensure proper memory placement

  4. Show the memory usage

This implementation demonstrates how offloading optimizer state to host memory can reduce device memory usage through a trade-off between argument size and temporary memory.

Memory Analysis:

  1. Argument Size Reduction:

    • The optimizer states are arguments of thejax.jit() function

    • By offloading these states to host memory, the argument size on device is reduced

  2. Temporary Memory Impact:

    • Offloading increases temporary memory usage

    • This is because outputs of optimizer states need memory buffers before being copied to host

    • The memory live ranges for these temporary buffers are extended due to the host-device transfers

  3. Latency Hiding Scheduling:

    • JAX uses XLA’s latency hiding scheduling to overlap computation with host-device transfers

    • The overlapping can cause tensors to have larger live ranges, which increases memory pressure on the device

    • This adaptive behavior helps maintain stable memory usage while still providing some performance benefits

  4. Memory Trade-off:

    • Total memory size with offloading: 2.87 GB

    • Total memory size without offloading: 4.59 GB

    • Net memory saving: 1.72 GB

while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage.

Note: The optimizer states can be compared for numerical equivalence usingjax.tree_util.tree_map andjnp.allclose, but this verification step is omitted here for brevity.

Tools for Host Offloading#

jax.stages.Compiled.memory_analysis() API is utilized above to get memory usage information. For device memory analysis, refer toProfiling device memory. The profiling tools described inProfiling computation can help measure memory savings and performance impact from host offloading.


[8]ページ先頭

©2009-2026 Movatter.jp