Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Fast Matrix Multiplications for Lookup Table-Quantized LLMs

License

NotificationsYou must be signed in to change notification settings

HanGuo97/flute

Repository files navigation

FLUTE: Flexible Lookup Table Engine for LUT-quantized LLMs

GitHub LicenseVersionarXiv

Update

  • Fenburary, 2024. HIGGS will appear in NAACL 2025.
  • Jan 9, 2025. Added (very) experimental support for removing specialization on shapes + GPU via auto-tune.
  • December 12, 2024. Added support for Hadamard Transform (viaHadaCore).
  • November 26, 2024. Added support forvector (de)quantization (vector_size=2), as part ofHIGGS.
  • October 5, 2024. FLUTE will appear in EMNLP 2024 (Findings).
  • September 15, 2024. Addedexperimental support for loading pre-quantized FLUTE models in HuggingFace.
  • September 6, 2024. Added (unlearned) NF-quantized LLaMA-3.1 (405B) models:base andinstruction tuned.
  • August 31, 2024. Addedsupport andexample for the Learned Normal Float (NFL) quantization.
  • August 26, 2024. Addedsupport for convertingbitsandbytes model into FLUTE model.
  • August 5, 2024. Added quantized LLaMA-3.1 (8B/70B) models.
  • August 2, 2024. Added support for RTX4090.
  • July 27, 2024. Added support for LLaMA-3.1 (405B) and tuned BF16 performance. FP16 is still the recommended data type, especially for 3-bit settings.

Installation

Install FLUTE with pip orfrom source:

# For CUDA 12.1pip install flute-kernel# For CUDA 11.8pip install flute-kernel -i https://flute-ai.github.io/whl/cu118# For CUDA 12.4pip install flute-kernel -i https://flute-ai.github.io/whl/cu124

Head over toGetting Started and try it out!

Background

Uniform quantization converts full precision weights to lower-precision intervals of equal size.Lookup table (LUT) quantization is a flexible variant of non-uniform quantization which can map intervals to arbitrary values via a lookup table.

Uniform (Integer) QuantizationLookup Table Quantization

$$\widehat{\mathbf{W}} = \mathtt{float}(\mathbf{Q}) \cdot \mathbf{s}$$

$$\widehat{\mathbf{W}} = \mathtt{tableLookup}(\mathbf{Q}, \mathtt{table}) \cdot \mathbf{s}$$

where$\mathbf{Q}$ denote the quantized weight,$\mathbf{s}$ the (group-wise) scales, and$\widehat{\mathbf{W}}$ the de-quantized weight. Here are some examples of the lookup table suppored in FLUTE.

ExamplesNotes

int4,int3,int2

recovers uniform/integer quantization

fp4,fp3,fp2

nf4,nf3,nf2

generalizes thenf4 data-format introduced in QLoRA

any arbitrary table

you could even learn it!

New Models Powered by FLUTE

The flexibility of the kernel could lead to new quantization algorithms. As a proof of concept, we are releasing a fewmodels quantized usingLearned Normal Float (NFL) --- a simple extension to thenf4 data format introduced in QLoRA. NFL initialized the lookup table and the scales with those from NF quantization. Then, it uses calibration data to learn the scales via straight through estimation for for the gradient with respect to the scales.

Benchmarks

For additional benchmarks, detailed breakdowns, and corresponding instruction-tuned models, please refer to the paper and themodel zoo.

LLaMA-3.1

Wiki PPLC4 PPLLLM Eval Avg.Wiki PPLC4 PPLLLM Eval Avg.
LLaMA-3.1 (8B)6.319.6069.75LLaMA-3.1 (70B)2.827.1875.45
+ NFL W4G646.2410.0669.13+ NFL W4G643.097.5374.84
+ NFL W3G647.2311.8365.66+ NFL W3G644.298.9172.65

Gemma-2

