LoRA (Low-Rank Adaptation)#

LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that enables adapting large language models to specific tasks without modifying the original model weights. Instead of fine-tuning all parameters, LoRA introduces small trainable rank decomposition matrices that are added to existing weights during inference.

Table of Contents#

  1. Background

  2. Basic Usage

  3. Advanced Usage

  4. TRTLLM serve with LoRA

  5. TRTLLM bench with LORA

Background#

The PyTorch backend provides LoRA support, allowing you to:

  • Load and apply multiple LoRA adapters simultaneously

  • Switch between different adapters for different requests

  • Use LoRA with quantized models

  • Support both HuggingFace and NeMo LoRA formats

Basic Usage#

Single LoRA Adapter#

fromtensorrt_llmimportLLMfromtensorrt_llm.lora_managerimportLoraConfigfromtensorrt_llm.executor.requestimportLoRARequestfromtensorrt_llm.sampling_paramsimportSamplingParams# Configure LoRAlora_config=LoraConfig(lora_dir=["/path/to/lora/adapter"],max_lora_rank=8,max_loras=1,max_cpu_loras=1)# Initialize LLM with LoRA supportllm=LLM(model="/path/to/base/model",lora_config=lora_config)# Create LoRA requestlora_request=LoRARequest("my-lora-task",0,"/path/to/lora/adapter")# Generate with LoRAprompts=["Hello, how are you?"]sampling_params=SamplingParams(max_tokens=50)outputs=llm.generate(prompts,sampling_params,lora_request=[lora_request])

Multi-LoRA Support#

# Configure for multiple LoRA adapterslora_config=LoraConfig(lora_target_modules=['attn_q','attn_k','attn_v'],max_lora_rank=8,max_loras=4,max_cpu_loras=8)llm=LLM(model="/path/to/base/model",lora_config=lora_config)# Create multiple LoRA requestslora_req1=LoRARequest("task-1",0,"/path/to/adapter1")lora_req2=LoRARequest("task-2",1,"/path/to/adapter2")prompts=["Translate to French: Hello world","Summarize: This is a long document..."]# Apply different LoRAs to different promptsoutputs=llm.generate(prompts,sampling_params,lora_request=[lora_req1,lora_req2])

Advanced Usage#

LoRA with Quantization#

fromtensorrt_llm.models.modeling_utilsimportQuantConfigfromtensorrt_llm.quantization.modeimportQuantAlgo# Configure quantizationquant_config=QuantConfig(quant_algo=QuantAlgo.FP8,kv_cache_quant_algo=QuantAlgo.FP8)# LoRA works with quantized modelsllm=LLM(model="/path/to/model",quant_config=quant_config,lora_config=lora_config)

NeMo LoRA Format#

# For NeMo-format LoRA checkpointslora_config=LoraConfig(lora_dir=["/path/to/nemo/lora"],lora_ckpt_source="nemo",max_lora_rank=8)lora_request=LoRARequest("nemo-task",0,"/path/to/nemo/lora",lora_ckpt_source="nemo")

Cache Management#

fromtensorrt_llm.llmapi.llm_argsimportPeftCacheConfig# Fine-tune cache sizespeft_cache_config=PeftCacheConfig(host_cache_size=1024*1024*1024,# 1GB CPU cachedevice_cache_percent=0.1# 10% of free GPU memory)llm=LLM(model="/path/to/model",lora_config=lora_config,peft_cache_config=peft_cache_config)

TRTLLM serve with LoRA#

YAML Configuration#

Create anextra_llm_api_options.yaml file:

lora_config:lora_target_modules:['attn_q','attn_k','attn_v']max_lora_rank:8

Starting the Server#

python-mtensorrt_llm.commands.serve/path/to/model\--extra_llm_api_optionsextra_llm_api_options.yaml

Client Usage#

importopenaiclient=openai.OpenAI(base_url="http://localhost:8000/v1",api_key="dummy")response=client.completions.create(model="/path/to/model",prompt="What is the capital city of France?",max_tokens=20,extra_body={"lora_request":{"lora_name":"lora-example-0","lora_int_id":0,"lora_path":"/path/to/lora_adapter"}},)

TRTLLM bench with LORA#

YAML Configuration#

Create anextra_llm_api_options.yaml file:

lora_config:lora_dir:-/workspaces/tensorrt_llm/loras/0max_lora_rank:64max_loras:8max_cpu_loras:8lora_target_modules:-attn_q-attn_k-attn_vtrtllm_modules_to_hf_modules:attn_q:q_projattn_k:k_projattn_v:v_proj

Run trtllm-bench#

trtllm-bench--model$model_paththroughput--dataset$dataset_path--extra_llm_api_optionsextra_llm_api_options.yaml--num_requests64--concurrency16