Control generated text using logits processor#

SourceNVIDIA/TensorRT-LLM.

  1fromtypingimportList,Optional  2  3importtorch  4fromtransformersimportPreTrainedTokenizer  5  6fromtensorrt_llmimportLLM  7fromtensorrt_llm.sampling_paramsimportLogitsProcessor,SamplingParams  8  9 10deftext_to_token(tokenizer:PreTrainedTokenizer,text:str,last:bool): 11tokens=tokenizer.encode(text,add_special_tokens=False) 12 13max_token_count=1 14bos_token_added=getattr(tokenizer,'bos_token',None)andgetattr( 15tokenizer,'bos_token_id',None)intokens 16prefix_token_added=getattr(tokenizer,'add_prefix_space', 17None)isnotFalse 18ifbos_token_addedorprefix_token_added: 19max_token_count=2 20 21ifnotlastandlen(tokens)>max_token_count: 22raiseException( 23f"Can't convert{text} to token. It has{len(tokens)} tokens.") 24 25returntokens[-1] 26 27 28# The recommended way to create a customized logits processor: 29#     * Subclass LogitsProcessor and implement the processing logics in the __call__ method. 30#     * Create an instance and pass to SamplingParams. 31# More LogitsProcessors references can be found at https://github.com/NVIDIA/logits-processor-zoo. 32classGenLengthLogitsProcessor(LogitsProcessor): 33""" 34    A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token 35    based on the length of the generated sequence, encouraging or discouraging shorter answers. 36    WARNING: Create a new object before every model.generate call since token_count is accumulated. 37 38    Parameters 39    ---------- 40    tokenizer: The tokenizer used by the LLM. 41    boost_factor (float): A factor to boost the likelihood of the EOS token as the sequence length increases. 42                        Suggested value range is [-1.0, 1.0]. Negative values are used for the opposite effect. 43    p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2. 44    complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop 45                                        or a new line. Default is False. 46 47    """ 48 49def__init__(self, 50tokenizer, 51boost_factor:float, 52p:int=2, 53complete_sentences:bool=False): 54self.eos_token=tokenizer.eos_token_id 55self.boost_factor=boost_factor 56self.p=p 57self.token_count=0 58self.full_stop_token=text_to_token(tokenizer, 59"It is a sentence.", 60last=True) 61self.new_line_token=text_to_token(tokenizer, 62"It is a new line\n", 63last=True) 64self.complete_sentences=complete_sentences 65 66def__call__(self,req_ids:int,logits:torch.Tensor,ids:List[List[int]], 67stream_ptr,client_id:Optional[int]): 68boost_val=self.boost_factor*(self.token_count**self.p)/(10** 69self.p) 70 71stream=Noneifstream_ptrisNoneelsetorch.cuda.ExternalStream( 72stream_ptr) 73 74withtorch.cuda.stream(stream): 75ids=torch.LongTensor(ids).to(logits.device,non_blocking=True) 76 77ifself.complete_sentences: 78enabled=(ids[:,-1]==self.full_stop_token)|( 79ids[:,-1]==self.new_line_token) 80logits[:,:,self.eos_token]+=enabled*boost_val 81else: 82logits[:,:,self.eos_token]+=boost_val 83 84self.token_count+=1 85 86 87defmain(): 88 89llm=LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") 90 91# Sample prompts 92prompts=[ 93"The future of AI is", 94"The future of AI is", 95] 96 97# Generate text 98forprompt_id,promptinenumerate(prompts): 99ifprompt_id%2==0:100# Without logit processor101sampling_params=SamplingParams(top_p=1,max_tokens=200)102else:103# Each prompt can be specified with a logits processor at runtime104sampling_params=SamplingParams(105temperature=0.8,106top_p=0.95,107logits_processor=GenLengthLogitsProcessor(108llm.tokenizer,boost_factor=1,complete_sentences=True))109110output=llm.generate(prompt,sampling_params)111print(112f"Prompt:{output.prompt!r}, Generated text:{output.outputs[0].text!r}"113)114115# Got output like:116# Prompt (original): "bright, and it's not just for big companies. Small businesses can also benefit from AI technology. Here are some ways:\n\n1. Improved customer service: AI can help businesses provide better customer service by analyzing customer data and providing personalized recommendations.117#                    This can help businesses improve their customer experience and increase customer loyalty.\n\n2. Increased productivity: AI can help businesses automate repetitive tasks, freeing up employees to focus on more complex tasks. This can118#                    help businesses increase productivity and reduce costs.\n\n3. Enhanced marketing: AI can help businesses create more personalized marketing campaigns by analyzing customer data and targeting specific audiences. This can help businesses119#                    increase their marketing ROI and drive more sales.\n\n4. Improved supply chain management: AI can help businesses optimize their supply chain by analyzing data on demand,"'120#121# Prompt (with GenLenthLogitsProcesor): "bright, and it's not just for big companies. Small businesses can also benefit from AI technology."122123124if__name__=='__main__':125main()