Multi-Head, Multi-Query, and Group-Query Attention#
This document details the implementation of multi-head attention (MHA),multi-query attention (MQA), and group-query attention (GQA) for autoregressivemodels in TensorRT LLM’s PyTorch backend.
Multi-head attention involves a sequence of batched matrix multiplications, a softmax operation, and another batched matrix multiplication,as described in theAttention Is All You Need paper.Multi-query Attention (MQA) andGroup-query Attention (GQA) arevariants of MHA that use fewer KV heads than the number of query heads.TensorRT LLM provides several implementations using different backends intensorrt_llm/_torch/attention_backend/.The following sections explain how to use these implementations and provide a brief guide on implementing new backends.
Attention Backends#
There are currently three available attention backends: the vanilla backend, the TRT-LLM backend, and the Flashinfer backend.You can specify the desired attention backend usingPyTorchConfig.attn_backend. For instance, to utilize the Flashinfer backend, you can passattn_backend="flashinfer" to theLLM constructor as follows:LLM(attn_backend="flashinfer"). This will enable the use of the Flashinfer backend for your model.
The vanilla backend,VanillaAttention, is a reference implementation designed primarily for inflight batching and linear KV cache support. While it serves as a useful baseline, it is not recommended for production use due to its limited optimizations.
In contrast, the Flashinfer backend,FlashInferAttention, is performance-optimized and supports both inflight batching and paged KV cache. It also includes the following advanced features:
FP8 Quantization: This feature enables the quantization of inputs and KV cache into FP8 format, significantly reducing memory usage and improving computational throughput.
RoPE Fusion: By integrating rotary position embedding (RoPE) directly into the attention computation, this feature enhances efficiency and reduces overhead.
The TRT-LLM backend,TrtllmAttention, serves as the default backend and supports all the features available in the Flashinfer backend while being further optimized for enhanced performance. It is the recommended choice for production environments. Additionally, it offers the following advanced features:
Fused QKV Input: It can accept a single QKV tensor as input, which is more efficient compared to using separate Q, K, and V tensors.
FP8 Output: It supports outputting the attention result in FP8 format, fusing quantization into the attention computation process.
Implement a New Attention Backend#
You can implement a new attention backend to integrate other attention libraries.An attention backend consists of anAttentionBackend class and anAttentionMetadata class.There are three stages in the PyTorch that involve the attention backend:
Model construction: During the model’s
__init__, callAttentionBackend.__init__to create an attention backend for each layer.Metadata preparation: Before each forward step of the model:
If the metadata is uninitialized, call
AttentionMetadata.__init__to create the attention metadata.If using CUDA graphs, call
AttentionMetadata.create_cuda_graph_metadatato convert the metadata to CUDA graph metadata, which pre-allocates all tensors and can be used to capture CUDA graphs. Do not re-allocate any tensors stored insideAttentionMetadataafter the initial warmup run when using CUDA graphs.To prepare parameters of the input and KV cache, call
AttentionMetadata.prepareto convert from existing metadata and KV cache manager.
Single step forward: During the forward pass of each attention layer, call
AttentionBackend.forwardto perform the attention operation. TheAttentionMetadatawill be provided as a forward argument.
ImplementAttentionMetadata#
TheAttentionMetadata class stores metadata from the batched input and KV cache for the attention backend.It contains the following predefined fields:
Field | Type | Description |
|---|---|---|
max_num_requests | int | The max number of requests in a single batch. |
num_contexts | int | The number of context-phase sequences in the batch. |
num_generations | int | The number of generation-phase sequences in the batch. |
max_num_tokens | int | The max number of tokens in all requests in a single batch. |
num_tokens | int | Number of tokens in the batch. |
num_ctx_tokens | int | Number of tokens in sequences in the context phase. |
kv_cache_manager | KVCacheManager | The KV cache manager. |
is_cuda_graph | bool | Whether CUDA graph is enabled. |
seq_lens | Tensor | The length of each sequence in the batch. The shape is (batch_size), and located on CPU memory. |
seq_lens_cuda | Tensor | A copy of |
context_lens | Tensor | The length of each context-phase sequence in the batch. The shape is ( |
position_ids | Optional[Tensor] | The position of each token in each sequence. May be None if positional embedding is applied outside of the backend. |
request_ids | List[int] | The request ID of each sequence in the batch. |
prompt_lens | List[int] | The prompt length of each sequence in the batch. |
kv_cache_params | KVCacheParams | The parameters for the KV cache. |
DuringAttentionMetadata.__init__, you can initialize additional fields for the new attention metadata.For example, the Flashinfer metadata initializesdecode_wrapper here.DuringAttentionMetadata.prepare, the runtime will fill all predefined fields, and you can fill your customized fields according to these predefined fields.For example, the Flashinfer metadata fillsqo_indptr by combiningcontext_lens andnum_generations here.
ImplementAttentionBackend#
TheAttentionBackend delegates the attention operation to the backend implementation.
Its__init__ accepts the following arguments:
Field | Type | Description |
|---|---|---|
layer_idx | int | The index of the attention layer in the model. |
num_heads | int | The number of query heads. |
head_dim | int | The size of each attention head |
num_kv_heads | Optional[int] | The number of KV heads. Defaults to num_heads if None. |
quant_config | QuantConfig | Optional quantization configuration. If None, no quantization is applied. |
pos_embd_params | PositionalEmbeddingParams | Optional parameters defining how positional embedding should be applied. If None, positional embedding should be applied by the model before calling the backend. Otherwise, the backend is in-charge of applying positional embedding and may cache K without embedding it first. |
Itsforward accepts the following arguments:
Field | Type | Description |
|---|---|---|
q | Tensor | Query tensor with shape |
k | Tensor | Key tensor with shape |
v | Tensor | Value tensor with shape |
metadata | AttentionMetadata | Metadata for the attention operation. |
attention_mask | AttentionMask | Optional attention mask. If None, causal mask is applied. |
For example, the Flashinfer backend callsappend_paged_kv_cache and thenwrapper.run to perform the attention operation here.
The Features of theTrtllmAttention Backend#
The following sections introduce some features of the defaultTrtllmAttention backend.
Packed Tensors#
In theTrtllmAttention backend, the attention operator supports the packed (i.e. non padded) QKV inputs.A naive layout for the QKV inputs is padding the sequencesthat are shorter than themax_sequence_length to the maximumlength. It may result in excessive memory consumption as well as unneededcomputations on padding tokens (in the various matrix multiplications thatsurround the MHA block).To overcome that problem, TensorRT LLM supports a mode without padding wherethe different tokens are packed together and the user provides the operatorwith a 1D tensor containing the lengths of the different sequences.
Context and Generation Phases#
TheTrtllmAttention backend encapsulates different implementations for bothcontext and generation phases into a single custom torch op.
Context Phase#
A context-phase implementation without optimization maps to a sequence of GPU kernels that will store theintermediateQ*K^T tensor in memory before calling the softmax operator. Itis the slowest method and the memory footprint is significant (grows quadratically in proportion to the sequence length).
TheTrtllmAttention backend will trigger a kernel that performs the MHA/MQA blockusing a single kernel instead. For short sequences, that kernel uses a vanillaimplementation of MHA/MQA. For larger sequences, this kernel uses the FlashAttention algorithm as described inFlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessandFlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.
Currently, the implementation triggers extra kernels that apply pre-processingto the elements (like RoPE) and populate the KV cache (see below). In a futurerelease, the number of such kernels may be reduced to improve the overall performance.
FP8 Context FMHA#
When FP8 quantization is activated, the attention can be further accelerated byenabling FP8 Context FMHA.
FP8 Paged Context FMHA is also supported with the fp8 quantization workflow.You need to specifyuse_paged_context_fmha=True for the attention operator.
Please be aware that this feature is only supported on Ada, Hopper and above.
Generation Phase#
The generation phase is implemented using a single kernel called the maskedmulti-head attention in TensorRT LLM. That kernel is able to applypre-processing on the Q, K, and V elements on-the-fly: it adds the QKV bias, appliesRoPE, and performs dequantization and quantization. TensorRT LLM will continue to add (orenable) additional features in future releases, such as enabling support for IA3.
The masked MHA kernel has a special version that distributes the work acrossmultiple CUDA thread-blocks on the GPU for cases where the GPU occupancy islow. That mode called multi-block is always enabled.NVIDIA recommends users to test that mode in scenarios where both the batchsize and the number of heads in the model are relatively small.The definition of ‘small’ in that context is hard to quantify because it depends on the model of the GPU.However, NVIDIA currently recommends testing that mode whenbatch_size*num_heads is less than the number of multi-processors on the GPU.This guidance may be subject to change in the future.
Note that even if the multi-block mode is enabled, the attention operator willnot immediately trigger the multi-block version of the GPU kernel. There is aminimum number of tokens (input + generated) that are required for themulti-block version to become more efficient than the “vanilla” implementationthat uses a single CUDA thread-block per head. It is controlled by an internalheuristic.
Another note is that as the masked MHA kernels use shared memory sizeproportional to sequence length, so there can be some cases that GPU’s sharedmemory is not enough when multi-block mode is not enabled. To get masked MHA kernel to work in those cases, multi-block mode is forced on and a warning message is printed in the log.
XQA Optimization#
XQA optimization is another optimization for MQA/GQA in the generation phase.It currently only supports a limited number of model configurations, such as the LLAMA2 70B model.
Support matrix of the XQA optimization:
FP16 / BF16 compute data type.
FP16 / BF16 / FP8 / INT8 KV cache data type.
Paged KV cache (8 / 16 / 32 / 64 / 128 tokens per block).
By default, this is enabled. Note that a heuristic algorithmis also used to decide whether to use XQA kernel or masked MHA kernel to getbetter performance.If you want to use that kernel whenever possible, setTRTLLM_FORCE_XQA=1 to force use of the XQA kernel when the model config is supported.Supported configurations can be found using theshouldUse function of theDecoderXQARunner class incpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h.
In-flight Batching#
TensorRT LLM supports in-flight batching of requests (also known as continuousbatching or iteration-level batching) for higher serving throughput. With this feature,sequences in the context phase can be processed together with sequences inthe generation phase. The purpose of that technique is to better interleaverequests to reduce latency as well as make better use of the GPUs.For efficiency reasons (1), the support for inflight batchingrequires theinput tensors to be packed (no padding).
In the current implementation, the sequences that are going through thecontext phase must be before the sequences in the generation phase in the inputtensor. For example, for sequencesS0,S1 andS2, ifS0 andS2 are incontext phase (andS1 in generation), tokens fromS0 andS2 must appearbefore the tokens ofS1 in the input tensor.
(1) Padding sequences in the generation phase, that contain a single token, tothe length of the maximum input sequence is inefficient use of resources.
Chunked Context#
In the original state, the common behavior was to process all context tokens atonce. This feature splits the context into several chunks. In this way, thecontext chunks can be batched with more tokens during the generation phase,which is expected to increase the total throughput. Chunking contexts also removesconstraints on input length. Except for the last one, the size of the context chunk needsto be an integer multiple of the kv-cache block size.
To enable this feature, the FMHA paged kv-cache also needs to be enabled.
KV Cache#
In the generation phase, a common optimization is to provide the MHA kernelwith a cache containing the values of the past K and V elements that havealready been computed. That cache is known as the KV cache. TensorRT LLM usesthat technique to accelerate its generation phase. In TensorRT LLM, there isone KV cache per Transformer layer, which means that there are as many KVcaches as layers in a model. The current version of TensorRT LLM supports twodifferent types of KV caches:contiguous andpaged KV caches.
Contiguous KV Cache#
The contiguous KV cache is a monolithic tensor. Its shape is:
[max_batch_size*max_beam_width,2,num_heads,max_seqlen,hidden_dim_per_head].
That implementation uses a lot more memory than needed when the sequences areshorter than the maximum sequence length (even if they end up close to thelimit after the generation of many output tokens, it may take a lot of steps toreach that point).
Paged KV Cache#
The paged KV cache decomposes the KV cache into blocks that are distributed tothe different requests by a cache manager during processing. That cache managerkeeps track of the sequences, allocates new blocks from a pool and recycles thoseblocks when required. See the implementation ofKVCacheManager.
INT8/FP8 KV Caches#
In its current implementation, even if the rest of the network runs in INT8 orFP8, the attention operator works with FP32, FP16, and BFloat16 inputs andoutputs. However, TensorRT LLM supports INT8 and FP8(QuantMode.INT8_KV_CACHE andQuantMode.FP8_KV_CACHE) KV caches.
The attention operator populates the KV cache. When INT8 or FP8 KV cachesare enabled, the input values have to be quantized to 8 bits using a scalingfactor. For quantization, the scaling factor is stored in thekv_cache_scaling_factor tensor. Its shape is[1] and only per-tensorquantization is supported in the current version. Quantization uses inversed scalesince it does multiply asfp_value*(1.0/kv_cache_scaling_factor) in plugin.
During generation, the values read from the cache are dequantized on-the-fly inthe MHA/MQA kernel. Dequantization is defined asquantized_value*kv_cache_scaling_factor.
Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache#
TensorRT LLM has a feature calledCyclicKVCache, which treats the kv cacheas a circular buffer. This means that it only stores the kv cache for the last Ntokens, where N is determined by theattention_window_size parameter inTrtllmAttention.forward. When the cache is full, new tokens’ kv cache willoverwrite the “least recently used” caches.
In the context phase, if the input length surpasses theattention_window_size,SlidingWindowAttention will be activated. This serves the same function asthe sliding window size.
This feature helps to reduce the memory footprint of the kv cache whendealing with very long sequences.
_Note that the cyclic kv cache feature doesn’t work with beam searching currently asthe context kv cache are shared across beams.
StreamingLLM#
The StreamingLLM feature uses a window attention to perform efficient and stable LLMon long texts, which means that onlyN tokens need to be stored in the KV cache.Similar to the cyclic KV cache feature in TensorRT LLM,attention_window_sizeparameter is used to determineN. Different from the cyclic KV cache feature,the firstS tokens, called sink tokens, are always kept in the attention window,whereS is determined bysink_token_length parameter.But in context phase, the self-attentions are dense in the official implementation ofStreamingLLM. It uses all of the tokens for computation and only savesN tokensto the KV cache.
In addition, the relative position embedding is also changed in StreamingLLM.When determining the relative distance and adding positional information to tokens,StreamingLLM use the positions within the cache rather than those in the original text.
sink_token_length is also used to enable this feature.
Beam-Search#
The attention operator supports beam-search. In the context phase, a singlebeam is computed per input sequence. In the generation phase, the MHA/MQA/GQAkernel uses an additional tensor to reconstruct the correct path for each beam.That tensor is called thecache_indirection. Its shape is[batch_size,beam_width,max_seqlen].
For a sequencesi, a beambi and a tokenti, the elementcache_indirection[si][bi][ti] is an integer between0 andbeam_width-1that indicates which path in the beam to read the K and V elements from in theKV cache. This tensor is populated in the sampling stage.
Input QKV tensor#
The input QKV tensor packs the Q, K and V tensors (concatenated along the lastdimension) after the projection of the hidden states. It is a 3D tensor. RoPEand quantization to INT8 or FP8 (when needed) are performed by the GPTattention operator.
In packed mode, its shape is[num_tokens,3*hidden_dim] wherenum_tokens is the total number of tokens in the batch. For the sequences incontext phase, the number of tokens of a sequence corresponds to its inputlength (even if the beam width is greater than1 for beam search). For thesequences in generation phase, there arebeam_width tokens per sequence. Thebeam width can be different for each sequence.
The following pseudo code explains how the number of tokens is computed:
num_tokens=0# Add the length of each sequence in context phase.forseqincontext_phase:num_tokens+=seq.length# Add the width of the beam for each sequence in generation phase.forseqingeneration_phase:num_tokens+=seq.beam_width
Rotary Positional Embedding (RoPE)#
The attention operator can perform the computation of the RotaryPositional Embedding (RoPE). When that operation is enabled,rotary_embedding_dim is set to a value greater than 0, it is fused with otheroperations. The GPT operator supports GPT-NeoX and GPT-J forms of RoPE bysettingposition_embedding_type toPositionEmbeddingType.rope_gpt_neoxorPositionEmbeddingType.rope_gptj.
ALiBi#
The attention operator can apply ALiBi to the result of theQ*K^Tproduct. The bias is computed on-the-fly from the ALiBi slopes in the optimizedkernel.
Scaling factor(s)#
In MHA, the output of theQ*K^T product is scaled by a constant value thatis computed as:
norm_factor=1.f/(q_scaling*sqrt(head_size)).
Cross Attention#
On top of the MHA as self attention needed by GPT-style decoder-only models, the attention operator also supports cross attention.
This enables the attention operator to be more broadly used as a generic decoder component. For example, the Encoder-Decoder model uses it to issue both the self attention and cross attention modules in its Decoder.