Integration with VLLM: Architecture and Usage Guide#
Created On: Dec 18, 2025 | Last Updated On: Dec 18, 2025
This tutorial provides a comprehensive overview of how TorchAO integrates with VLLM, and what needs to be implemented to have a new technique work E2E.
Configuration System#
1. HuggingFace Model Configuration#
TorchAO quantization is configured through the model’sconfig.json file:
{"model_type":"llama","quant_type":{"default":{"_type":"Int4WeightOnlyConfig","_data":{"group_size":128,"use_hqq":true}}}}
2. TorchAO Configuration Classes#
All quantization methods inherit fromAOBaseConfig:
fromtorchao.core.configimportAOBaseConfigfromtorchao.quantizationimportInt4WeightOnlyConfig# Example configurationconfig=Int4WeightOnlyConfig(group_size=128,use_hqq=True,version=1,)assertisinstance(config,AOBaseConfig)
Note
All quantization configurations inherit fromtorchao.core.config.AOBaseConfig, which provides serialization and validation capabilities.
3. FQN Configuration#
For granular control, useFqnToConfig:
fromtorchao.quantizationimportFqnToConfig,Int4WeightOnlyConfig,Int8WeightOnlyConfigconfig=FqnToConfig({"model.layers.0.self_attn.q_proj":Int4WeightOnlyConfig(group_size=64),"model.layers.0.self_attn.k_proj":Int4WeightOnlyConfig(group_size=64),"model.layers.0.mlp.gate_proj":Int8WeightOnlyConfig(),"_default":Int4WeightOnlyConfig(group_size=128,version=1)# Default for other modules})
Usage Examples#
1. Quantizing Models with HuggingFace Integration#
fromtransformersimportTorchAoConfig,AutoModelForCausalLMfromtorchao.quantizationimportInt4WeightOnlyConfig# Create quantization configurationquantization_config=TorchAoConfig(quant_type=Int4WeightOnlyConfig(group_size=128,use_hqq=True,version=1))# Load and automatically quantize the modelmodel=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B",dtype="auto",device_map="auto",quantization_config=quantization_config)# Save quantized model (see Serialization section below for safe_serialization details)model.push_to_hub("your-username/Llama-3.2-1B-int4",safe_serialization=False)
See also
For more information on quantization configs, seetorchao.quantization.Int4WeightOnlyConfig andtorchao.quantization.Int8WeightOnlyConfig.
2. Serving with VLLM#
# Start VLLM server with TorchAO quantized modelvllmserveyour-username/Llama-3.2-1B-int4\--quantizationtorchao\--dtypebfloat16\
Adding New Quantization Methods to VLLM#
Minimal Requirements for VLLM Compatibility#
To make a new TorchAO quantization method work with VLLM, you need to implement minimal tensor subclass operations that supporttensor parallelism. VLLM usesnarrow() andcopy_() to move data from host cpu loaded in a state dict to the device, these require these specific aten operations:
Why these ?#
VLLM’s tensor parallelism works by:
narrow()- Slicing weight tensors across different dimensionsSharding - Distributing tensor chunks across multiple GPUs
copy_()- Moving tensor data between devices
A helpful pattern for doing this is_apply_fn_to_data, a method that applies a given function to all the attributes on your class w/ Tensor types. Below is a generic implementation that should work for most subclasses. We make heavy use of this pattern in the torchao codebase:
def_apply_fn_to_data(self,fn:Callable):"""Applies a fn to all tensor components stored on this class"""tensor_names,ctx=self.__tensor_flatten__()# Apply the function to each tensor componentnew_tensors={}fornameintensor_names:new_tensors[name]=fn(getattr(self,name))returnself.__class__.__tensor_unflatten__(new_tensors,ctx,None,# outer_size parameterNone,# outer_stride parameter)
Step-by-Step Guide to Add a New Quantization Method#
1. Create Your Tensor Subclass#
Note
For more details on tensor subclasses and their design principles, please refer to theWhat are Tensor Subclasses? documentation.
fromtorchao.core.configimportAOBaseConfigfromtorchao.utilsimportTorchAOBaseTensor@dataclassclassMyNewQuantConfig(AOBaseConfig):"""Configuration for your new quantization method"""bits:int=8VERSION:ClassVar[int]=1classMyQuantizedTensor(TorchAOBaseTensor):"""Example based on Float8Tensor - stores quantized data + scale"""tensor_data_attrs=["quantized_data","scale"]tensor_attributes=["dtype"]def__new__(cls,quantized_data,scale,dtype):shape=quantized_data.shapereturntorch.Tensor._make_wrapper_subclass(cls,shape,device=quantized_data.device,dtype=dtype,requires_grad=False)def__init__(self,quantized_data,scale,dtype):self.quantized_data=quantized_dataself.scale=scaledef__tensor_flatten__(self)->Tuple[List[str],List]:"""Serialize tensor subclass into plain tensors and metadata"""returnself.tensor_data_attrs,[getattr(self,attr)forattrinself.tensor_attributes]@classmethoddef__tensor_unflatten__(cls,tensor_data_dict:Dict[str,torch.Tensor],tensor_attributes:List,outer_size:Optional[torch.Size],outer_stride:Optional[Tuple],)->"MyQuantizedTensor":"""Reconstruct tensor subclass from serialized data"""returncls(*[tensor_data_dict[name]fornameincls.tensor_data_attrs],*tensor_attributes,)
2. Implement Required VLLM Operations#
fromtorch.utils._python_dispatchimportreturn_and_correct_aliasing@MyQuantizedTensor.implements([aten.detach.default,aten.alias.default])def_(func,types,args,kwargs):returnreturn_and_correct_aliasing(func,args,kwargs,args[0]._apply_fn_to_data(func))@MyQuantizedTensor.implements([aten._to_copy.default])def_(func,types,args,kwargs):returnreturn_and_correct_aliasing(func,args,kwargs,args[0]._apply_fn_to_data(torch.clone))@MyQuantizedTensor.implements([aten.slice.Tensor])def_(func,types,args,kwargs):self,dim,start,end,step=fill_defaults(args,5,[0,None,None,1])ifdim==0ordim==1:# NOTE the slicing here will likely be different for different quant techniquesreturnreturn_and_correct_aliasing(func,args,kwargs,args[0]._apply_fn_to_data(lambdax:aten.slice.Tensor(x,dim,start,end,step)))else:raiseNotImplementedError(f"Slicing along dim={dim} not supported")
3. Register with TorchAO’s Quantization System#
fromtorchao.quantization.transform_moduleimportregister_quantize_module_handler@register_quantize_module_handler(MyNewQuantConfig)def_my_quant_transform(module:torch.nn.Module,config:MyNewQuantConfig):"""Transform function that applies your quantization to a module"""weight=module.weight# Your quantization logic herequantized_weight=my_quantization_function(weight,config)# Replace the weight with your quantized tensormodule.weight=torch.nn.Parameter(quantized_weight,requires_grad=False)returnmodule
Important
Thetorchao.quantization.transform_module.register_quantize_module_handler() decorator registers your config class with TorchAO’s quantization system.
Key Implementation Details#
Hardware-Specific Linear Operations#
Your quantized tensor’s forward pass determines hardware support and what actually gets called whentorch.nn.functional.linear() is called.
@MyQuantizedTensor.implements(torch.nn.functional.linear)def_(func,types,args,kwargs):input_tensor,weight_tensor,bias=args[0],args[1],args[2]iflen(args)>2elseNone# This is where you define what hardware your method supportsifhasattr(weight_tensor,'use_cutlass_kernel'):returnmy_cutlass_linear(input_tensor,weight_tensor,bias)elifhasattr(weight_tensor,'use_triton_kernel'):returnmy_triton_linear(input_tensor,weight_tensor,bias)else:# Fallback - dequantize and use standard linearreturntorch.nn.functional.linear(input_tensor,weight_tensor.dequantize(),bias)
Compilation Benefits#
The overhead of tensor subclasses disappears withtorch.compile(), this is on by default in VLLM.
Trade Off of Tensor Subclasses#
Compilation: is essential for removing subclass overhead. Without it unless your model is extremely gpu bound the overhead of dispatch on the CPU can severely impact performance.
The checkpoint defines the behavior of the model. You might be saying “don’t all checkpoints do this”. This is true, however people typically solely think of a torch.Tensor as its data. When in actuality its a true class where it brings the Dispatcher and all the kernels ATen has registered to it. When you define your tensor subclass, you are building a separate little world. One w/ a different representation of data, but also one where you need to explicitly define what ops you support and have implementations for all the hardware you want to support. This can feel a little like spooky action at a distance at first. But it can be very powerful. Case in point is being able to support TP with only 3 definitions.
Serialization and Model Sharing#
SafeTensors Support#
Current Status: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must usesafe_serialization=False.
Workaround: For production use, save models withsafe_serialization=False when pushing to HuggingFace Hub.
Future Work: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress at:pytorch/ao#2338
Integration Architecture Diagrams#
1. High-Level Model Flow: Transformers → VLLM + TorchAO#
This diagram shows the end-to-end flow from model creation to serving:
graph LR A[HuggingFace Model] --> B[Transformers AutoModel] B --> C{Quantization Config?} C -->|TorchAO Config| D[Apply TorchAO Quantization] C -->|No Config| E[Standard Model] D --> F[Quantized Model w/ Tensor Subclasses] E --> G[Standard PyTorch Model] F --> H[VLLM Model Loading] G --> H H --> I[VLLM Distributed Engine] I --> J[Tensor Parallel Sharding] J --> K[Optimized Inference] style D fill:#e1f5fe style F fill:#f3e5f5 style J fill:#e8f5e82. TorchAO Integration Points in VLLM#
This shows how VLLM detects and applies TorchAO quantization:
graph LR A[Model Config Detection] --> B{quantization=torchao?} B -->|Yes| C[TorchAOConfig.from_config] B -->|No| D[Other Quantization Methods] C --> E[Parse HF quant_type] E --> F[config_from_dict] F --> G[AOBaseConfig Instance] G --> H[get_quant_method per layer] H --> I{Layer Type?} I -->|LinearBase| J[TorchAOLinearMethod] I -->|Other| K[UnquantizedLinearMethod] J --> L[create_weights] L --> M[torchao_quantize_param_data] M --> N[Quantized Tensor Subclass] style C fill:#e1f5fe style G fill:#f3e5f5 style N fill:#e8f5e83. Kernel Dispatch: Bringing External Kernels to VLLM#
This illustrates how tensor subclasses enable custom kernel dispatch within VLLM:
graph LR A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function] B --> C[Custom implements Handler] C --> D{Hardware Check} D --> E[Dispatch to External Kernel] E --> F[Execute Optimized Kernel] F --> G[Return Result to VLLM] subgraph "External Libraries" H[TorchAO CUTLASS] I[TorchAO Triton] J[FBGEMM-GPU] K[Custom Libraries] end subgraph "Tensor Subclass Code" L[implements F.linear] M[custom_linear_impl] N[call external kernel] end E --> H E --> I E --> J E --> K C --> L L --> M M --> N N --> E style B fill:#e8f6ff,color:#000 style C fill:#fff3e0,color:#000 style E fill:#e8f5e8,color:#000 style L fill:#f3e5f5,color:#000