Speculative Decoding#

SourceNVIDIA/TensorRT-LLM.

 1fromtypingimportOptional 2 3importclick 4 5fromtensorrt_llmimportLLM,SamplingParams 6fromtensorrt_llm.llmapiimport(EagleDecodingConfig,KvCacheConfig, 7MTPDecodingConfig,NGramDecodingConfig) 8 9prompts=[10"What is the capital of France?",11"What is the future of AI?",12]131415defrun_MTP(model:Optional[str]=None):16spec_config=MTPDecodingConfig(num_nextn_predict_layers=1,17use_relaxed_acceptance_for_thinking=True,18relaxed_topk=10,19relaxed_delta=0.01)2021llm=LLM(22# You can change this to a local model path if you have the model downloaded23model=modelor"nvidia/DeepSeek-R1-FP4",24speculative_config=spec_config,25)2627forpromptinprompts:28response=llm.generate(prompt,SamplingParams(max_tokens=10))29print(response.outputs[0].text)303132defrun_Eagle3():33spec_config=EagleDecodingConfig(34max_draft_len=3,35speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",36eagle3_one_model=True)3738kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.8)3940llm=LLM(41model="meta-llama/Llama-3.1-8B-Instruct",42speculative_config=spec_config,43kv_cache_config=kv_cache_config,44)4546forpromptinprompts:47response=llm.generate(prompt,SamplingParams(max_tokens=10))48print(response.outputs[0].text)495051defrun_ngram():52spec_config=NGramDecodingConfig(53max_draft_len=3,54max_matching_ngram_size=3,55is_keep_all=True,56is_use_oldest=True,57is_public_pool=True,58)5960llm=LLM(61model="meta-llama/Llama-3.1-8B-Instruct",62speculative_config=spec_config,63# ngram doesn't work with overlap_scheduler64disable_overlap_scheduler=True,65)6667forpromptinprompts:68response=llm.generate(prompt,SamplingParams(max_tokens=10))69print(response.outputs[0].text)707172@click.command()73@click.argument("algo",74type=click.Choice(["MTP","EAGLE3","DRAFT_TARGET","NGRAM"]))75@click.option("--model",76type=str,77default=None,78help="Path to the model or model name.")79defmain(algo:str,model:Optional[str]=None):80algo=algo.upper()81ifalgo=="MTP":82run_MTP(model)83elifalgo=="EAGLE3":84run_Eagle3()85elifalgo=="NGRAM":86run_ngram()87else:88raiseValueError(f"Invalid algorithm:{algo}")899091if__name__=="__main__":92main()