Sampling#

The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability, guided decoding and logits processors

General usage#

To use the feature:

  1. Enable theenable_trtllm_sampler option in theLLM class

  2. Pass aSamplingParams object with the desired options to thegenerate() function

The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:

fromtensorrt_llmimportLLM,SamplingParamsllm=LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',enable_trtllm_sampler=True)sampling_params=SamplingParams(temperature=1.0,top_k=8,top_p=0.5,)llm.generate(["Hello, my name is","Hello, my name is"],sampling_params)

Note: Theenable_trtllm_sampler option is not currently supported when using speculative decoders, such as MTP or Eagle-3, so there is a smaller subset of sampling options available.

Beam search#

Beam search is a decoding strategy that maintains multiple candidate sequences (beams) during text generation, exploring different possible continuations to find higher quality outputs. Unlike greedy decoding or sampling, beam search considers multiple hypotheses simultaneously.

To enable beam search, you must:

  1. Enable theuse_beam_search option in theSamplingParams object

  2. Set themax_beam_width parameter in theLLM class to match thebest_of parameter inSamplingParams

  3. Disable overlap scheduling using thedisable_overlap_scheduler parameter of theLLM class

  4. Disable the usage of CUDA Graphs by passingNone to thecuda_graph_config parameter of theLLM class

Parameter Configuration:

  • best_of: Controls the number of beams processed during generation (beam width)

  • n: Controls the number of output sequences returned (can be less thanbest_of)

  • Ifbest_of is omitted, the number of beams processed defaults ton

  • max_beam_width in theLLM class must equalbest_of inSamplingParams

The following example demonstrates beam search with a beam width of 4, returning the top 3 sequences:

fromtensorrt_llmimportLLM,SamplingParamsllm=LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',enable_trtllm_sampler=True,max_beam_width=4,# must equal SamplingParams.best_ofdisable_overlap_scheduler=True,cuda_graph_config=None)sampling_params=SamplingParams(best_of=4,# must equal LLM.max_beam_widthuse_beam_search=True,n=3,# return top 3 sequences)llm.generate(["Hello, my name is","Hello, my name is"],sampling_params)

Guided decoding#

Guided decoding controls the generation outputs to conform to pre-defined structured formats, ensuring outputs follow specific schemas or patterns.

The PyTorch backend supports guided decoding with the XGrammar and Low-level Guidance (llguidance) backends and the following formats:

  • JSON schema

  • JSON object

  • Regular expressions

  • Extended Backus-Naur form (EBNF) grammar

  • Structural tags

To enable guided decoding, you must:

  1. Set theguided_decoding_backend parameter to'xgrammar' or'llguidance' in theLLM class

  2. Create aGuidedDecodingParams object with the desired format specification

    • Note: Depending on the type of format, a different parameter needs to be chosen to construct the object (json,regex,grammar,structural_tag).

  3. Pass theGuidedDecodingParams object to theguided_decoding parameter of theSamplingParams object

The following example demonstrates guided decoding with a JSON schema:

fromtensorrt_llmimportLLM,SamplingParamsfromtensorrt_llm.llmapiimportGuidedDecodingParamsllm=LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',guided_decoding_backend='xgrammar')structure='{"title": "Example JSON", "type": "object", "properties": {...}}'guided_decoding_params=GuidedDecodingParams(json=structure)sampling_params=SamplingParams(guided_decoding=guided_decoding_params,)llm.generate("Generate a JSON response",sampling_params)

You can find a more detailed example on guided decodinghere.

Logits processor#

Logits processors allow you to modify the logits produced by the network before sampling, enabling custom generation behavior and constraints.

To use a custom logits processor:

  1. Create a custom class that inherits fromLogitsProcessor and implements the__call__ method

  2. Pass an instance of this class to thelogits_processor parameter ofSamplingParams

The following example demonstrates logits processing:

importtorchfromtypingimportList,Optionalfromtensorrt_llmimportLLM,SamplingParamsfromtensorrt_llm.sampling_paramsimportLogitsProcessorclassMyCustomLogitsProcessor(LogitsProcessor):def__call__(self,req_id:int,logits:torch.Tensor,token_ids:List[List[int]],stream_ptr:Optional[int],client_id:Optional[int])->None:# Implement your custom inplace logits processing logiclogits*=logitsllm=LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')sampling_params=SamplingParams(logits_processor=MyCustomLogitsProcessor())llm.generate(["Hello, my name is"],sampling_params)

You can find a more detailed example on logits processorshere.