Sequence Packing and Dynamic Batching#

This document describes the sequence packing and dynamic batching features implemented in NeMo-RL to optimize training efficiency for variable-length sequences.

Table of Contents#

  1. Problem

  2. Sequence Packing and Dynamic Batching

  3. Sequence Packing

  4. Dynamic Batching

  5. Configuration

  6. Integration with Training Pipeline

  7. Metrics and Monitoring

  8. Usage

Problem#

Challenge: Variable Sequence Lengths in RL/SFT#

RL and SFT exhibit highly variable sequence lengths due to many datasets having seqlens following Zipf’s law:

  • Skewed Distribution: Most sequences are short, with a few very long sequences

  • Padding Inefficiency: Traditional fixed-length batching requires padding all sequences to the maximum length, resulting in:

    • Wasted computation on pad tokens

    • Underutilized GPU memory

    • Poor GPU compute efficiency

  • Memory Constraints: Batch size is often limited by the longest sequences in the batch

Without optimization, 50-70% of computation can be wasted on padding tokens.

Sequence Packing and Dynamic Batching#

NeMo-RL implements two exclusive approaches to address variable sequence lengths:

  1. Sequence Packing: Concatenates multiple sequences into a single “packed” sequence, eliminating most padding.

  2. Dynamic Batching: Groups sequences of similar lengths and adjusts microbatch sizes based on total token count, reducing the excess padding.

Important Notes#

  • Dynamic batching and sequence packing cannot be enabled simultaneously,they are exclusive.

  • Compatible with Context Parallelism (CP)

  • Requires FlashAttention-2 for packed sequences

Sequence Packing#

Sequence packing concatenates multiple variable-length sequences into a single sequence, eliminating the need for padding tokens. This approach maximizes GPU utilization by ensuring all computational resources are used for meaningful tokens.

