Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/aoPublic

PyTorch native quantization and sparsity for training and inference

License

NotificationsYou must be signed in to change notification settings

pytorch/ao

Repository files navigation

PyTorch-Native Training-to-Serving Model Optimization

  • Pre-train Llama-3.1-70B1.5x faster with float8 training
  • Recover67% of quantized accuracy degradation on Gemma3-4B with QAT
  • Quantize Llama-3-8B to int4 for1.89x faster inference with58% less memory

📣 Latest News

Older news

🌅 Overview

TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box withtorch.compile() andFSDP2 across most HuggingFace PyTorch models.

Stable Workflows

🟢 = stable, 🟡 = prototype, 🟠 = planned, ⚪ = not supported

recommended hardwareweightactivationquantized trainingQATPTQ data algorithmsquantized inference
H100, B200 GPUsfloat8 rowwisefloat8 rowwise🟢(link)🟢(link)🟢(link)
H100 GPUsint4float8 rowwise🟢(link)🟠🟢(link)
A100 GPUsint4bfloat16🟢(link)🟡:HQQ,AWQ,GPTQ🟢(link)
A100 GPUsint8bfloat16🟢(link)🟢(link)
A100 GPUsint8int8🟡(link)🟢(link)🟢(link)
edgeintx (1..7)bfloat16🟢(link)🟢(link)
edgeintx (1..7)bfloat16🟢(link)🟢(link)

Prototype Workflows

🟢 = stable, 🟡 = prototype, 🟠 = planned, ⚪ = not supported

recommended hardwareweightactivationquantized trainingQATPTQ data algorithmsquantized inference
B200, MI350x GPUsmxfp8mxfp8🟡(dense),(moe)🟡(link)
B200 GPUsnvfp4nvfp4🟠🟡(link)🟡(link)
B200, MI350x GPUsmxfp4mxfp4⚪ not supported🟠🟠🟡(link)
H100float8 128x128 (blockwise)float8 1x128🟠🟡

Other

Check out ourdocs for more details!

🚀 Quick Start

First, install TorchAO. We recommend installing the latest stable version:

pip install torchao

Quantize your model weights to int4!

fromtorchao.quantizationimportInt4WeightOnlyConfig,quantize_quantize_(model,Int4WeightOnlyConfig(group_size=32,int4_packing_format="tile_packed_to_4d",int4_choose_qparams_algorithm="hqq"))

See ourquick start guide for more details.

🛠 Installation

To install the latest stable version:

pip install torchao
Other installation options
# Nightlypip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128# Different CUDA versionspip install torchao --index-url https://download.pytorch.org/whl/cu126  # CUDA 12.6pip install torchao --index-url https://download.pytorch.org/whl/cu129  # CUDA 12.9pip install torchao --index-url https://download.pytorch.org/whl/cpu    # CPU only# For developers# Note: the `--no-build-isolation` flag is required.USE_CUDA=1 pip install -e . --no-build-isolationUSE_CPP=0 pip install -e . --no-build-isolation

Please see thetorchao compability table for version requirements for dependencies.

🔎 Inference

TorchAO delivers substantial performance gains with minimal code changes:

Following is our recommended flow for quantization and deployment:

fromtransformersimportTorchAoConfig,AutoModelForCausalLMfromtorchao.quantizationimportFloat8DynamicActivationFloat8WeightConfig,PerRow# Create quantization configurationquantization_config=TorchAoConfig(quant_type=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))# Load and automatically quantizequantized_model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-32B",dtype="auto",device_map="auto",quantization_config=quantization_config)

Alternative quantization API to use when the above doesn't work isquantize_ API inquick start guide.

Serving with vllm on 1xH100 machine:

# ServerVLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Qwen3-32B-FP8 --tokenizer Qwen/Qwen3-32B -O3
# Clientcurl http://localhost:8000/v1/chat/completions -H"Content-Type: application/json" -d'{  "model": "pytorch/Qwen3-32B-FP8",  "messages": [    {"role": "user", "content": "Give me a short introduction to large language models."}  ],  "temperature": 0.6,  "top_p": 0.95,  "top_k": 20,  "max_tokens": 32768}'

We also support deployment to edge devices through ExecuTorch, for more detail, seequantization and serving guide. We also release pre-quantized modelshere.

🚅 Training

Quantization-Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration withTorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to theQAT README and theoriginal blog:

importtorchfromtorchao.quantizationimportquantize_,Int8DynamicActivationIntxWeightConfig,PerGroupfromtorchao.quantization.qatimportQATConfig# preparebase_config=Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4,weight_granularity=PerGroup(32),)quantize_(my_model,QATConfig(base_config,step="prepare"))# train model (not shown)# convertquantize_(my_model,QATConfig(base_config,step="convert"))

Users can also combine LoRA + QAT to speed up training by1.89x compared to vanilla QAT using thisfine-tuning recipe.

Quantized training

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out inhttps://arxiv.org/abs/2209.05433. Withtorch.compile on, current results show throughput speedups of up to1.5x on up to 512 GPU / 405B parameter count scale (details):

fromtorchao.float8importconvert_to_float8_trainingconvert_to_float8_training(m)

Our float8 training is integrated intoTorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support:

Other features (sparse training, memory efficient optimizers)

Sparse Training

We've added support for semi-structured 2:4 sparsity with6% end-to-end speedups on ViT-L. Full bloghere. The code change is a 1 liner with the full example availablehere:

fromtorchao.sparsity.trainingimportSemiSparseLinear,swap_linear_with_semi_sparse_linearswap_linear_with_semi_sparse_linear(model, {"seq.0":SemiSparseLinear})

Memory-efficient optimizers

Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:

1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision

fromtorchao.optimimportAdamW8bit,AdamW4bit,AdamWFp8optim=AdamW8bit(model.parameters())# replace with Adam4bit and AdamFp8 for the 4 / fp8 versions

Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailedbenchmarks here.

2. CPU offloading: Move optimizer state and gradients to CPU memory

For maximum memory savings, we supportsingle GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach canreduce your VRAM requirements by 60% with minimal impact on training speed:

optim=CPUOffloadOptimizer(model.parameters(),torch.optim.AdamW,fused=True)optim.load_state_dict(ckpt["optim"])

🔗 Integrations

TorchAO is integrated into some of the leading open-source libraries including:

🎥 Videos

💬 Citation

If you find the torchao library useful, please cite it in your work as below.

@software{torchao,title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},author={torchao},url={https://github.com/pytorch/ao},license={BSD-3-Clause},month={oct},year={2024}}

[8]ページ先頭

©2009-2025 Movatter.jp