|
23 | 23 | fromtensorrt_llm.executorimportCppExecutorError |
24 | 24 | fromtensorrt_llm.executor.postproc_workerimportPostprocParams |
25 | 25 | fromtensorrt_llm.inputsimportprompt_inputs |
| 26 | +fromtensorrt_llm.inputs.dataimportTokensPrompt |
26 | 27 | fromtensorrt_llm.inputs.utilsimportConversationMessage,apply_chat_template |
27 | 28 | fromtensorrt_llm.llmapiimportDisaggregatedParamsasLlmDisaggregatedParams |
28 | 29 | fromtensorrt_llm.llmapiimportMultimodalEncoder |
@@ -677,8 +678,16 @@ async def generator_wrapper(generator: AsyncIterator[Any]): |
677 | 678 | ifrequest.streamelsecompletion_response_post_processor, |
678 | 679 | postproc_args=postproc_args, |
679 | 680 | ) |
| 681 | + |
| 682 | +prompt=prompt_inputs(prompt) |
| 683 | +ifprompt.get("prompt")isnotNone: |
| 684 | +prompt_token_ids,extra_processed_inputs=awaitasyncio.to_thread(self.llm.input_processor,prompt,sampling_params) |
| 685 | +tokens_prompt=TokensPrompt(prompt_token_ids=prompt_token_ids,query_token_ids=extra_processed_inputs.get("query_token_ids")ifextra_processed_inputsisnotNoneelseNone) |
| 686 | +else: |
| 687 | +tokens_prompt=prompt |
| 688 | + |
680 | 689 | promise=self.llm.generate_async( |
681 | | -inputs=prompt, |
| 690 | +inputs=tokens_prompt, |
682 | 691 | sampling_params=sampling_params, |
683 | 692 | _postproc_params=postproc_params, |
684 | 693 | streaming=request.stream, |
|