Sampling Techniques Showcase#
SourceNVIDIA/TensorRT-LLM.
1""" 2This example demonstrates various sampling techniques available in TensorRT-LLM. 3It showcases different sampling parameters and their effects on text generation. 4""" 5 6fromtypingimportOptional 7 8importclick 9 10fromtensorrt_llmimportLLM,SamplingParams 11 12# Example prompts to demonstrate different sampling techniques 13prompts=[ 14"What is the future of artificial intelligence?", 15"Describe a beautiful sunset over the ocean.", 16"Write a short story about a robot discovering emotions.", 17] 18 19 20defdemonstrate_greedy_decoding(prompt:str): 21"""Demonstrates greedy decoding with temperature=0.""" 22print("\n🎯 === GREEDY DECODING ===") 23print("Using temperature=0 for deterministic, focused output") 24 25llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") 26 27sampling_params=SamplingParams( 28max_tokens=50, 29temperature=0.0,# Greedy decoding 30) 31 32response=llm.generate(prompt,sampling_params) 33print(f"Prompt:{prompt}") 34print(f"Response:{response.outputs[0].text}") 35 36 37defdemonstrate_temperature_sampling(prompt:str): 38"""Demonstrates temperature sampling with different temperature values.""" 39print("\n🌡️ === TEMPERATURE SAMPLING ===") 40print( 41"Higher temperature = more creative/random, Lower temperature = more focused" 42) 43 44llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") 45 46temperatures=[0.3,0.7,1.0,1.5] 47fortempintemperatures: 48 49sampling_params=SamplingParams( 50max_tokens=50, 51temperature=temp, 52) 53 54response=llm.generate(prompt,sampling_params) 55print(f"Temperature{temp}:{response.outputs[0].text}") 56 57 58defdemonstrate_top_k_sampling(prompt:str): 59"""Demonstrates top-k sampling with different k values.""" 60print("\n🔝 === TOP-K SAMPLING ===") 61print("Only consider the top-k most likely tokens at each step") 62 63llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") 64 65top_k_values=[1,5,20,50] 66 67forkintop_k_values: 68sampling_params=SamplingParams( 69max_tokens=50, 70temperature=0.8,# Use moderate temperature 71top_k=k, 72) 73 74response=llm.generate(prompt,sampling_params) 75print(f"Top-k{k}:{response.outputs[0].text}") 76 77 78defdemonstrate_top_p_sampling(prompt:str): 79"""Demonstrates top-p (nucleus) sampling with different p values.""" 80print("\n🎯 === TOP-P (NUCLEUS) SAMPLING ===") 81print("Only consider tokens whose cumulative probability is within top-p") 82 83llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") 84 85top_p_values=[0.1,0.5,0.9,0.95] 86 87forpintop_p_values: 88sampling_params=SamplingParams( 89max_tokens=50, 90temperature=0.8,# Use moderate temperature 91top_p=p, 92) 93 94response=llm.generate(prompt,sampling_params) 95print(f"Top-p{p}:{response.outputs[0].text}") 96 97 98defdemonstrate_combined_sampling(prompt:str): 99"""Demonstrates combined top-k and top-p sampling."""100print("\n🔄 === COMBINED TOP-K + TOP-P SAMPLING ===")101print("Using both top-k and top-p together for balanced control")102103llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")104105sampling_params=SamplingParams(106max_tokens=50,107temperature=0.8,108top_k=40,# Consider top 40 tokens109top_p=0.9,# Within 90% cumulative probability110)111112response=llm.generate(prompt,sampling_params)113print(f"Combined (k=40, p=0.9):{response.outputs[0].text}")114115116defdemonstrate_multiple_sequences(prompt:str):117"""Demonstrates generating multiple sequences with different sampling."""118print("\n📚 === MULTIPLE SEQUENCES ===")119print("Generate multiple different responses for the same prompt")120121llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")122123sampling_params=SamplingParams(124max_tokens=40,125temperature=0.8,126top_k=50,127top_p=0.95,128n=3,# Generate 3 different sequences129)130131response=llm.generate(prompt,sampling_params)132print(f"Prompt:{prompt}")133fori,outputinenumerate(response.outputs):134print(f"Sequence{i+1}:{output.text}")135136137defdemonstrate_with_logprobs(prompt:str):138"""Demonstrates generation with log probabilities."""139print("\n📊 === GENERATION WITH LOG PROBABILITIES ===")140print("Get probability information for generated tokens")141142llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")143144sampling_params=SamplingParams(145max_tokens=20,146temperature=0.7,147top_k=50,148logprobs=True,# Return log probabilities149)150151response=llm.generate(prompt,sampling_params)152output=response.outputs[0]153154print(f"Prompt:{prompt}")155print(f"Generated:{output.text}")156print(f"Logprobs:{output.logprobs}")157158159defrun_all_demonstrations(model_path:Optional[str]=None):160"""Run all sampling demonstrations."""161print("🚀 TensorRT LLM Sampling Techniques Showcase")162print("="*50)163164# Use the first prompt for most demonstrations165demo_prompt=prompts[0]166167# Run all demonstrations168demonstrate_greedy_decoding(demo_prompt)169demonstrate_temperature_sampling(demo_prompt)170demonstrate_top_k_sampling(demo_prompt)171demonstrate_top_p_sampling(demo_prompt)172demonstrate_combined_sampling(demo_prompt)173# TODO[Superjomn]: enable them once pytorch backend supports174# demonstrate_multiple_sequences(llm, demo_prompt)175# demonstrate_beam_search(demo_prompt)176demonstrate_with_logprobs(demo_prompt)177178print("\n🎉 All sampling demonstrations completed!")179180181@click.command()182@click.option("--model",183type=str,184default=None,185help="Path to the model or model name")186@click.option("--demo",187type=click.Choice([188"greedy","temperature","top_k","top_p","combined",189"multiple","beam","logprobs","creative","all"190]),191default="all",192help="Which demonstration to run")193@click.option("--prompt",type=str,default=None,help="Custom prompt to use")194defmain(model:Optional[str],demo:str,prompt:Optional[str]):195"""196 Showcase various sampling techniques in TensorRT-LLM.197198 Examples:199 python llm_sampling.py --demo all200 python llm_sampling.py --demo temperature --prompt "Tell me a joke"201 python llm_sampling.py --demo beam --model path/to/your/model202 """203204demo_prompt=promptorprompts[0]205206# Run specific demonstration207ifdemo=="greedy":208demonstrate_greedy_decoding(demo_prompt)209elifdemo=="temperature":210demonstrate_temperature_sampling(demo_prompt)211elifdemo=="top_k":212demonstrate_top_k_sampling(demo_prompt)213elifdemo=="top_p":214demonstrate_top_p_sampling(demo_prompt)215elifdemo=="combined":216demonstrate_combined_sampling(demo_prompt)217elifdemo=="multiple":218demonstrate_multiple_sequences(demo_prompt)219elifdemo=="logprobs":220demonstrate_with_logprobs(demo_prompt)221elifdemo=="all":222run_all_demonstrations(model)223224225if__name__=="__main__":226main()