Generate text with multiple LoRA adapters#

SourceNVIDIA/TensorRT-LLM.

 1 2importargparse 3fromtypingimportOptional 4 5fromhuggingface_hubimportsnapshot_download 6 7fromtensorrt_llmimportLLM 8fromtensorrt_llm.executorimportLoRARequest 9fromtensorrt_llm.lora_helperimportLoraConfig101112defmain(chatbot_lora_dir:Optional[str],mental_health_lora_dir:Optional[str],13tarot_lora_dir:Optional[str]):1415# Download the LoRA adapters from huggingface hub, if not provided via command line args.16ifchatbot_lora_dirisNone:17chatbot_lora_dir=snapshot_download(18repo_id="snshrivas10/sft-tiny-chatbot")19ifmental_health_lora_dirisNone:20mental_health_lora_dir=snapshot_download(21repo_id=22"givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")23iftarot_lora_dirisNone:24tarot_lora_dir=snapshot_download(25repo_id="barissglc/tinyllama-tarot-v1")2627# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.28# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.29lora_config=LoraConfig(lora_dir=[chatbot_lora_dir],30max_lora_rank=64,31max_loras=3,32max_cpu_loras=3)33llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",34lora_config=lora_config)3536# Sample prompts37prompts=[38"Hello, tell me a story: ",39"Hello, tell me a story: ",40"I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",41"I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",42"In this reading, the Justice card represents a situation where",43"In this reading, the Justice card represents a situation where",44]4546# At runtime, multiple LoRA adapters can be specified via lora_request; None means no LoRA used.47foroutputinllm.generate(prompts,48lora_request=[49None,50LoRARequest("chatbot",1,chatbot_lora_dir),51None,52LoRARequest("mental-health",2,53mental_health_lora_dir),None,54LoRARequest("tarot",3,tarot_lora_dir)55]):56prompt=output.prompt57generated_text=output.outputs[0].text58print(f"Prompt:{prompt!r}, Generated text:{generated_text!r}")5960# Got output like61# Prompt: 'Hello, tell me a story: ', Generated text: '1. Start with a question: "What\'s your favorite color?" 2. Ask a question that leads to a story: "What\'s your'62# Prompt: 'Hello, tell me a story: ', Generated text: '1. A person is walking down the street. 2. A person is sitting on a bench. 3. A person is reading a book.'63# Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (smiling) No, I'm just feeling a bit overwhelmed lately. I've been trying to"64# Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (sighs) Yeah, I've been struggling with some personal issues. I've been feeling like I'm"65# Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. The card suggests that you should take the time to consider all the options'66# Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. It is important to take the time to consider all the options and make'676869if__name__=='__main__':70parser=argparse.ArgumentParser(71description="Generate text with multiple LoRA adapters")72parser.add_argument('--chatbot_lora_dir',73type=str,74default=None,75help='Path to the chatbot LoRA directory')76parser.add_argument('--mental_health_lora_dir',77type=str,78default=None,79help='Path to the mental health LoRA directory')80parser.add_argument('--tarot_lora_dir',81type=str,82default=None,83help='Path to the tarot LoRA directory')84args=parser.parse_args()85main(args.chatbot_lora_dir,args.mental_health_lora_dir,86args.tarot_lora_dir)