Source code for tensorrt_llm.quantization.mode

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromenumimportIntFlag,autofromtypingimportOptionalfromstrenumimportStrEnumfrom.._utilsimportBaseEnumMeta
[docs]classQuantAlgo(StrEnum,metaclass=BaseEnumMeta):W8A16=auto()W4A16=auto()W4A16_AWQ=auto()W4A8_AWQ=auto()W8A16_GPTQ=auto()W4A16_GPTQ=auto()W8A8_SQ_PER_CHANNEL=auto()W8A8_SQ_PER_TENSOR_PLUGIN=auto()W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN=auto()W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN=auto()W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN=auto()W4A8_QSERVE_PER_GROUP=auto()W4A8_QSERVE_PER_CHANNEL=auto()FP8=auto()FP8_PER_CHANNEL_PER_TOKEN=auto()FP8_BLOCK_SCALES=auto()INT8=auto()MIXED_PRECISION=auto()NVFP4=auto()W4A8_NVFP4_FP8=auto()W4A8_MXFP4_FP8=auto()W4A8_MXFP4_MXFP8=auto()W4A16_MXFP4=auto()NO_QUANT=auto()
QUANT_ALGO_LIST=list(set(QuantAlgo)-{QuantAlgo.INT8})KV_CACHE_QUANT_ALGO_LIST=[QuantAlgo.FP8,QuantAlgo.INT8,QuantAlgo.NVFP4]W8A8_SQ_PLUGIN_LIST=[QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN,QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN,QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN,QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN,]MODELOPT_FLOW_QUANTIZATIONS={QuantAlgo.W4A16_AWQ,QuantAlgo.FP8,QuantAlgo.W8A8_SQ_PER_CHANNEL,QuantAlgo.W4A8_AWQ}
[docs]classQuantMode(IntFlag):# [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/common/quantization.h# The weights are quantized to 4 bits.INT4_WEIGHTS=auto()# The weights are quantized to 8 bits.INT8_WEIGHTS=auto()# The activations are quantized.ACTIVATIONS=auto()# The method uses one scaling factor per channel. It's pre-computed (static) from the weights.PER_CHANNEL=auto()# The method uses one scaling factor per token. It's computed on-the-fly.PER_TOKEN=auto()# The method uses one scaling factor per group. It's pre-computed (static) from the weights.PER_GROUP=auto()# The KV cache is quantized in INT8.INT8_KV_CACHE=auto()# The KV cache is quantized in FP8.FP8_KV_CACHE=auto()# FP8 QDQFP8_QDQ=auto()# FP8 rowwiseFP8_ROWWISE=auto()# FP8 block scales for DeepseekFP8_1x128_128x128=auto()# W4A8 qserveW4A8_QSERVE=auto()# FP4NVFP4=auto()NVFP4_KV_CACHE=auto()# W4A8 NVFP4W4A8_NVFP4_FP8=auto()# W4A8 MXFP4W4A8_MXFP4_FP8=auto()W4A8_MXFP4_MXFP8=auto()W4A16_MXFP4=auto()# The smallest power-of-two that is not used by a flag. Do not call auto() after that line.COUNT=auto()# Bitmask to detect if weights, activations or both are quantized.WEIGHTS_AND_ACTIVATIONS=INT4_WEIGHTS|INT8_WEIGHTS|ACTIVATIONS# The mask of all valid flags.VALID_FLAGS=COUNT-1def__deepcopy__(self,memo):returnself# All the bits set? You can restrict the test to the bits indicated by "mask".def_all(self,bits,mask=VALID_FLAGS):return(self&mask)==bits# Is one of the bits of the mask set?def_any(self,bits):return(self&bits)!=0defis_int8_weight_only(self):returnself._all(self.INT8_WEIGHTS,self.WEIGHTS_AND_ACTIVATIONS)defis_int4_weight_only(self):returnself._all(self.INT4_WEIGHTS,self.WEIGHTS_AND_ACTIVATIONS)defis_weight_only(self):returnself.is_int4_weight_only()orself.is_int8_weight_only()defis_int8_weight_only_per_group(self):returnself.is_int8_weight_only()andself._any(self.PER_GROUP)# TODO: Using the current flags cannot distinguish between w4aFP8 AWQ and w4a8 QServe.defis_qserve_w4a8(self):returnself._any(self.W4A8_QSERVE)defis_int4_weight_only_per_group(self):returnself.is_int4_weight_only()andself._any(self.PER_GROUP)defhas_act_and_weight_quant(self):returnself._all(self.INT8_WEIGHTS|self.ACTIVATIONS,self.WEIGHTS_AND_ACTIVATIONS)defhas_act_or_weight_quant(self):returnself._any(self.INT4_WEIGHTS|self.INT8_WEIGHTS|self.ACTIVATIONS)defhas_per_token_dynamic_scaling(self):returnself._any(self.PER_TOKEN)defhas_fp8_block_scales(self):returnself._any(self.FP8_1x128_128x128)defhas_act_static_scaling(self):returnnotself.has_per_token_dynamic_scaling()andnotself.has_fp8_rowwise()defhas_per_channel_scaling(self):returnself._any(self.PER_CHANNEL)defhas_per_group_scaling(self):returnself._any(self.PER_GROUP)defhas_int8_kv_cache(self):returnself._any(self.INT8_KV_CACHE)defhas_fp8_kv_cache(self):returnself._any(self.FP8_KV_CACHE)defhas_fp4_kv_cache(self):returnself._any(self.NVFP4_KV_CACHE)defhas_kv_cache_quant(self):return(self.has_int8_kv_cache()orself.has_fp8_kv_cache()orself.has_fp4_kv_cache())defhas_fp8_qdq(self):returnself._any(self.FP8_QDQ)defhas_fp8_rowwise(self):returnself._any(self.FP8_ROWWISE)defhas_nvfp4(self):returnself._any(self.NVFP4)defhas_w4a8_nvfp4_fp8(self):returnself._any(self.W4A8_NVFP4_FP8)defhas_w4a8_mxfp4_fp8(self):returnself._any(self.W4A8_MXFP4_FP8)defhas_w4a8_mxfp4_mxfp8(self):returnself._any(self.W4A8_MXFP4_MXFP8)defhas_w4a16_mxfp4(self):returnself._any(self.W4A16_MXFP4)defhas_mxfp4(self):returnself._any(self.W4A8_MXFP4_FP8|self.W4A8_MXFP4_MXFP8|self.W4A16_MXFP4)defhas_weight_quant(self):returnself._any(self.INT4_WEIGHTS|self.INT8_WEIGHTS)defhas_any_quant(self,exclude_kv_cache:bool=False):has_quant=self._any(self.INT4_WEIGHTS|self.INT8_WEIGHTS|self.ACTIVATIONS|self.FP8_QDQ|self.FP8_ROWWISE|self.W4A8_QSERVE|self.FP8_1x128_128x128|self.NVFP4|self.W4A8_NVFP4_FP8|self.W4A8_MXFP4_FP8|self.W4A16_MXFP4|self.W4A8_MXFP4_MXFP8)ifexclude_kv_cache:returnhas_quantreturnhas_quant|self._any(self.INT8_KV_CACHE|self.FP8_KV_CACHE|self.NVFP4_KV_CACHE)defset_int8_kv_cache(self):returnself|self.INT8_KV_CACHEdefset_fp8_kv_cache(self):returnself|self.FP8_KV_CACHEdefset_fp4_kv_cache(self):returnself|self.NVFP4_KV_CACHEdefset_fp8_qdq(self):returnself|self.FP8_QDQdefset_fp8_rowwise(self):returnself|self.FP8_ROWWISE|self.PER_TOKEN|self.PER_CHANNEL@staticmethoddeffrom_description(quantize_weights=False,quantize_activations=False,per_token=False,per_channel=False,per_group=False,use_int4_weights=False,use_int8_kv_cache=False,use_fp8_kv_cache=False,use_fp8_qdq=False,use_fp8_block_scales=False,use_fp8_rowwise=False,use_nvfp4=False,use_w4a8_nvfp4_fp8=False,use_w4a8_qserve=False,use_w4a8_mxfp4_fp8=False,use_w4a8_mxfp4_mxfp8=False,use_w4a16_mxfp4=False):defraise_error():raiseValueError(f"Unsupported combination of QuantMode args: "f"{quantize_weights=}, "f"{quantize_activations=}, "f"{per_token=}, "f"{per_channel=}, "f"{per_group=}, "f"{use_int4_weights=}, "f"{use_int8_kv_cache=}, "f"{use_fp8_kv_cache=}, "f"{use_fp8_qdq=}, "f"{use_fp8_block_scales=}, "f"{use_fp8_rowwise=}, "f"{use_nvfp4=}, "f"{use_w4a8_qserve=}, "f"{use_w4a8_mxfp4_fp8=}, "f"{use_w4a8_mxfp4_mxfp8=}, "f"{use_w4a16_mxfp4=}")# We must quantize weights when we quantize activations.ifquantize_activationsandnotquantize_weights:raise_error()# If we set per_token or per_channel, we must quantize both weights and activations.if(per_tokenorper_channel)andnot(quantize_weightsandquantize_activations):raise_error()mode=QuantMode(0)# Do we quantize the weights - if so, do we use INT4 or INT8?ifquantize_weightsanduse_int4_weights:mode=mode|QuantMode.INT4_WEIGHTSelifquantize_weights:mode=mode|QuantMode.INT8_WEIGHTS# Do we quantize the activations?ifquantize_activations:mode=mode|QuantMode.ACTIVATIONS# Per-channel/per-token/per-group additional flags.ifper_channel:mode=mode|QuantMode.PER_CHANNELifper_token:mode=mode|QuantMode.PER_TOKENifper_group:mode=mode|QuantMode.PER_GROUP# Int8 KV cacheifuse_int8_kv_cache:mode=mode|QuantMode.INT8_KV_CACHE# FP8 KV cacheifuse_fp8_kv_cache:mode=mode|QuantMode.FP8_KV_CACHEifuse_fp8_qdq:mode=mode|QuantMode.FP8_QDQifuse_fp8_rowwise:mode=mode|QuantMode.FP8_ROWWISE|QuantMode.PER_TOKEN|QuantMode.PER_CHANNELifuse_fp8_block_scales:mode=mode|QuantMode.FP8_1x128_128x128ifuse_nvfp4:mode=mode|QuantMode.NVFP4ifuse_w4a8_nvfp4_fp8:mode=mode|QuantMode.W4A8_NVFP4_FP8# W4A8 QServeifuse_w4a8_qserve:mode=mode|QuantMode.W4A8_QSERVEifuse_w4a8_mxfp4_fp8:mode=mode|QuantMode.W4A8_MXFP4_FP8ifuse_w4a8_mxfp4_mxfp8:mode=mode|QuantMode.W4A8_MXFP4_MXFP8ifuse_w4a16_mxfp4:mode=mode|QuantMode.W4A16_MXFP4returnmode@staticmethoddefuse_smooth_quant(per_token=False,per_channel=False):returnQuantMode.from_description(True,True,per_token,per_channel)@staticmethoddefuse_qserve(per_group):returnQuantMode.from_description(quantize_weights=True,quantize_activations=True,per_group=per_group,use_int4_weights=True,use_w4a8_qserve=True)@staticmethoddefuse_weight_only(use_int4_weights=False,per_group=False):returnQuantMode.from_description(quantize_weights=True,quantize_activations=False,per_token=False,per_channel=False,per_group=per_group,use_int4_weights=use_int4_weights)@staticmethoddeffrom_quant_algo(quant_algo:Optional[QuantAlgo]=None,kv_cache_quant_algo:Optional[QuantAlgo]=None,)->"QuantMode":assertquant_algoisNoneorquant_algoinQUANT_ALGO_LISTassertkv_cache_quant_algoisNoneorkv_cache_quant_algoinKV_CACHE_QUANT_ALGO_LISTifquant_algo==QuantAlgo.W8A16:quant_mode=QuantMode.use_weight_only(use_int4_weights=False)elifquant_algo==QuantAlgo.W4A16:quant_mode=QuantMode.use_weight_only(use_int4_weights=True)elifquant_algo==QuantAlgo.W4A16_AWQ:quant_mode=QuantMode.use_weight_only(use_int4_weights=True,per_group=True)elifquant_algo==QuantAlgo.W4A8_AWQ:quant_mode=QuantMode.use_weight_only(use_int4_weights=True,per_group=True)elifquant_algo==QuantAlgo.W4A16_GPTQ:quant_mode=QuantMode.use_weight_only(use_int4_weights=True,per_group=True)elifquant_algo==QuantAlgo.W8A16_GPTQ:quant_mode=QuantMode.use_weight_only(use_int4_weights=False,per_group=True)elifquant_algo==QuantAlgo.W8A8_SQ_PER_CHANNEL:quant_mode=QuantMode.use_smooth_quant(per_token=False,per_channel=True)elifquant_algo==QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN:quant_mode=QuantMode.use_smooth_quant(per_token=False,per_channel=False)elifquant_algo==QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN:quant_mode=QuantMode.use_smooth_quant(per_token=True,per_channel=True)elifquant_algo==QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN:quant_mode=QuantMode.use_smooth_quant(per_token=False,per_channel=True)elifquant_algo==QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN:quant_mode=QuantMode.use_smooth_quant(per_token=True,per_channel=False)elifquant_algo==QuantAlgo.W4A8_QSERVE_PER_GROUP:quant_mode=QuantMode.use_qserve(per_group=True)elifquant_algo==QuantAlgo.W4A8_QSERVE_PER_CHANNEL:quant_mode=QuantMode.use_qserve(per_group=False)elifquant_algo==QuantAlgo.FP8:quant_mode=QuantMode.from_description(use_fp8_qdq=True)elifquant_algo==QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN:quant_mode=QuantMode.from_description(use_fp8_rowwise=True)elifquant_algo==QuantAlgo.FP8_BLOCK_SCALES:quant_mode=QuantMode.from_description(use_fp8_block_scales=True)elifquant_algo==QuantAlgo.NVFP4:quant_mode=QuantMode.from_description(use_nvfp4=True)elifquant_algo==QuantAlgo.W4A8_NVFP4_FP8:quant_mode=QuantMode.from_description(use_w4a8_nvfp4_fp8=True)elifquant_algo==QuantAlgo.W4A8_MXFP4_FP8:quant_mode=QuantMode.from_description(use_w4a8_mxfp4_fp8=True)elifquant_algo==QuantAlgo.W4A8_MXFP4_MXFP8:quant_mode=QuantMode.from_description(use_w4a8_mxfp4_mxfp8=True)elifquant_algo==QuantAlgo.W4A16_MXFP4:quant_mode=QuantMode.from_description(use_w4a16_mxfp4=True)else:quant_mode=QuantMode(0)ifkv_cache_quant_algo==QuantAlgo.INT8:quant_mode=quant_mode.set_int8_kv_cache()elifkv_cache_quant_algo==QuantAlgo.FP8:quant_mode=quant_mode.set_fp8_kv_cache()elifkv_cache_quant_algo==QuantAlgo.NVFP4:quant_mode=quant_mode.set_fp4_kv_cache()returnquant_modedefto_dict(self):return{'use_smooth_quant':self.has_act_and_weight_quant(),'per_channel':self.has_per_channel_scaling(),'per_token':self.has_per_token_dynamic_scaling(),'per_group':self.has_per_group_scaling(),'int8_kv_cache':self.has_int8_kv_cache(),'enable_fp8':self.has_fp8_qdq(),'enable_fp8_rowwise':self.has_fp8_rowwise(),'enable_fp8_block_scales':self.has_fp8_block_scales(),'enable_nvfp4':self.has_nvfp4(),'enable_w4a8_nvfp4_fp8':self.has_w4a8_nvfp4_fp8(),'enable_w4a8_mxfp4_fp8':self.has_w4a8_mxfp4_fp8(),'enable_w4a8_mxfp4_mxfp8':self.has_w4a8_mxfp4_mxfp8(),'enable_w4a16_mxfp4':self.has_w4a16_mxfp4(),'fp8_kv_cache':self.has_fp8_kv_cache(),'use_weight_only':self.is_weight_only(),'weight_only_precision':'int8'ifself.is_int8_weight_only()else'int4',}
classGroupwiseQuantAlgo:BIAS=1ZERO=2PRE_QUANT_SCALE=4W4A8_ALPHA=8INT8_WEIGHT=16