- Notifications
You must be signed in to change notification settings - Fork566
Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
License
meta-pytorch/gpt-fast
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Simple and efficient pytorch-native transformer text generation.
Featuring:
- Very low latency
- <1000 lines of python
- No dependencies other than PyTorch and sentencepiece
- int8/int4 quantization
- Speculative decoding
- Tensor parallelism
- Supports Nvidia and AMD GPUs
This isNOT intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire.
For an in-depth walkthrough of what's in this codebase, see thisblog post.
Please check the rest of this page about benchmark of LLaMA family models.
We also supportedMixtral 8x7B which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:
| 1 GPU | 2 GPU | 4 GPU | 8 GPU | |
|---|---|---|---|---|
| baseline(bfloat16) | OOM | 96.67 | 155.35 | 227.82 |
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run atbatch size=1, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
For more details about Mixtral 8x7B, please checkthis page or thisnote.
In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs.
Projects inspired by gpt-fast in the community:
- gpt-blazing: applies the same performance optimization strategy to more models (e.g., baichuan2).
- gptfast: applies a subset of the performance optimizations to all Huggingface models
- gpt-accelera: extends
gpt-fastto SFT/RM/PPO training and batched inference to optimize the throughput
Install required packages:
pip install -r requirements.txt
To download llama models, go tohttps://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access.Then login withhuggingface-cli login
Models tested/supported
tinyllamas/stories{15,42,100}openlm-research/open_llama_7bmeta-llama/Llama-2-7b-chat-hfmeta-llama/Llama-2-13b-chat-hfmeta-llama/Llama-2-70b-chat-hfcodellama/CodeLlama-7b-Python-hfcodellama/CodeLlama-34b-Python-hfmistralai/Mistral-7B-v0.1mistralai/Mistral-7B-Instruct-v0.1mistralai/Mistral-7B-Instruct-v0.2meta-llama/Meta-Llama-3-8Bmeta-llama/Meta-Llama-3.1-8Bmeta-llama/Meta-Llama-3.1-70Bmeta-llama/Meta-Llama-3.1-405BFor example, to convert Llama-2-7b-chat-hf
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf./scripts/prepare.sh$MODEL_REPO
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run atbatch size=1, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
|---|---|---|---|
| Llama-2-7B | Base | 104.9 | 1397.31 |
| 8-bit | 155.58 | 1069.20 | |
| 4-bit (G=32) | 196.80 | 862.69 | |
| Llama-2-70B | Base | OOM | |
| 8-bit | 19.13 | 1322.58 | |
| 4-bit (G=32) | 25.25 | 1097.66 | |
| Llama-3.1-8B | Base | 93.89 | 1410.76 |
| 8-bit | 137.64 | 1030.89 | |
| Llama-3.1-70B | Base | OOM | |
| 8-bit | 18.04 | 1253.78 |
Verifier: Llama-70B (int4), Draft: Llama-7B (int4): 48.4 tok/s
| Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) |
|---|---|---|---|
| Llama-2-7B | 1 | 104.9 | 1397.31 |
| 2 | 168.84 | 1181.99 | |
| 4 | 254.02 | 955.83 | |
| 8 | 328.43 | 704.10 | |
| Llama-2-70B | 1 | OOM | |
| 2 | 21.32 | 1481.87 | |
| 4 | 38.01 | 1340.76 | |
| 8 | 62.50 | 1135.29 | |
| Llama-3.1-8B | 1 | 93.83 | 1408.37 |
| 2 | 149.10 | 1197.32 | |
| 4 | 217.21 | 986.32 | |
| 8 | 276.01 | 772.60 | |
| Llama-3.1-70B | 1 | OOM | |
| 2 | 16.03 | 1130.81 | |
| 4 | 37.45 | 1360.53 | |
| 8 | 58.78 | 1129.61 |
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
|---|---|---|---|
| Llama-2-70B | Base | 62.50 | 1135.29 |
| 8-bit | 80.44 | 752.04 | |
| 4-bit (G=32) | 90.77 | 548.10 | |
| Llama-3.1-70B | Base | 58.78 | 1129.61 |
| 8-bit | 75.58 | 726.57 | |
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |
Benchmarks run on one GCD of a MI-250x.
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
|---|---|---|---|
| Llama-2-7B | Base | 76.33 | 1028.70 |
| 8-bit | 101.86 | 700.06 |
Model definition inmodel.py, generation code ingenerate.py.
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt"Hello, my name is"
To squeeze out a little bit more performance, you can also compile the prefill with--compile_prefill. This will increase compilation times though.
Choose device to use by
# The current support devices: cuda, cpuexport DEVICE=cuda
To generate this version of the model
# Spits out model at checkpoints/$MODEL_REPO/model_int8.pthpython quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
To run with int8, just pass the int8 checkpoint to generate.py.
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device$DEVICE
To generate int4 version of model
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pthpython quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
To run with int4, just pass the int4 checkpoint to generate.py.
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compileTo generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO).
In this example, the "smaller" model is just the int8 quantized version of the model.
export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hfpython generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pthNote: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s.
ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pthWe use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py.
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winograndeNote: Generative tasks are currently not supported for gpt-fast
Installation Instructions for the evaluation harness:https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install
We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantizedversion of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e.
# Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pthpython quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048You can then eval or generate text with this model in the same way as above.
gpt-fast is released under theBSD 3 license.
Thanks to:
- Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning.
- GGML for driving forward fast, on device inference of LLMs
- Karpathy for spearheading simple, interpretable and fast LLM implementations
- MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware
About
Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
Resources
License
Code of conduct
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.