Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

[ICML 2024] Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

NotificationsYou must be signed in to change notification settings

mit-han-lab/Quest

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[paper] [poster] [slides]

News

  • [2024/10] 🔥 We released Quest support for theLlama-3.1 andMistral-v0.3 model family! Check out our examplehere.

TL;DR

Quest is an efficient long-context LLM inference framework that leveragesquery-aware sparsity in KV cache to reduce memory movement during attention and thus boost throughput.

Abstract

As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128k or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query.

To this end, we propose Quest, a query-aware token criticality estimation algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 7.03× self-attention speedup, which reduces inference latency by 2.23× while performing well on tasks with long dependencies with negligible accuracy loss.

Installation

  1. Clone this repo (also clone submodules)
git clone --recurse-submodules https://github.com/mit-han-lab/questcd quest
  1. Install dependency libraries
conda create -yn quest python=3.10conda activate quest# Questpip install -e .# Flash-Attentionpip install ninja packagingpip install flash-attn==2.6.3 --no-build-isolation# Install CMake (with version >= 3.26.4)conda install cmake# build libraftcd kernels/3rdparty/raft./build.sh libraft
  1. Compile kernel benchmarks (Optional). Remember to configure env variables for CUDA (Check thetutorial).
cd kernelsmkdir build && cd buildcmake ..make -j
  1. Build end-to-end operators with PyBind
# This will automatically build and link the operatorscd quest/opsbash setup.sh

Accuracy Evaluation

Our evaluations are based onLongChat-7B-v1.5-32K andYarn-Llama2-7B-128K models, which are capable of handling long-context text generations. We evaluate both passkey retrieval and LongBench benchmarks. We provide several scripts to reproduce our results in the paper:

To get the Passkey Retrieval results, please modify and execute:

bash scripts/passkey.sh

To reproduce the LongBench results, please modify and execute:

bash scripts/longbench.sh

To evaluate the perplexity result of PG-19, please execute:

bash scripts/ppl_eval.sh

Efficiency Evaluation

Kernels and end-to-end effiency are evaluated on NVIDIA Ada6000 and RTX4090 GPUs with CUDA version of 12.4. We provide several scripts to reproduce our results in the paper:

Kernel-level Efficiency

We also release the unit tests and benchmarks used for kernel implementations. Correctness of kernel is verified by unit tests inkernels/src/test, while performance is evaluated by NVBench inkernels/src/bench. We also test the correctness of PyBind operators inquest/tests with PyTorch results via PyTest.

To test the correctness of kernels, please execute:

cd kernels/build./test_batch_decode # or any other operator

Or utilize PyTest:

cd quest/testsPYTHONPATH=$PYTHONPATH:../../ pytest

To reproduce the kernel performance shown in paper, please execute:

cd kernels/build./bench_batch_decode -a seqlen=4096 -a page_budget=[64,512]# or any other operator

With sample output:

End-to-end Efficiency

Quest can achieve up to 2.23× end-to-end speedup while performing well on tasks with long dependencies with negligible accuracy loss:

We incorporate all implemented operators into a full pipeline to evaluate the end-to-end efficiency in text generations. Based on theHuggingface Transformers, we enable a KV-Cache manager which supports query-aware sparsity as shown inquest/models/QuestAttention.py.

To reproduce the end-to-end efficiency results in Figure.10, please execute:

bash scripts/bench_efficiency_e2e.sh

For the qualitative analysis of baselines, we use FlashInfer kernel to estimate the performance of H2O and TOVA. To reproduce the results in Figure.11, please execute:

bash scripts/bench_kernels.sh

Examples

We provide several examples to demonstrate the usage of Quest. These examples are implemented with the end-to-end integration of Quest operators, and can be executed with the following commands (please make sure you have setup all the operators):

python3 scripts/example_textgen.py

With example output of long-context summarization under LongChat-7B-v1.5-32K model:

You can also tryscripts/example_demo.py to test the performance of Quest on your own text generation tasks. We provide a simple interface to load the model and generate text with Quest operators. The above demo is an example with 32K input on FP16 LongChat-7B-v1.5-32K. Quest with 2048 token budget achieves 1.7x speedup compared to full cache FlashInfer version.

TODOs

  • Support GQA models

Reference

If you find this project is helpful to your research, please consider to cite our paper:

@misc{tang2024quest,      title={Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference},       author={Jiaming Tang and Yilong Zhao and Kan Zhu and Guangxuan Xiao and Baris Kasikci and Song Han},      year={2024},      eprint={2406.10774},      archivePrefix={arXiv},      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}}

Related Projects

This codebase utilizeslm_eval to evaluate perplexity and zero-shot accuracy. It also adapts code snippets fromH2O,StreamingLLM andPunica. Our kernels are implemented based onFlashInfer (a performant and extensible kernel library for LLM serving) and tested byNVBench. Thanks for the great works from our community!

H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models

TOVA: Transformers are Multi-State RNNs

StreamingLLM: Efficient Streaming Language Models with Attention Sinks

AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration

About

[ICML 2024] Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

[8]ページ先頭

©2009-2025 Movatter.jp