Wiki PPLC4 PPLLLM Eval Avg.Wiki PPLC4 PPLLLM Eval Avg.
Gemma-2 (9B)6.8810.1273.12Gemma-2 (27B)5.708.9875.71
+ NFL W4G646.4910.3572.50+ NFL W4G645.699.3174.11

Getting Started

FLUTE + vLLM

FLUTE-quantized models (Model Zoo) can be directly served using exisiting frameworks such as vLLM.

- python -m vllm.entrypoints.openai.api_server \+ python -m flute.integrations.vllm vllm.entrypoints.openai.api_server \    --model [MODEL] \    --revision [REVISION] \    --tensor-parallel-size [TP_SIZE] \+   --quantization flute

For example, the following commmand runs the FLUTE-quantized LLaMA-3.1 (8B) on a single GPU.

python -m flute.integrations.vllm vllm.entrypoints.openai.api_server \    --model radi-cho/Meta-Llama-3.1-8B-FLUTE \    --quantization flute

We can then query the vLLM server as usual.

curl http://localhost:8000/v1/completions \    -H"Content-Type: application/json" \    -d'{        "model": "radi-cho/Meta-Llama-3.1-8B-FLUTE",        "prompt": "San Francisco is a",        "max_tokens": 7,        "temperature": 0    }'

FLUTE + HuggingFace

FLUTE also runs out of the box with HuggingFace and itsaccelerate extension. This integration is mostly experimental and not optimized. Users sensitive to performance considerations should use thevLLM integration instead.

  1. Loading a pre-quantized FLUTE model.
