Numerical Precision#

This document describes the different quantization recipes implemented in TensorRT-LLM and contains a support matrixfor the different models.

FP32, FP16 and BF16#

The different models implemented in TensorRT-LLM work with 32-bit IEEEfloating-point (FP32) numbers. When checkpoints are available, the models alsosupport 16-bit IEEE floating-point numbers (FP16) and 16-bit Bfloat16 (BF16) asdescribedhere.

Quantization and Dequantization (Q/DQ)#

Given a floating-point numberx and a floating-point scaling factors,TensorRT-LLM implements INT8 quantization as:

q=int8.satfinite(x*s)

Given an INT8 numberq and a floating-point scaling factors, TensorRT-LLMimplements INT8 dequantization to the floating-point (FP) type as:

x=static_cast<FP>(q)*s

Given a matrix (2D tensor) of shapeMxN (M rows andN columns) whereM is the number of tokens andN is the number of channels. TensorRT-LLM hasthe three following modes to quantize and dequantize the elements of thetensor:

  • Per-tensor: It uses a single scaling factor for all the elements,

  • Per-token: It uses a different scaling factor for each token. There areMscaling factors in that case,

  • Per-channel: It uses a different scaling factor for each channel. There areN scaling factors in that case.

Note that per-token and per-channel scaling modes can be used together (i.e.they arenot mutually exclusive).

In pseudo-code, the quantization can be implemented as follows for the threedifferent modes:

# Per-tensor scaling.formiinrange(M):forniinrange(N):q[mi][ni]=int8.satfinite(x[mi][ni]*s)# Per-token scaling.formiinrange(M):forniinrange(N):q[mi][ni]=int8.satfinite(x[mi][ni]*s[mi])# Per-channel scaling.formiinrange(M):forniinrange(N):q[mi][ni]=int8.satfinite(x[mi][ni]*s[ni])

INT8 SmoothQuant (W8A8)#

The SmoothQuant technique was introduced inhttps://arxiv.org/abs/2211.10438. It is amethod to run inference using INT8 for both activations and weights whilemaintaining the accuracy of the network (on downstream tasks).

As explained in the research paper, preprocessing must be applied to theweights of the model. TensorRT-LLM includes scripts to prepare the model torun using the SmoothQuant method.

Examples of how to enable SmoothQuant for GPT, GPT-J and LLaMA can be found intheexamples/quantization folder of that release.

INT4 and INT8 Weight-Only (W4A16 and W8A16)#

The INT4 and INT8 Weight-Only techniques consist in quantizing the weights ofa model and dequantizing those weights on-the-fly in linear layers (Matmuls).The activations are encoded using floating-point values (FP16 or BF16).

To use INT4/INT8 Weight-Only methods, the user must determine the scalingfactors to use to quantize and dequantize the weights of the model.

This release includes examples forGPT andLLaMA.

GPTQ and AWQ (W4A16)#

The GPTQ and AWQ techniques are presented inhttps://arxiv.org/abs/2210.17323andhttps://arxiv.org/abs/2306.00978,respectively. TensorRT-LLM supports per-group scaling factors andzero-offsetting in linear layers to implement GPTQ and AWQ methods. See theWeightOnlyGroupwiseQuantMatmulPluginplugin and the correspondingweight_only_groupwise_quant_matmulPython function, for details.

This release includes examples of applying GPTQ toGPT-NeoXandLLaMA-v2, as well as an example of using AWQ withGPT-J.

FP8 (Hopper)#

This release of TensorRT-LLM contains implementations of FP8 for GPT-NeMo,GPT-J and LLaMA. Those examples can be found inexamples/quantization.

NVFP4 (Blackwell)#

LLama and Mixtral can run in NVFP4 datatype. Those examples can be found in Llama examples.

Support matrix#

This release of TensorRT-LLM contains the following examples:

Model

FP32

FP16

BF16

FP8

NVFP4

W8A8 SQ

W8A16

W4A16

W4A16 AWQ

W4A16 GPTQ

Baichuan

Y

Y

Y

Y

.

Y

Y

Y

Y

Y

BERT

Y

Y

