- Notifications
You must be signed in to change notification settings - Fork386
PyTorch native quantization and sparsity for training and inference
License
pytorch/ao
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
- Pre-train Llama-3.1-70B1.5x faster with float8 training
- Recover67% of quantized accuracy degradation on Gemma3-4B with QAT
- Quantize Llama-3-8B to int4 for1.89x faster inference with58% less memory
Latest News |Overview |Quick Start |Installation |Integrations |Inference |Training |Videos |Citation
- [Oct 25] QAT is now integrated intoUnsloth for both full and LoRA fine-tuning! Try it out usingthis notebook.
- [Oct 25] MXFP8 MoE training prototype achieved~1.45x speedup for MoE layer in Llama4 Scout, and~1.25x speedup for MoE layer in DeepSeekV3 671b - with comparable numerics to bfloat16! Check out thedocs to try it out.
- [Sept 25] MXFP8 training achieved1.28x speedup on Crusoe B200 cluster with virtually identical loss curve to bfloat16!
- [Sept 19]TorchAO Quantized Model and Quantization Recipes Now Available on Huggingface Hub!
- [Jun 25] OurTorchAO paper was accepted to CodeML @ ICML 2025!
Older news
- [May 25] QAT is now integrated intoAxolotl for fine-tuning (docs)!
- [Apr 25] Float8 rowwise training yielded1.34-1.43x training speedup at 2k H100 GPU scale
- [Apr 25] TorchAO is added as aquantization backend to vLLM (docs)!
- [Mar 25] Our2:4 Sparsity paper was accepted to SLLM @ ICLR 2025!
- [Jan 25] Ourintegration with GemLite and SGLang yielded 1.1-2x faster inference with int4 and float8 quantization across different batch sizes and tensor parallel sizes
- [Jan 25] We added1-8 bit ARM CPU kernels for linear and embedding ops
- [Nov 24] We achieved1.43-1.51x faster pre-training on Llama-3.1-70B and 405B using float8 training
- [Oct 24] TorchAO is added as a quantization backend to HF Transformers!
- [Sep 24] We officially launched TorchAO. Check out our bloghere!
- [Jul 24] QATrecovered up to 96% accuracy degradation from quantization on Llama-3-8B
- [Jun 24] Semi-structured 2:4 sparsityachieved 1.1x inference speedup and 1.3x training speedup on the SAM and ViT models respectively
- [Jun 24] Block sparsityachieved 1.46x training speeedup on the ViT model with <2% drop in accuracy
TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box withtorch.compile() andFSDP2 across most HuggingFace PyTorch models.
🟢 = stable, 🟡 = prototype, 🟠 = planned, ⚪ = not supported
| recommended hardware | weight | activation | quantized training | QAT | PTQ data algorithms | quantized inference |
|---|---|---|---|---|---|---|
| H100, B200 GPUs | float8 rowwise | float8 rowwise | 🟢(link) | 🟢(link) | ⚪ | 🟢(link) |
| H100 GPUs | int4 | float8 rowwise | ⚪ | 🟢(link) | 🟠 | 🟢(link) |
| A100 GPUs | int4 | bfloat16 | ⚪ | 🟢(link) | 🟡:HQQ,AWQ,GPTQ | 🟢(link) |
| A100 GPUs | int8 | bfloat16 | ⚪ | 🟢(link) | ⚪ | 🟢(link) |
| A100 GPUs | int8 | int8 | 🟡(link) | 🟢(link) | ⚪ | 🟢(link) |
| edge | intx (1..7) | bfloat16 | ⚪ | 🟢(link) | ⚪ | 🟢(link) |
| edge | intx (1..7) | bfloat16 | ⚪ | 🟢(link) | ⚪ | 🟢(link) |
🟢 = stable, 🟡 = prototype, 🟠 = planned, ⚪ = not supported
| recommended hardware | weight | activation | quantized training | QAT | PTQ data algorithms | quantized inference |
|---|---|---|---|---|---|---|
| B200, MI350x GPUs | mxfp8 | mxfp8 | 🟡(dense),(moe) | ⚪ | ⚪ | 🟡(link) |
| B200 GPUs | nvfp4 | nvfp4 | 🟠 | 🟡(link) | ⚪ | 🟡(link) |
| B200, MI350x GPUs | mxfp4 | mxfp4 | ⚪ not supported | 🟠 | 🟠 | 🟡(link) |
| H100 | float8 128x128 (blockwise) | float8 1x128 | 🟠 | ⚪ | ⚪ | 🟡 |
- Quantization-Aware Training (QAT) README.md
- Post-Training Quantization (PTQ) README.md
- Sparsity README.md, includes different techniques such as 2:4 sparsity and block sparsity
- the prototype folder for other prototype features
Check out ourdocs for more details!
First, install TorchAO. We recommend installing the latest stable version:
pip install torchao
Quantize your model weights to int4!
fromtorchao.quantizationimportInt4WeightOnlyConfig,quantize_quantize_(model,Int4WeightOnlyConfig(group_size=32,int4_packing_format="tile_packed_to_4d",int4_choose_qparams_algorithm="hqq"))
See ourquick start guide for more details.
To install the latest stable version:
pip install torchao
Other installation options
# Nightlypip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128# Different CUDA versionspip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6pip install torchao --index-url https://download.pytorch.org/whl/cu129 # CUDA 12.9pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only# For developers# Note: the `--no-build-isolation` flag is required.USE_CUDA=1 pip install -e . --no-build-isolationUSE_CPP=0 pip install -e . --no-build-isolationPlease see thetorchao compability table for version requirements for dependencies.
TorchAO delivers substantial performance gains with minimal code changes:
- Int4 weight-only:1.73x speedup with 65% less memory for Gemma3-12b-it on H100 with slight impact on accuracy
- Float8 dynamic quantization:1.5-1.6x speedup on gemma-3-27b-it and1.54x and 1.27x speedup on Flux.1-Dev* and CogVideoX-5b respectively on H100 with preserved quality
- Int8 activation quantization and int4 weight quantization: Quantized Qwen3-4B running with 14.8 tokens/s with 3379 MB memory usage on iPhone 15 Pro throughExecuTorch
- Int4 + 2:4 Sparsity:2.37x throughput with 67.7% memory reduction on Llama-3-8B
Following is our recommended flow for quantization and deployment:
fromtransformersimportTorchAoConfig,AutoModelForCausalLMfromtorchao.quantizationimportFloat8DynamicActivationFloat8WeightConfig,PerRow# Create quantization configurationquantization_config=TorchAoConfig(quant_type=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))# Load and automatically quantizequantized_model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B",dtype="auto",device_map="auto",quantization_config=quantization_config)
Alternative quantization API to use when the above doesn't work isquantize_ API inquick start guide.
Serving with vllm on 1xH100 machine:
# ServerVLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Qwen3-32B-FP8 --tokenizer Qwen/Qwen3-32B -O3# Clientcurl http://localhost:8000/v1/chat/completions -H"Content-Type: application/json" -d'{ "model": "pytorch/Qwen3-32B-FP8", "messages": [ {"role": "user", "content": "Give me a short introduction to large language models."} ], "temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_tokens": 32768}'
We also support deployment to edge devices through ExecuTorch, for more detail, seequantization and serving guide. We also release pre-quantized modelshere.
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration withTorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to theQAT README and theoriginal blog:
importtorchfromtorchao.quantizationimportquantize_,Int8DynamicActivationIntxWeightConfig,PerGroupfromtorchao.quantization.qatimportQATConfig# preparebase_config=Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4,weight_granularity=PerGroup(32),)quantize_(my_model,QATConfig(base_config,step="prepare"))# train model (not shown)# convertquantize_(my_model,QATConfig(base_config,step="convert"))
Users can also combine LoRA + QAT to speed up training by1.89x compared to vanilla QAT using thisfine-tuning recipe.
torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out inhttps://arxiv.org/abs/2209.05433. Withtorch.compile on, current results show throughput speedups of up to1.5x on up to 512 GPU / 405B parameter count scale (details):
fromtorchao.float8importconvert_to_float8_trainingconvert_to_float8_training(m)
Our float8 training is integrated intoTorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support:
- Accelerating Large Scale Training and Convergence with PyTorch Float8 Rowwise on Crusoe 2K H200s
- Supercharging Training using float8 and FSDP2
- Efficient Pre-training of Llama 3-like model architectures using torchtitan on Amazon SageMaker
- Float8 in PyTorch
Other features (sparse training, memory efficient optimizers)
We've added support for semi-structured 2:4 sparsity with6% end-to-end speedups on ViT-L. Full bloghere. The code change is a 1 liner with the full example availablehere:
fromtorchao.sparsity.trainingimportSemiSparseLinear,swap_linear_with_semi_sparse_linearswap_linear_with_semi_sparse_linear(model, {"seq.0":SemiSparseLinear})
Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:
1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision
fromtorchao.optimimportAdamW8bit,AdamW4bit,AdamWFp8optim=AdamW8bit(model.parameters())# replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailedbenchmarks here.
2. CPU offloading: Move optimizer state and gradients to CPU memory
For maximum memory savings, we supportsingle GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach canreduce your VRAM requirements by 60% with minimal impact on training speed:
optim=CPUOffloadOptimizer(model.parameters(),torch.optim.AdamW,fused=True)optim.load_state_dict(ckpt["optim"])
TorchAO is integrated into some of the leading open-source libraries including:
- Unsloth now supports QAT:Read blog andguide.
- HuggingFace transformers with abuiltin inference backend andlow bit optimizers
- HuggingFace diffusers best practices with
torch.compileand TorchAO in a standalone repodiffusers-torchao - vLLM for LLM serving:usage,detailed docs
- Integration withFBGEMM for SOTA kernels on server GPUs
- Integration withExecuTorch for edge device deployment
- Axolotl forQAT andPTQ
- TorchTitan forfloat8 pre-training
- HuggingFace PEFT for LoRA using TorchAO as theirquantization backend
- TorchTune for our NF4QLoRA,QAT, andfloat8 quantized fine-tuning recipes
- SGLang for LLM serving:usage
- Keynote talk at GPU MODE IRL
- Low precision dtypes at PyTorch conference
- Slaying OOMs at the Mastering LLM's course
- Advanced Quantization at CUDA MODE
- Chip Huyen's GPU Optimization Workshop
- Cohere for AI community talk
If you find the torchao library useful, please cite it in your work as below.
@software{torchao,title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},author={torchao},url={https://github.com/pytorch/ao},license={BSD-3-Clause},month={oct},year={2024}}
About
PyTorch native quantization and sparsity for training and inference
Topics
Resources
License
Code of conduct
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.