import flute.integrations.huggingface- model = AutoModelForCausalLM.from_pretrained(+ model = flute.integrations.huggingface.from_pretrained(    "radi-cho/Meta-Llama-3.1-8B-FLUTE",    # all of your favoriate HF flags will be forwarded    device_map="auto")
  1. Loading and quantizing a dense model.
importflute.integrations.baseflute.integrations.base.prepare_model_flute(name="model.model.layers",module=model.model.layers,# for LLaMA-3 and Gemma-2num_bits=num_bits,group_size=group_size,fake=False,handle_hooks=True)# for `accelerate` hooks

After this, the model can be used as normal. Please checkout the quantizationguide for more information.

Support and Compatibility

Kernel

DescriptionSupported (via pip)Supported (build from source)
Input dtypestorch.float16torch.bfloat16
Bits4bit3bit2bit
Group Sizes3264128256
GPUsA100A6000RTX 4090H100 (unoptimized)

Warning

In the current release, we noticedtorch.bfloat16 is slower thantorch.float16. This likely because of lack of tuning, and that Ampere GPUs lack a hardware acceleration forbfloat16vectorized atomic-add.

Warning

We noticed several numerically unstable situations usingbits=4, group-size=256, GPU=A100, though this is relatively rare (8 of 9360 test cases failed). We also noticed correctness issues in some situations withbits=4, group-size=256, dtype=bfloat16, GPU=RTX4090 (1 of 52 test cases failed). We will be looking into this, but we suggest avoiding these particular use cases (W4G256) for now.

Models

Note

As of the current release, the kernel is shape-specialized due to legacy reasons (i.e., we tune tile sizes etc for each matrix shape). Please see the below chart for the supported use cases, as different platform and tensor parallel size changes the matrix shapes. We plan to add supports for a broad range of shapes in the near future. In the meantime, please let us know if you have any specific models in mind and we are happy to add support for them.

ModelSingle GPU / Pipeline ParallelTensor Parallel
LLaMA-3/3.1 (8B)
LLaMA-3/3.1 (70B)2 or 4 GPUs
LLaMA-3.1 (405B)4 or 8 GPUs
Gemma-2 (9B)
Gemma-2 (27B)2 or 4 GPUs

Model Zoo

Note

The models we release here are trained on more data and hence different from those in the paper.

Tip

The HuggingFace Hub links are forNFL W4G64 quantization by default. To use theNFL W3G64 quantization, add--revision nfl_w3g64.

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.319.6079.1682.2052.6560.7174.0369.75
NFL W4G646.2410.0679.3881.6151.5459.5773.5669.13
NFL W3G647.2311.8377.9176.9846.3356.7470.3265.66
WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized2.827.1882.8185.3159.6467.4982.0075.45
NFL W4G643.097.5383.0385.5258.1967.0480.4374.84
NFL W3G644.298.9182.0483.2954.7864.9978.1472.65

Note that the weights are in the branchnf_w4g64 and thus--revision nf_w4g64 is needed since these are not on the default branch.

WikiC4
NFL W4G646.7811.11
NFL W3G647.7312.83
WikiC4
NFL W4G644.159.18
NFL W3G644.749.48

Note that the weights are in the branchnf_w4g64 and thus--revision nf_w4g64 is needed since these are not on the default branch.

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.19.279.980.150.460.272.868.6
NFL W4G646.119.3879.3379.7949.7459.2273.9568.41
NFL W3G647.1311.0678.7876.2244.3756.6970.3265.28
WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized2.96.982.486.960.366.480.675.3
NFL W4G643.037.0382.1585.9857.8566.1779.7974.39
NFL W3G644.158.1080.7483.7155.2964.0578.4572.45
WikiC4
NFL W4G646.7810.61
NFL W3G647.7512.28
WikiC4
NFL W4G643.677.95
NFL W3G644.9010.86
WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.8810.1281.3987.3761.3561.2374.2773.12
NFL W4G646.4910.3581.2886.2459.3060.4075.3072.50
NFL W3G647.0611.1480.5283.1655.4658.2872.6970.02
WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized5.708.9883.2487.8462.8865.3579.2475.71
NFL W4G645.699.3182.5386.4559.2264.1378.2174.11
WikiC4
NFL W4G646.8811.02
NFL W3G647.3511.72
WikiC4
NFL W4G645.919.71

Quantizing Your Own Models

We provide two APIs to quantize a custom models. The easist way is to use the command line interface.

Simple Normal Float Quantization

python -m flute.integrations.base \    --pretrained_model_name_or_path meta-llama/Meta-Llama-3-70B-Instruct \    --save_directory Meta-Llama-3-70B-Instruct-NF4 \    --num_bits 4 \    --group_size 128

The CLI essentially wraps around the following Python API,

fromtransformersimport (LlamaForCausalLM,Gemma2ForCausalLM,AutoModelForCausalLM)importflute.integrations.basemodel=AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,device_map="cpu",torch_dtype="auto")ifisinstance(model, (LlamaForCausalLM,Gemma2ForCausalLM)):flute.integrations.base.prepare_model_flute(name="model.model.layers",module=model.model.layers,num_bits=num_bits,group_size=group_size,fake=False)else:# more models to comeraiseNotImplementedError

Convertingbitsandbytes Model into FLUTE Model

While FLUTE has its own Normal Float (NF) implementation, we could convert an existing HuggingFace model quantized viabitsandbytes into FLUTE format. To do so, just add two lines to the Python API,

flute.integrations.base.prepare_model_flute(    name="model.model.layers",    module=model.model.layers,    num_bits=num_bits,    group_size=group_size,    fake=False,+   prepare_bnb_layers=True,+   default_bnb_dtype=torch.float16,)

It's worth noting that we do not support double quantization, and the conversion will materialize the first-level scales.

Learned Normal Float Quantization (NFL)

NFL initialized the lookup table and the scales with those from NF quantization. Then, it uses calibration data to learn the scales via straight through estimation for for the gradient with respect to the scales.

To use NFL quantization, call the following function beforeprepare_model_flute. We also provide anexample jupyter notebook to illustrate the entire process.

importflute.integrations.learnableflute.integrations.learnable.learn_scales(model=model,tokenizer=tokenizer,num_bits=num_bits,group_size=group_size,custom_corpora=list_of_corpora,samples=num_samples,)

Extending to New Models (Experimental)

At the moment, FLUTE kernel is specialized to the combination of GPU, matrix shapes, data types, bits, and group sizes. This means adding supporting new models requires tuning the kernel configurations for the corresponding use cases. We are hoping to add support for just-in-time tuning, but in the meantime, here are the ways to tune the kernel ahead-of-time.

Step 1: Build theraw version of the library that exposes all templates.

  1. Reset the previously tuned kernel,
cp flute/csrc/qgemm_kernel_generated.template.cu flute/csrc/qgemm_kernel_generated.cu
  1. Un-comment the combination(s) to tune influte/csrc/qgemm_kernel_raw_generated.cu,
INSTANTIATE_TEMPLATE(NUM_SMs, DTYPE, cute::uint16_t, __half2, BITS, GROUP_SIZE);
Example for W4G64 on A100
-// INSTANTIATE_TEMPLATE(108, cute::half_t    , cute::uint16_t, __half2       , 4, 64);+INSTANTIATE_TEMPLATE(108, cute::half_t    , cute::uint16_t, __half2       , 4, 64);-// INSTANTIATE_TEMPLATE(108, cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 64);+INSTANTIATE_TEMPLATE(108, cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 64);
  1. Remove settingsnot tuned influte/csrc/qgemm.cpp,flute/__init__.py, andflute/ops.py

Note

Although including other settings could still build, it could break the linking process and require re-compiling the library.

Example for W4G64 on A100
diff --git a/flute/csrc/qgemm.cpp b/flute/csrc/qgemm.cppindex 84bae95..c4a0236 100644--- a/flute/csrc/qgemm.cpp+++ b/flute/csrc/qgemm.cpp@@ -314,3 +313,0 @@ qgemm_raw_simple(const at::Tensor& input,-        case 32:                                      \-            RUN_QGEMM_RAW(T, NUM_BITS, 32);           \-            break;                                    \@@ -320,6 +316,0 @@ qgemm_raw_simple(const at::Tensor& input,-        case 128:                                     \-            RUN_QGEMM_RAW(T, NUM_BITS, 128);          \-            break;                                    \-        case 256:                                     \-            RUN_QGEMM_RAW(T, NUM_BITS, 256);          \-            break;                                    \@@ -335,6 +325,0 @@ qgemm_raw_simple(const at::Tensor& input,-        case 2:                                          \-            RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 2);       \-            break;                                       \-        case 3:                                          \-            RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 3);       \-            break;                                       \@@ -381 +366 @@ TORCH_LIBRARY(flute, m) {-    // m.def("qgemm_raw_simple_80(Tensor input, Tensor weight, Tensor(a!) output, Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int num_bits, int group_size, int template_id) -> ()");+    m.def("qgemm_raw_simple_80(Tensor input, Tensor weight, Tensor(a!) output, Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace,int num_bits, int group_size, int template_id) -> ()");@@ -391 +376 @@ TORCH_LIBRARY_IMPL(flute, CUDA, m) {-    // m.impl("qgemm_raw_simple_80", &qgemm_raw_simple<cute::Int<108>>);+    m.impl("qgemm_raw_simple_80", &qgemm_raw_simple<cute::Int<108>>);
diff --git a/flute/__init__.py b/flute/__init__.pyindex 34b1a26..f524841 100644--- a/flute/__init__.py+++ b/flute/__init__.py@@ -69 +69 @@ QGEMM_SIMPLE_DICT = {-# QGEMM_RAW_SIMPLE_DICT = {+QGEMM_RAW_SIMPLE_DICT = {@@ -71 +71 @@ QGEMM_SIMPLE_DICT = {-#     108: cast(QGEMM_RAW_SIMPLE_TYPE, torch.ops.flute.qgemm_raw_simple_80),+    108: cast(QGEMM_RAW_SIMPLE_TYPE, torch.ops.flute.qgemm_raw_simple_80),@@ -73 +73 @@ QGEMM_SIMPLE_DICT = {-# }+}@@ -76 +76 @@ qgemm_simple     = QGEMM_SIMPLE_DICT[NUM_SMS]-qgemm_raw_simple = None  # QGEMM_RAW_SIMPLE_DICT[NUM_SMS]+qgemm_raw_simple = QGEMM_RAW_SIMPLE_DICT[NUM_SMS]
diff --git a/flute/ops.py b/flute/ops.pyindex 9fd91a2..80782ea 100644--- a/flute/ops.py+++ b/flute/ops.py@@ -124 +124 @@ def _qgemm_simple_89_abstract(-# @torch.library.register_fake("flute::qgemm_raw_simple_80")+@torch.library.register_fake("flute::qgemm_raw_simple_80")
  1. Build from source (see instructions below).
pip install -e. --no-build-isolation# `--no-build-isolation` is optional

Depending on the number of configurations to tune, this could take time in the order of tens of minutes to hours.

Step 2: Tune FLUTE on the new matrix shapes.

importtorchfromflute.tuneimportTuneTask,tune_tasks_legacytasks= [TuneTask(M=1,# batch size (x sequence length, usually 1 for token-by-token generation)N=1024,# parameter dimension (note when using tensor-parallelism, this could change)K=4096,# parameter dimension (note when using tensor-parallelism, this could change)num_bits=4,# number of bitsgroup_size=64,# group sizenum_sms=108,# number of streaming multiprocessors of the GPUdtype=torch.float16,# data typedevice=torch.device("cuda:0")    ),]tune_tasks_legacy(tasks)

After this step is complete, artifacts will be saved influte/data/.

Step 3: Build the newly-tuned kernel

# remove changesgit checkout -- flute/csrc/# generating new dispatching logic based on tuning artifactsbash scripts/codegen_tuned.sh# remove changesgit checkout -- \    flute/ops.py \    flute/__init__.py# Buildpip install -e. --no-build-isolation

Note that if only one data type is tuned, you will also need to editflute/utils.py.

Example
diff --git a/flute/utils.py b/flute/utils.pyindex 5add543..13f49c0 100644--- a/flute/utils.py+++ b/flute/utils.py@@ -270,7 +270,7 @@ def pack(          K, N = W.shape         template_ids = []-        for dtype in [torch.float16, torch.bfloat16]:+        for dtype in [torch.float16]:             template_id = TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(                 NUM_SMS,                 num_bits,

Finally, please follow the examples intests/ to verify that the kernel is working correctly.

Build From Source

  1. Clone the CUTLASS library.
# Unfortunately, the path is hard-coded as of now. If you install CUTLASS# in a different directory, please make sure the corresponding path in# `setup.py` is updated.cd /workspacegit clone https://github.com/NVIDIA/cutlass.gitcd cutlassgit checkout v3.4.1
  1. Build.
git clone https://github.com/HanGuo97/flutecd flutepip install -e.

Note: the build process requires having the local CUDA version (nvcc --version) match PyTorch's CUDA. In situations in which the build process throws an error related to CUDA version mismatch, try adding--no-build-isolation.

Acknowledgement and Citation

Special thanks to Dmytro Ivchenko, Yijie Bei, and the Fireworks AI team for helpful discussion. If you find any of the models or code in this repo useful, please feel free to cite:

@inproceedings{flute2024,title={Fast Matrix Multiplications for Lookup Table-Quantized LLMs},author={Guo, Han and Brandon, William and Cholakov, Radostin and Ragan-Kelley, Jonathan and Xing, Eric and Kim, Yoon},booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024},pages={12419--12433},year={2024}}@article{higgs2024,title={Pushing the Limits of Large Language Model Quantization via the Linearity Theorem},author={Malinovskii, Vladimir and Panferov, Andrei and Ilin, Ivan and Guo, Han and Richt{\'a}rik, Peter and Alistarh, Dan},journal={arXiv preprint arXiv:2411.17525},year={2024}}

[8]ページ先頭

©2009-2025 Movatter.jp