Y

.

.

.

.

.

.

.

BLIP-2

Y

Y

Y

.

.

.

.

.

.

.

BLOOM

Y

Y

Y

Y

.

Y

Y

Y

.

.

ChatGLM

Y

Y

Y

.

.

.

.

.

.

.

ChatGLM-v2

Y

Y

Y

.

.

.

.

.

.

.

ChatGLM-v3

Y

Y

Y

.

.

.

.

.

.

.

DBRX

Y

Y

Y

.

.

.

Y

Y

.

.

Falcon

Y

Y

Y

Y

.

.

Y

Y

Y

.

Flan-T5

Y

Y

Y

.

.

.

.

.

.

.

Gemma

Y

Y

Y

Y

.

Y

Y

Y

Y

.

GPT

Y

Y

Y

Y

.

Y

Y

Y

.

.

GPT-J

Y

Y

Y

Y

.

Y

Y

Y

Y

.

GPT-NeMo

Y

Y

Y

.

.

.

.

.

.

.

GPT-NeoX

Y

Y

Y

.

.

.

.

.

.

Y

InternLM

Y

Y

Y

.

.

Y

Y

Y

.

.

InternLM2

Y

Y

Y

.

.

.

.

.

.

.

LLaMA

Y

Y

Y

Y

Y

Y

Y

Y

Y

Y

LLaMA-v2

Y

Y

Y

Y

Y

Y

Y

Y

Y

Y

Mamba

Y

Y

Y

.

.

.

.

.

.

.

Mistral

Y

Y

Y

Y

.

Y

Y

Y

Y

.

Mixtral

Y

Y

Y

Y

Y

.

Y

Y

.

.

MPT

Y

Y

Y

Y

.

Y

Y

Y

Y

.

OPT

Y

Y

Y

.

.

.

.

.

.

.

Phi

Y

Y

Y

.

.

.

.

.

.

.

Qwen

Y

Y

Y

.

.

Y

Y

Y

Y

Y

RecurrentGemma

Y

Y

Y

Y

.

Y

.

.

Y

.

Replit Code

Y

Y

Y

.

.

.

.

.

.

.

SantaCoder

Y

Y

Y

.

.

.

Y

Y

.

.

Skywork

Y

Y

Y

.

.

.

.

.

.

.

StarCoder1

Y

Y

Y

.

.

.

Y

Y

.

.

StarCoder2

Y

Y

Y

Y

.

.

Y

Y

.

.

T5

Y

Y

Y

.

.

.

.

.

.

.

Whisper

Y

Y

Y

.

.

.

Y

Y

.

.

BLIP2-OPT

Y

Y

Y

.

.

.

.

.

.

.

BLIP2-T5

Y

Y

Y

.

.

.

.

.

.

.

LLaVA

Y

Y

Y

Y

.

Y

Y

Y

Y

Y

VILA

Y

Y

Y

Y

.

Y

Y

Y

Y

Y

Nougat

Y

Y

Y

.

.

.

.

.

.

.

Note: The vision component of multi-modal models(BLIP2-OPT/BLIP2-T5/LLaVA/VILA/Nougat) uses FP16 by default.The language component decides which quantization methods are supported by a given multi-modal model.

Technical Detail: TheQuantMode Flags#

The quantization method is controlled by theQuantMode flags. The different fieldsare:

  • INT4_WEIGHTS, the weights are quantized to 4 bits (W4A*),

  • INT8_WEIGHTS, the weights are quantized to 8 bits (W8A*),

  • ACTIVATIONS, the activations are quantized to 8 bits (W*A8),

  • PER_CHANNEL, the scaling factors are defined per channel,

  • PER_TOKEN, the scaling factors are defined per token,

  • PER_GROUP, the scaling factors are defined per group.

There are three additional flags to control TensorRT-LLM:

  • INT8_KV_CACHE, the K/V cache stores K and V using 8-bit integers,

  • FP8_KV_CACHE, the K/V cache stores K and V using 8-bit floating-point numbers,

  • FP8_QDQ, TensorRT-LLM relies on automatic fusion of Q/DQ nodes in TensorRT.