Unpacked:(# == useful token, p == padding token)000ppp11111122pppp333ppp~40%paddingPacked:00011111122333p# some padding may still be required as discussed later, but it is significantly reduced

Implementation Details#

1. Packing Process (BatchedDataDict.shard_by_batch_size)#

# Located in: nemo_rl/distributed/batched_data_dict.pydefshard_by_batch_size(self,shards:int,sequence_packing_args:Optional[SequencePackingArgs]=None):# 1. Get bin packer for specified algorithmbin_packer=get_packer(algorithm=sequence_packing_args["algorithm"],bin_capacity=sequence_packing_args["max_tokens_per_microbatch"])# 2. Pack sequences into bins per chunkforchunk_idxinrange(num_chunks):chunk_bin_assignments=bin_packer.pack(sequence_lengths=chunk_padded_seqlens_list)# 3. Create sharded microbatches from packed bins

This methoddoes not actually concatenate the sequences and create the packed tensor. Rather, it reorders the elements in the batch and creates metadata such that after you call your workers withRayWorkerGroup.run_all_workers_sharded_data, each worker can callBatchedDataDict.make_microbatch_iterator_for_packable_sequences locally to return an iterator over batches, where each batch contains elements that should be packed together. For an example of this, you can take a look at theMegatronPolicyWorker’s train function.

We have the policy backends perform the actual packing because implementations can vary widely on how exactly it should be done and what metadata needs to be collected.

2. Packing Algorithms (nemo_rl/data/packing/algorithms.py)#

Four packing algorithms are implemented, but we recommend you just use Modified First Fit Decreasing for the best packing efficiency:

Concatenative Packer#
  • Sequential concatenation until bin capacity is reached

  • O(n)

  • Simple, deterministic packing for debugging

Modified First Fit Decreasing (MFFD)#
  • Johnson & Garey (1985) heuristic with 5-phase packing strategy

  • O(n log n + n*m)

  • Best bin utilization

  • Phases:

    1. Classify items (large: >C/2, medium: >C/3, small: >C/6, tiny: ≤C/6)

    2. Create one bin per large item

    3. Add medium items to large bins (forward pass)

    4. Add pairs of small items (backward pass)

    5. Greedy fit remaining items

    6. Apply FFD to leftovers

First Fit Decreasing (FFD)#
  • Sort sequences by length (descending), place each in first fitting bin

  • O(n log n + n*m) where m = number of bins

  • Good general-purpose algorithm

First Fit Shuffle#
  • Randomly shuffle sequences, then apply first-fit

  • O(n*m)

  • When sequence order doesn’t matter

Usage with Context Parallelism#

For long sequences with context parallelism (CP > 1):

  • Individual sequences must be padded to a multiple ofcp_size*2*tp_size, where the factor of 2 ensures load balancing for causal attention

Understanding CP Load balancing:#

Given a sequence of length 6, CP 2:0 1 2 3 4 5The attention mask is:  | 0 1 2 3 4 5--+------------0 | 1 0 0 0 0 01 | 1 1 0 0 0 02 | 1 1 1 0 0 03 | 1 1 1 1 0 04 | 1 1 1 1 1 05 | 1 1 1 1 1 1If we were to naively chunk this sequence into CP chunks, we would have:CP0:  | 0 1 2--+------0 | 1 0 01 | 1 1 0   +   send KV 0 1 22 | 1 1 1CP1:  | 3 4 5                            | 0 1 2--+------                          --+------3 | 1 0 0                          3 | 1 1 1 4 | 1 1 0   +   recv KV 0 1 2   +  4 | 1 1 15 | 1 1 1                          5 | 1 1 1Here, CP1 ends up with more than double the work of CP0, stalling training on CP0.To fix this, we can chunk the sequence into 2*CP chunks (and pad to accommodate):| 0 1 | 2 3 | 4 5 | p p ||--V--|--V--|--V--|--V--|| CP0 | CP1 | CP1 | CP0 |Now, the work looks like this:CP0:  | 0 1                                           | 2 3 4 5 p p--+----                                         --+------------0 | 1 0   +   send KV 0 1, recv KV 2 3 4 5   +  p | 1 1 1 1 1 01 | 1 1                                         p | 1 1 1 1 1 1CP1:  | 2 3 4 5                                           | 0 1--+--------                                         --+----2 | 1 0 0 0                                         2 | 1 13 | 1 1 0 0   +   send KV 2 3 4 5, recv KV 0 1   +  3 | 1 14 | 1 1 1 0                                         4 | 1 15 | 1 1 1 1                                         5 | 1 1Much more even!

With Sequence packing + CP, we pack and CP-shardper sequence to take full advantage of the load-balancing properties of CP-sharding.

Inputbatch:00000ppp111111112ppppppp333pppppCP=2Firstpackeverysequenceto2*CP*TP=4:[00000ppp,11111111,2ppp,333p]NowCP-shardeachindividualsequenceandpackCP0:00pp11112p3ppacked:00pp11112p3pCP1:000p1111pp33packed:000p1111pp33

Internally, DTensor and Megatron-Core are made aware of sequence packing with eitherFlashAttentionArgs orPackedSeqParams, which containcu_seqlens_q andcu_seqlens_kv, which are the cumulative sequence lengths of the sequence in the packed batch without CP.

Nuances#

  • With using Sequence Packing with Megatron + Pipeline Parallelism (PP), note that all packed sequences will be padded up to the maximum packed sequence length because PP requires maintaining a fixed-size batch x seqlen buffer for PP communications. In practice, however, we find that packing isso efficient that this hardly makes a difference.

All together, we seespeedups in the ~2-3x range when enabling sequence packing.

Dynamic Batching#

Dynamic batching optimizes microbatch formation by:

  1. Sorting sequences by length within batches (and respects chunk boundaries, so there are no training order diffs).

  2. Grouping sequences to achieve target token count per microbatch.

  3. Padding sequences to configurable multiples for hardware alignment.

Cannot be used with sequence packing

Architecture#

Processing Pipeline#

┌─────────────────┐    ┌──────────────────┐    ┌─────────────────┐│   Input Batch   │ ── │ Sort by Length   │ ── │ Group by Tokens ││                 │    │ (within chunks)  │    │                 │└─────────────────┘    └──────────────────┘    └─────────────────┘                                                        │┌─────────────────┐    ┌──────────────────┐    ┌────────V────────┐│ Dynamic Micros  │ <─ │ Pad to Multiple  │ <─ │ Calculate Sizes ││                 │    │                  │    │                 │└─────────────────┘    └──────────────────┘    └─────────────────┘
Inputbatch:00ppppp1111ppp2222222333333p444pppp5555pppMBS=16tokensDynamicBatchingwillre-orderthisbatchtominimizepadding1.Sort:2222222333333p1111ppp5555ppp444pppp00ppppp2.ChunkbyMBStokencountMBS0:2222222333333pMBS1:11115555444p00ppNotehowwe're able to remove a huge chunk of padding this way and do the full batch with fewer microbatches than we would otherwise need.

Implementation Details#

Sorting and Load Balancing (nemo_rl/distributed/batched_data_dict.py)

ifdynamic_batching_argsisnotNone:# Sort sequences by length within each chunkforchunk_idxinrange(num_chunks):chunk_seqlens=self.data[input_lengths_key][chunk_start:chunk_end]chunk_idx_indices=sorted(range(batch_size),key=lambdai:chunk_seqlens[i])# Stride sorted sequences across DP ranks for load balancingchunk_idx_indices=[chunk_idx_indices[i::shards]foriinrange(shards)]

Dynamic Shape Processing (nemo_rl/distributed/batched_data_dict.py)

# In the batched datadict, everything is padded up to the max seqlen. This truncates# everything in one dynamic batch to just pad up to the max within this batch.defmake_microbatch_iterator_with_dynamic_shapes(self):forseqlen,(start_idx,end_idx)inzip(self.micro_batch_lengths[0],self.micro_batch_indices[0]):mb=self.slice(start_idx,end_idx)mb.truncate_tensors(dim=sequence_dim,truncated_len=seqlen)yieldmb

Interface#

classBatchedDataDict(UserDict,Generic[DictT]):defshard_by_batch_size(self,shards:int,dynamic_batching_args:Optional[DynamicBatchingArgs]=None,sequence_packing_args:Optional[SequencePackingArgs]=None)->list[SlicedDataDict]:# Main entry point for both packing and dynamic batching

Similar to Sequence Packing, we do not actually create the dynamic batches upon the call to shard_by_batch_size, but just reorder sequences and create metadata internally. With a call toRayWorkerGroup.run_all_workers_sharded_data, the workers can runmake_microbatch_iterator_with_dynamic_shapes to get the true dynamic batches.

Nuances#

  • Dynamic batchingcannot be used with Megatron + PP because PP requires a fixed [batch x seqlen] buffer for PP communication. Please use Sequence Packing.

  • Dynamic batching is almost always slower than Sequence Packing, but does not require that your model (and in particular, your attention variant) have Sequence-packing implemented (which can be complicated). We’d recommend always using Sequence Packing where possible, and falling back to Dynamic batching as a last resort.

Configuration#

Dynamic Batching Configuration#

classDynamicBatchingArgs(TypedDict):max_tokens_per_microbatch:int# Target tokens per microbatchsequence_length_round:int# Padding alignment multipleinput_key:str# Input tensor key ("input_ids")input_lengths_key:str# Sequence lengths key ("input_lengths")

Sequence Packing Configuration#

classSequencePackingArgs(TypedDict):max_tokens_per_microbatch:int# Bin capacity for packinginput_key:str# Input tensor keyinput_lengths_key:str# Sequence lengths keyalgorithm:str# Packing algorithm namesequence_length_pad_multiple:int# CP/TP alignment factor

Integration with Training Pipeline#

Loss Function Integration#

A key design consideration was that we wanted to avoid the loss function writer needing to be aware of if there is sequence packing or not. To do this, we created aSequencePackingLossWrapper which takes the packed next_token_logits and the unpacked auxiliary loss function data and runs the loss function on each sequence individually. Since the loss function’s computation time is typically trivial, we don’t see a slowdown from this approach. With this, the loss function can be written as though it just deals with typical, unpacked batched data (as long as it is capable of processing one sequence at a time).

If your loss function cannot assume batch-independence, however, then both Dynamic Batching and Sequence Packing won’t work. (I.e. DPOissue #719).

Metrics and Monitoring#

Packing Efficiency Metrics (nemo_rl/data/packing/metrics.py)#

  • Bin Utilization: Percentage of bin capacity used

  • Waste Ratio: Fraction of capacity unused due to packing constraints

  • Bin Balance: Measure of load distribution evenness across bins

  • Packing Efficiency: Ratio of theoretical minimum to actual bins used

Usage#

Sequence Packing Configuration#

# examples/configs/grpo_math_1B.yamlpolicy:sequence_packing:enabled:Truetrain_mb_tokens:2048# Target tokens per microbatchlogprob_mb_tokens:2048algorithm:"modified_first_fit_decreasing"# Best algorithmsequence_length_round:64# Hardware alignmentdynamic_batching:enabled:False# Mutually exclusive

Dynamic Batching Configuration#

# examples/configs/grpo_math_8B.yamlpolicy:dynamic_batching:enabled:Truetrain_mb_tokens:4096logprob_mb_tokens:8192sequence_length_round:64sequence_packing:enabled:False# Mutually exclusive

Framework Compatibility#

Sequence Packing Requirements:

  • Megatron or DTensor policy

  • FlashAttention-2 for efficient packed attention

  • If using CP with Megatron, youmust use sequence packing. If using CP with Dtensor, youcannot yet use packing (WIP,Issue #520)

Dynamic Batching Requirements:

  • Any policy framework

  • Pipeline parallelism size = 1

  • Cannot be used with torch.compile since shapes change.


References#

Johnson & Garey (1985) - Modified First Fit Decreasing

On this page