importatexitimportjsonimportosimportshutilimportsocketimporttempfileimporttimeimportweakreffromcollections.abcimportMappingfrompathlibimportPathfromtypingimportAny,List,Literal,Optional,Sequence,UnionimporttransformersfromtqdmimporttqdmfromtransformersimportPreTrainedTokenizerBasefromtensorrt_llm._utilsimportmpi_disabledfromtensorrt_llm.inputs.dataimportTextPromptfromtensorrt_llm.inputs.multimodalimportMultimodalInput,MultimodalParamsfromtensorrt_llm.inputs.registryimportDefaultInputProcessorfromtensorrt_llm.llmapiimporttracingfromtensorrt_llm.metrics.enumsimportMetricNamesfrom.._utilsimportnvtx_range_debugfrom..bindingsimportexecutorastllmfrom..bindingsimportsteady_clock_nowfrom..builderimportEngineConfigfrom..disaggregated_paramsimportDisaggregatedParamsfrom..executorimport(DetokenizedGenerationResultBase,GenerationExecutor,GenerationResult,IterationResult,LoRARequest,PostprocWorkerConfig,PromptAdapterRequest)from..executor.postproc_workerimportPostprocParamsfrom..executor.utilsimport(create_mpi_comm_session,get_spawn_proxy_process_env)from..inputsimport(PromptInputs,create_input_processor,create_input_processor_with_hash,get_cache_salt_id,prompt_inputs)from..loggerimportloggerfrom..sampling_paramsimportSamplingParamsfrom..scheduling_paramsimportSchedulingParamsfrom.llm_argsimport(TORCH_LLMARGS_EXPLICIT_DOCSTRING,TRT_LLMARGS_EXPLICIT_DOCSTRING,PeftCacheConfig,PybindMirror,TorchLlmArgs,TrtLlmArgs)from.llm_utilsimport(CachedModelLoader,KvCacheRetentionConfig,LlmBuildStats,ModelLoader,_ModelRuntimeContext)from.mpi_sessionimportMpiPoolSession,external_mpi_comm_availablefrom.tokenizerimportTokenizerBase,_xgrammar_tokenizer_info# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following importfrom.utilsimport(append_docstring,exception_handler,get_device_count,logger_debug,set_api_status)[docs]classRequestOutput(DetokenizedGenerationResultBase,GenerationResult):"""The output data of a completion request to the LLM. Attributes: request_id (int): The unique ID of the request. prompt (str, optional): The prompt string of the request. prompt_token_ids (List[int]): The token ids of the prompt. outputs (List[CompletionOutput]): The output sequences of the request. context_logits (torch.Tensor, optional): The logits on the prompt token ids. mm_embedding_handle (Dict[str, Any], optional): The multimodal embedding handle of the request. finished (bool): Whether the whole request is finished. """[docs]def__init__(self)->None:raiseRuntimeError(f"{self.__class__.__name__} is designed to be instantiated using{self.__class__.__name__}._from_generation_result by GenerationExecutor. "f"Users are not expected to create{self.__class__.__name__} directly.") @classmethoddef_from_generation_result(cls,generation_result:GenerationResult,prompt:Optional[str]=None,tokenizer:Optional[TokenizerBase]=None)->'RequestOutput':inst=cls.__new__(cls)inst.__dict__.update(generation_result.__dict__)inst.tokenizer=tokenizerinst._streaming=generation_result._streaminginst._prompt=promptreturninst@propertydefprompt(self)->Optional[str]:returnself._promptdef_repr_fields(self):return["request_id","prompt","prompt_token_ids","outputs","finished","mm_embedding_handle"] TRT_LLM_DOCSTRING=TRT_LLMARGS_EXPLICIT_DOCSTRING+""" Attributes: tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any. workspace (pathlib.Path): The directory to store intermediate files. llm_id (str): The unique ID of the LLM instance."""TORCH_LLM_DOCSTRING=TORCH_LLMARGS_EXPLICIT_DOCSTRING+""" Attributes: tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any. llm_id (str): The unique ID of the LLM instance."""classBaseLLM:""" The base class for all LLM classes. """def__init__(self,model:Union[str,Path],tokenizer:Optional[Union[str,Path,TokenizerBase,PreTrainedTokenizerBase]]=None,tokenizer_mode:Literal['auto','slow']='auto',skip_tokenizer_init:bool=False,trust_remote_code:bool=False,tensor_parallel_size:int=1,dtype:str="auto",revision:Optional[str]=None,tokenizer_revision:Optional[str]=None,**kwargs:Any)->None:self._executor_cls=kwargs.pop("executor_cls",GenerationExecutor)self._orchestrator_type=kwargs.get("orchestrator_type",None)self._llm_id=Nonelog_level=logger.levellogger.set_level("info")# force display the backendtry:backend=kwargs.get('backend',None)ifbackend=="pytorch":logger.info("Using LLM with PyTorch backend")llm_args_cls=TorchLlmArgsifself._orchestrator_type=="ray"ormpi_disabled():self._orchestrator_type="ray"os.environ["TLLM_DISABLE_MPI"]="1"# Propagate to args constructionkwargs["orchestrator_type"]="ray"elifbackend=='_autodeploy':logger.info("Using LLM with AutoDeploy backend")from.._torch.auto_deploy.llm_argsimport \LlmArgsasAutoDeployLlmArgsllm_args_cls=AutoDeployLlmArgselse:logger.info("Using LLM with TensorRT backend")llm_args_cls=TrtLlmArgs# check the kwargs and raise ValueError directlyvalid_keys=set(list(llm_args_cls.model_fields.keys())+['_mpi_session','backend'])forkeyinkwargs:ifkeynotinvalid_keys:raiseValueError(f"{self.__class__.__name__} got invalid argument:{key}")self.args=llm_args_cls.from_kwargs(model=model,tokenizer=tokenizer,tokenizer_mode=tokenizer_mode,skip_tokenizer_init=skip_tokenizer_init,trust_remote_code=trust_remote_code,tensor_parallel_size=tensor_parallel_size,dtype=dtype,revision=revision,tokenizer_revision=tokenizer_revision,**kwargs)exceptExceptionase:logger.error(f"Failed to parse the arguments for the LLM constructor:{e}")raiseefinally:logger.set_level(log_level)# restore the log levellogger_debug(f"LLM.args.mpi_session:{self.args.mpi_session}\n","yellow")self.mpi_session=self.args.mpi_sessionifself.args.parallel_config.is_multi_gpu:ifget_device_count()<self.args.parallel_config.world_size_per_node:raiseRuntimeError(f"Only{get_device_count()} GPUs are available, but{self.args.parallel_config.world_size} are required.")logger.info(f'start MpiSession with{self.args.parallel_config.world_size} workers')ifnotself.mpi_session:mpi_process_pre_spawned:bool=get_spawn_proxy_process_env()ifnotmpi_process_pre_spawned:logger_debug(f"LLM create MpiPoolSession\n","yellow")self.mpi_session=MpiPoolSession(n_workers=self.args.parallel_config.world_size)else:logger_debug(f"LLM create MpiCommSession\n","yellow")self.mpi_session=create_mpi_comm_session(self.args.parallel_config.world_size)try:# Due to the Executor can only accept a engine path, we need to save the engine to a directoryself._engine_dir:Optional[Path]=Noneself._executor:Optional[GenerationExecutor]=Noneifself._on_trt_backend:self._workspace=tempfile.TemporaryDirectory(suffix="-llm-workspace",dir=self.args.workspace)else:self._workspace=Noneself._hf_model_dir:Optional[Path]=Noneself._hf_model_config=Noneself._generation_config=Noneself.runtime_context:Optional[_ModelRuntimeContext]=Noneself.llm_build_stats=LlmBuildStats()self._build_model()exceptException:ifself.mpi_sessionisnotNone:self.mpi_session.shutdown()raisetry:ifself.args.otlp_traces_endpoint:tracing.init_tracer("trt.llm",self.args.otlp_traces_endpoint)logger.info(f"Initialized OTLP tracer successfully, endpoint:{self.args.otlp_traces_endpoint}")exceptExceptionase:logger.error(f"Failed to initialize OTLP tracer:{e}")exception_handler.register(self,'shutdown')atexit.register(LLM._shutdown_wrapper,weakref.ref(self))@property@set_api_status("beta")defllm_id(self)->str:ifself._llm_idisNone:hostname=socket.gethostname()pid=os.getpid()timestamp=int(time.time()*1000)self._llm_id=f"{hostname}-{pid}-{timestamp}"returnself._llm_iddefgenerate(self,inputs:Union[PromptInputs,Sequence[PromptInputs]],sampling_params:Optional[Union[SamplingParams,List[SamplingParams]]]=None,use_tqdm:bool=True,lora_request:Optional[Union[LoRARequest,Sequence[LoRARequest]]]=None,prompt_adapter_request:Optional[Union[PromptAdapterRequest,Sequence[PromptAdapterRequest]]]=None,kv_cache_retention_config:Optional[Union[KvCacheRetentionConfig,Sequence[KvCacheRetentionConfig]]]=None,disaggregated_params:Optional[Union[DisaggregatedParams,Sequence[DisaggregatedParams]]]=None,scheduling_params:Optional[Union[SchedulingParams,List[SchedulingParams]]]=None,cache_salt:Optional[Union[str,Sequence[str]]]=None,)->Union[RequestOutput,List[RequestOutput]]:"""Generate output for the given prompts in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts. Args: inputs (tensorrt_llm.inputs.data.PromptInputs, Sequence[tensorrt_llm.inputs.data.PromptInputs]): The prompt text or token ids. It can be single prompt or batched prompts. sampling_params (tensorrt_llm.sampling_params.SamplingParams, List[tensorrt_llm.sampling_params.SamplingParams], optional): The sampling params for the generation. Defaults to None. A default one will be used if not provided. use_tqdm (bool): Whether to use tqdm to display the progress bar. Defaults to True. lora_request (tensorrt_llm.executor.request.LoRARequest, Sequence[tensorrt_llm.executor.request.LoRARequest], optional): LoRA request to use for generation, if any. Defaults to None. prompt_adapter_request (tensorrt_llm.executor.request.PromptAdapterRequest, Sequence[tensorrt_llm.executor.request.PromptAdapterRequest], optional): Prompt Adapter request to use for generation, if any. Defaults to None. kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, Sequence[tensorrt_llm.bindings.executor.KvCacheRetentionConfig], optional): Configuration for the request's retention in the KV Cache. Defaults to None. disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, Sequence[tensorrt_llm.disaggregated_params.DisaggregatedParams], optional): Disaggregated parameters. Defaults to None. scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], optional): Scheduling parameters. Defaults to None. cache_salt (str, Sequence[str], optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None. Returns: Union[tensorrt_llm.llmapi.RequestOutput, List[tensorrt_llm.llmapi.RequestOutput]]: The output data of the completion request to the LLM. """unbatched=notisinstance(inputs,list)ifnotunbatched:ifisinstance(inputs[0],int):unbatched=Trueifunbatched:inputs=[inputs]inputs=[prompt_inputs(i)foriininputs]def_item_at(maybe_batched:Union[Any,Sequence[Any]],pos:int)->Any:ifisinstance(maybe_batched,list):returnmaybe_batched[pos]else:returnmaybe_batchedfutures=[]fori,request_inputsinenumerate(inputs):future=self.generate_async(request_inputs,sampling_params=_item_at(sampling_params,i),lora_request=_item_at(lora_request,i),prompt_adapter_request=_item_at(prompt_adapter_request,i),kv_cache_retention_config=_item_at(kv_cache_retention_config,i),disaggregated_params=_item_at(disaggregated_params,i),scheduling_params=_item_at(scheduling_params,i),cache_salt=_item_at(cache_salt,i),streaming=False,)futures.append(future)forfutureintqdm(futures,desc="Processed requests",dynamic_ncols=True,disable=notuse_tqdm):future.result()ifunbatched:futures=futures[0]returnfutures@nvtx_range_debug("LLM.generate_async",color="green",category="LLM")defgenerate_async(self,inputs:PromptInputs,sampling_params:Optional[SamplingParams]=None,lora_request:Optional[LoRARequest]=None,prompt_adapter_request:Optional[PromptAdapterRequest]=None,streaming:bool=False,kv_cache_retention_config:Optional[KvCacheRetentionConfig]=None,disaggregated_params:Optional[DisaggregatedParams]=None,trace_headers:Optional[Mapping[str,str]]=None,_postproc_params:Optional[PostprocParams]=None,scheduling_params:Optional[SchedulingParams]=None,cache_salt:Optional[str]=None,)->RequestOutput:"""Generate output for the given prompt in the asynchronous mode. Asynchronous generation accepts single prompt only. Args: inputs (tensorrt_llm.inputs.data.PromptInputs): The prompt text or token ids; it must be single prompt. sampling_params (tensorrt_llm.sampling_params.SamplingParams, optional): The sampling params for the generation. Defaults to None. A default one will be used if not provided. lora_request (tensorrt_llm.executor.request.LoRARequest, optional): LoRA request to use for generation, if any. Defaults to None. prompt_adapter_request (tensorrt_llm.executor.request.PromptAdapterRequest, optional): Prompt Adapter request to use for generation, if any. Defaults to None. streaming (bool): Whether to use the streaming mode for the generation. Defaults to False. kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None. disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None. trace_headers (Mapping[str, str], optional): Trace headers. Defaults to None. scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None. cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None. Returns: tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM. """# Check if the worker is shutting downifself._executorisNoneorself._executor.is_shutdown():raiseRuntimeError("LLM is shutting down")arrival_time=steady_clock_now()ifself.args.return_perf_metricselseNonesampling_params=self._prepare_sampling_params(sampling_params)cache_salt_id=get_cache_salt_id(cache_salt)ifcache_saltisnotNoneelseNone# With pytorch backend, py_executor has logic to handle max_tokens of 1,# so set to 1 to avoid allocating unnecessary KV cache blocks for single request# TODO: Also support for trt backendis_ctx_only=disaggregated_paramsisnotNoneanddisaggregated_params.request_type=="context_only"is_gen_only=disaggregated_paramsisnotNoneanddisaggregated_params.request_type=="generation_only"is_mm_disagg=disaggregated_paramsisnotNoneanddisaggregated_params.multimodal_embedding_handlesisnotNoneifis_ctx_onlyandnotself._on_trt_backend:sampling_params.max_tokens=1inputs=prompt_inputs(inputs)ifnotinputs.get("prompt")andinputs.get("prompt_token_ids")and(inputs.get("multi_modal_data")orinputs.get("multi_modal_embeddings"))andnotisinstance(self.input_processor,DefaultInputProcessor):# VLMs need to process/tokenize the prompt in their own wayprompt=self.tokenizer.decode(inputs['prompt_token_ids'])inputs=TextPrompt(prompt=prompt,multi_modal_data=inputs.get("multi_modal_data"),mm_processor_kwargs=inputs.get("mm_processor_kwargs"))ifsampling_params.add_special_tokens:logger.debug("Setting add_special_tokens to False because prompt_token_ids were provided to generate. VLMs will re-encode the prompt.")sampling_params.add_special_tokens=Falsequery_token_ids=Nonemultimodal_params=Noneifis_mm_disagg:ifnotself.input_processor.support_mm_disagg:raiseValueError("Multimodal disaggregated inference is not supported for this model")mm_handles=disaggregated_params.multimodal_embedding_handlesprompt_token_ids,mm_token_length,mm_token_positions=self.input_processor.get_prompt_token_ids(inputs,mm_handles)prompt=inputs.get("prompt",None)query_token_ids=inputs.get("query_token_ids",None)ifis_gen_only:raiseValueError("Generation-only mode should not need multimodal parameters")else:mm_hashes=disaggregated_params.multimodal_hashesmultimodal_input=MultimodalInput.from_components(mm_hashes,mm_token_positions,mm_token_length)multimodal_params=MultimodalParams(multimodal_input=multimodal_input,multimodal_data={"multimodal_embedding":mm_handles})elif"prompt_token_ids"ininputs:prompt_token_ids=inputs['prompt_token_ids']prompt=Nonequery_token_ids=inputs.get("query_token_ids",None)elif"prompt"ininputs:if'multi_modal_data'ininputs:# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)# to handle/add multimodal hashes, positions, and lengths. Now we only support image modality.# In the future, we should refactor this to:# 1. Extend support for more modalities and models# 2. Decouple input processor into distinct phases (preprocessor (all preprocessing logics), vision model (fuse in model fwd), etc.input_processor_with_hash=create_input_processor_with_hash(self.input_processor)withnvtx_range_debug("input_processor_with_hash"):prompt_token_ids,extra_processed_inputs=input_processor_with_hash(inputs,sampling_params)elif'multi_modal_embeddings'ininputs:mm_embedding_info=inputs['multi_modal_embeddings']prompt_token_ids,extra_processed_inputs=self.input_processor.attach_multimodal_embeddings(inputs,mm_embedding_info,sampling_params)else:withnvtx_range_debug("input_processor"):prompt_token_ids,extra_processed_inputs=self.input_processor(inputs,sampling_params)prompt=inputs['prompt']ifextra_processed_inputsisnotNone:query_token_ids=extra_processed_inputs.get('query_token_ids')# Create unified MultimodalParamsmultimodal_params=MultimodalParams(multimodal_input=extra_processed_inputs.get('multimodal_input'),multimodal_data=extra_processed_inputs.get('multimodal_data'))# Only pass it if it has contentifnotmultimodal_params.has_content():multimodal_params=Noneelse:# Convert to shared tensor handle to reduce IPC overheadmultimodal_params.to_handle("multimodal_data")else:raiseTypeError(f"The inputs must be type str or list of int, but got{type(inputs)}")self._check_arguments(len(prompt_token_ids),len(query_token_ids)ifquery_token_idsisnotNoneelse0,sampling_params,is_gen_only=is_gen_only)if_postproc_params:_postproc_params.postproc_args.num_prompt_tokens=len(prompt_token_ids)result=self._executor.generate_async(prompt_token_ids,query_token_ids=query_token_ids,sampling_params=sampling_params,lora_request=lora_request,prompt_adapter_request=prompt_adapter_request,streaming=streaming,kv_cache_retention_config=kv_cache_retention_config,disaggregated_params=disaggregated_params,trace_headers=trace_headers,postproc_params=_postproc_params,multimodal_params=multimodal_params,scheduling_params=scheduling_params,cache_salt_id=cache_salt_id,arrival_time=arrival_time,)ifsampling_params.return_perf_metrics:result.metrics_dict.update({MetricNames.ARRIVAL_TIMESTAMP:time.time()})returnRequestOutput._from_generation_result(result,prompt,self.tokenizer)@set_api_status("beta")defget_stats(self,timeout:Optional[float]=2)->List[dict]:'''Get iteration statistics from the runtime. To collect statistics, call this function after prompts have been submitted with LLM().generate(). Args: timeout (float, optional): Max wait time in seconds when retrieving stats from queue. Defaults to 2. Returns: List[dict]: A list of runtime stats as dict. e.g., ['{"cpuMemUsage": ..., "iter": 0, ...}', '{"cpuMemUsage": ..., "iter": 1, ...}'] '''returnself._executor.get_stats(timeout=timeout)@set_api_status("beta")defget_stats_async(self,timeout:Optional[float]=2)->IterationResult:'''Get iteration statistics from the runtime. To collect statistics, you can call this function in an async coroutine or the /metrics endpoint (if you're using trtllm-serve) after prompts have been submitted. Args: timeout (float, optional): Max wait time in seconds when retrieving stats from queue. Defaults to 2. Returns: tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime stats. '''returnself._executor.aget_stats(timeout=timeout)@set_api_status("beta")defget_kv_cache_events(self,timeout:Optional[float]=2)->List[dict]:'''Get iteration KV events from the runtime. KV events are used to track changes and operations within the KV Cache. Types of events: - KVCacheCreatedData: Indicates the creation of cache blocks. - KVCacheStoredData: Represents a sequence of stored blocks. - KVCacheRemovedData: Contains the hashes of blocks that are being removed from the cache. - KVCacheUpdatedData: Captures updates to existing cache blocks. To enable KV events: - set `event_buffer_max_size` to a positive integer in the `KvCacheConfig`. - set `enable_block_reuse` to True in the `KvCacheConfig`. Args: timeout (float, optional): Max wait time in seconds when retrieving events from queue. Defaults to 2. Returns: List[dict]: A list of runtime events as dict. '''returnself._executor.get_kv_events(timeout=timeout)@set_api_status("beta")defget_kv_cache_events_async(self,timeout:Optional[float]=2)->IterationResult:'''Get iteration KV events from the runtime. KV events are used to track changes and operations within the KV Cache. Types of events: - KVCacheCreatedData: Indicates the creation of cache blocks. - KVCacheStoredData: Represents a sequence of stored blocks. - KVCacheRemovedData: Contains the hashes of blocks that are being removed from the cache. - KVCacheUpdatedData: Captures updates to existing cache blocks. To enable KV events: - set `event_buffer_max_size` to a positive integer in the `KvCacheConfig`. - set `enable_block_reuse` to True in the `KvCacheConfig`. Args: timeout (float, optional): Max wait time in seconds when retrieving events from queue. . Defaults to 2. Returns: tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime events. '''returnself._executor.aget_kv_events(timeout=timeout)def_prepare_sampling_params(self,sampling_params:Optional[SamplingParams]=None)->SamplingParams:ifsampling_paramsisNone:sampling_params=SamplingParams()ifisinstance(sampling_params,SamplingParams):ifsampling_params.end_idisNone:ifself.tokenizerisNone:raiseValueError("tokenizer is required to reset end_id if it is None, or you can explicitly specify the end_id for sampling_params")sampling_params._setup(self.tokenizer,self._hf_model_config,self._generation_config)else:raiseTypeError(f"The sampling_params must be type SamplingParams or None, but got{type(sampling_params)}")# auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend.ifself.args.backendnotin["pytorch","_autodeploy"]:ifsampling_params.prompt_logprobsandnotsampling_params.return_context_logits:sampling_params.return_context_logits=Truesampling_params._context_logits_auto_enabled=Trueifsampling_params.logprobsandnotsampling_params.return_generation_logits:sampling_params.return_generation_logits=Truesampling_params._generation_logits_auto_enabled=Trueifsampling_params._stream_intervalisNone:sampling_params._stream_interval=getattr(self.args,"stream_interval",1)sampling_params.return_perf_metrics=sampling_params.return_perf_metricsorself.args.return_perf_metricsreturnsampling_paramsdef_check_arguments(self,prompt_len:int,query_len:int,sampling_params:SamplingParams,is_gen_only:bool)->None:ifself.args.backendin["pytorch","_autodeploy"]:# Check prompt length and query length against max_num_tokens to filter illegal requests.# Skip check for gen-only requestsifself.args.backend=="pytorch"andnotself.args.enable_chunked_prefillandnotis_gen_only:max_num_tokens=self.args.max_num_tokensifmax_num_tokensandprompt_len/self.args.parallel_config.cp_size+query_len>max_num_tokens:raiseValueError(f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "f"max_num_tokens ({max_num_tokens})")returnbuild_config=self.args.build_configbuilt_enging_cfg_file=Path(self.args.model)/'config.json'withopen(built_enging_cfg_file)asf:built_enging_cfg=json.load(f)max_seq_len=built_enging_cfg['build_config']['max_seq_len']if'build_config'inbuilt_enging_cfgelsebuild_config.max_seq_len# TODO: Remove this check and left the request verification to cpp runtimeif(notself.args.enable_chunked_prefill)and(prompt_len/self.args.parallel_config.cp_size+query_len+(sampling_params.max_tokensor0)>max_seq_len):raiseValueError(f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "f"max_seq_len ({max_seq_len})")ifsampling_params.use_beam_searchandsampling_params.best_of>build_config.max_beam_width:ifsampling_params.n==sampling_params.best_of:raiseValueError(f"sampling_params.n ({sampling_params.n}) cannot exceed max_beam_width ({build_config.max_beam_width}) when use_beam_search is True")else:raiseValueError(f"sampling_params.best_of ({sampling_params.best_of}) cannot exceed max_beam_width ({build_config.max_beam_width}) when use_beam_search is True")max_batch_size=self.args.max_batch_sizeifmax_batch_sizeisNone:max_batch_size=build_config.max_batch_sizeifnotsampling_params.use_beam_searchandsampling_params.best_of>max_batch_size:ifsampling_params.n==sampling_params.best_of:raiseValueError(f"sampling_params.n ({sampling_params.n}) cannot exceed max_batch_size ({max_batch_size}) when use_beam_search is False")else:raiseValueError(f"sampling_params.best_of ({sampling_params.best_of}) cannot exceed max_batch_size ({max_batch_size}) when use_beam_search is False")ifsampling_params.prompt_logprobsandnotbuild_config.gather_context_logits:raiseValueError(f"`sampling_params's prompt_logprobs={sampling_params.prompt_logprobs}` requires `gather_context_logits=True` "f"in the `BuildConfig` when constructing the LLM. "f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True)).")ifsampling_params.logprobsandnotself.args.gather_generation_logits:raiseValueError(f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "f"to be passed explicitly to the `LLM()` constructor.")def_build_model(self):model_loader=CachedModelLoader(self.args,mpi_session=self.mpi_session,workspace=self._workspace,llm_build_stats=weakref.proxy(self.llm_build_stats))self._engine_dir,self._hf_model_dir=model_loader()@propertydef_on_trt_backend(self)->bool:returnisinstance(self.args,TrtLlmArgs)def_try_load_tokenizer(self)->Optional[TokenizerBase]:ifself.args.skip_tokenizer_init:returnNoneifself.args.tokenizerisnotNone:assertisinstance(self.args.tokenizer,TokenizerBase)returnself.args.tokenizerifself.runtime_contextisnotNone:returnself.runtime_context.tokenizer# TODO smor- need to refine what is the desired behavior if lora is enabled# in terms of the tokenizer initialization processifhasattr(self.args,"backend")andself.args.backendin["pytorch","_autodeploy"]andself.args.lora_configisnotNone:num_lora_dirs=len(self.args.lora_config.lora_dir)ifnum_lora_dirs==1:tokenizer_path=self.args.lora_config.lora_dir[0]try:tokenizer=ModelLoader.load_hf_tokenizer(tokenizer_path,trust_remote_code=self.args.trust_remote_code,use_fast=self.args.tokenizer_mode!='slow')iftokenizerisNone:tokenizer_path=self.args.modelelse:returntokenizerexceptException:tokenizer_path=self.args.modelelse:tokenizer_path=self.args.modelelse:tokenizer_path=self.args.modelreturnModelLoader.load_hf_tokenizer(tokenizer_path,trust_remote_code=self.args.trust_remote_code,use_fast=self.args.tokenizer_mode!='slow')@propertydeftokenizer(self)->Optional[TokenizerBase]:ifhasattr(self,"input_processor"):ifhasattr(self.input_processor,"tokenizer"):returnself.input_processor.tokenizerreturnself._tokenizer@tokenizer.setterdeftokenizer(self,tokenizer:TokenizerBase):self._tokenizer=tokenizerdef_try_load_generation_config(self)->Optional[transformers.GenerationConfig]:returnModelLoader.load_hf_generation_config(self.args.model)def_try_load_hf_model_config(self)->Optional[transformers.PretrainedConfig]:returnModelLoader.load_hf_model_config(self.args.model)@set_api_status("beta")defshutdown(self)->None:ifhasattr(self,"_executor")andself._executorisnotNone:self._executor.shutdown()self._executor=Noneifhasattr(self,'mpi_session')andself.mpi_sessionisnotNone:self.mpi_session.shutdown()self.mpi_session=None@staticmethoddef_shutdown_wrapper(self_ref):# Retrieve the instance if it still existsinstance=self_ref()ifinstanceisnotNone:instance.shutdown()def__enter__(self):returnselfdef__exit__(self,exc_type,exc_value,traceback)->Literal[False]:# https://github.com/microsoft/pyright/issues/7009#issuecomment-1894135045delexc_value,tracebackself.shutdown()returnFalse# propagate exceptionsdef__getstate__(self):raiseRuntimeError("LLM object can not be pickled.")def__del__(self):self.shutdown()@append_docstring(TRT_LLM_DOCSTRING)class_TrtLLM(BaseLLM):"""LLM class is the main class for running a LLM model using TensorRT LLM backend. Parameters:"""def__init__(self,model:Union[str,Path],tokenizer:Optional[Union[str,Path,TokenizerBase,PreTrainedTokenizerBase]]=None,tokenizer_mode:Literal['auto','slow']='auto',skip_tokenizer_init:bool=False,trust_remote_code:bool=False,tensor_parallel_size:int=1,dtype:str="auto",revision:Optional[str]=None,tokenizer_revision:Optional[str]=None,**kwargs:Any)->None:# TODO: deprecate backend in LLM kwargssuper().__init__(model,tokenizer,tokenizer_mode,skip_tokenizer_init,trust_remote_code,tensor_parallel_size,dtype,revision,tokenizer_revision,**kwargs)@propertydefworkspace(self)->Path:returnPath(self._workspace.name)ifself._on_trt_backendelseNonedefsave(self,engine_dir:str)->None:"""Save the built engine to the given path. Args: engine_dir (str): The path to save the engine. """logger.info(f"Save model to{engine_dir}")ifself._engine_dirisNone:raiseRuntimeError("The engine is not built yet.")ifself._engine_dir.absolute()==os.path.abspath(engine_dir):returnifnotself.mpi_sessionornotself.mpi_session.is_comm_session():shutil.copytree(self._engine_dir,engine_dir,dirs_exist_ok=True)else:# NFS is fragile, so we copy files one by onetarget_engine_dir=Path(engine_dir)target_engine_dir.mkdir(parents=True,exist_ok=True)# copy files one by oneforfileinself._engine_dir.iterdir():logger_debug(f"Copying{file} to{target_engine_dir/file.name}\n")shutil.copy(file,target_engine_dir/file.name)def_build_model(self):super()._build_model()# update the model_dir to a local dir for the runtime, such as tokenizer loading.ifself._engine_dirisnotNone:self.args.model=self._engine_dir# Tokenizer and config loading should be after calling model_loader(), since model_loader() may download the model from HF hub.# It should also be before bindings ExecutorConfig, which may depend on tokenizer info.self._tokenizer=self._try_load_tokenizer()self._hf_model_config=self._try_load_hf_model_config()self._generation_config=self._try_load_generation_config()# Multimodal special handling:# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__self.input_processor=create_input_processor(self._hf_model_dir,self.tokenizer)self._tokenizer=self.input_processor.tokenizermax_batch_size=self.args.max_batch_sizemax_num_tokens=self.args.max_num_tokensmax_seq_len=self.args.max_seq_lenbuild_config=self.args.build_configmax_batch_size=max_batch_sizeorbuild_config.max_batch_sizemax_num_tokens=max_num_tokensorbuild_config.max_num_tokensmax_seq_len=max_seq_lenorbuild_config.max_seq_lenself._executor_config=tllm.ExecutorConfig(max_beam_width=self.args.max_beam_width,scheduler_config=PybindMirror.maybe_to_pybind(self.args.scheduler_config),batching_type=PybindMirror.maybe_to_pybind(self.args.batching_type)ortllm.BatchingType.INFLIGHT,max_batch_size=max_batch_size,max_num_tokens=max_num_tokens,gather_generation_logits=self.args.gather_generation_logits,fail_fast_on_attention_window_too_large=getattr(self.args,'fail_fast_on_attention_window_too_large',False))# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokensifmax_seq_lenisnotNone:self._executor_config.max_seq_len=max_seq_lenelse:engine_config=EngineConfig.from_json_file(self._engine_dir/"config.json")self._executor_config.max_seq_len=engine_config.build_config.max_seq_lenifself.args.kv_cache_configisnotNone:self._executor_config.kv_cache_config=PybindMirror.maybe_to_pybind(self.args.kv_cache_config)ifos.getenv("FORCE_DETERMINISTIC","0")=="1":# Disable KV cache reuse for deterministic modeself._executor_config.kv_cache_config.enable_block_reuse=Falseself._executor_config.kv_cache_config.enable_partial_reuse=Falseifself.args.peft_cache_configisnotNone:self._executor_config.peft_cache_config=PybindMirror.maybe_to_pybind(self.args.peft_cache_config)lora_config=Noneifself.args.build_config.plugin_config.lora_plugin:engine_config=EngineConfig.from_json_file(self._engine_dir/"config.json")lora_config=engine_config.build_config.lora_configifself.args.lora_configisnotNone:logger.info("Overriding lora_config from engine with lora_config from LLM args")lora_config=self.args.lora_configmax_lora_rank=lora_config.max_lora_ranknum_lora_modules=engine_config.pretrained_config.num_hidden_layers* \len(lora_config.lora_target_modules+lora_config.missing_qkv_modules)peft_cache_config_model=PeftCacheConfig.from_pybind(self._executor_config.peft_cache_config)ifself._executor_config.peft_cache_configisnotNoneelsePeftCacheConfig()iflora_config.max_lorasisnotNone:peft_cache_config_model.num_device_module_layer= \max_lora_rank*num_lora_modules*lora_config.max_lorasiflora_config.max_cpu_lorasisnotNone:peft_cache_config_model.num_host_module_layer= \max_lora_rank*num_lora_modules*lora_config.max_cpu_lorasself._executor_config.peft_cache_config=peft_cache_config_model._to_pybind()ifself.args.decoding_configisnotNone:self._executor_config.decoding_config=self.args.decoding_configifself.args.guided_decoding_backend=='xgrammar':self._executor_config.guided_decoding_config=tllm.GuidedDecodingConfig(backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,**_xgrammar_tokenizer_info(self.tokenizer))elifself.args.guided_decoding_backendisnotNone:raiseValueError(f"Unsupported guided decoding backend{self.args.guided_decoding_backend}")self._executor_config.normalize_log_probs=self.args.normalize_log_probsself._executor_config.enable_chunked_context=self.args.enable_chunked_prefillself._executor_config.max_beam_width=self.args.max_beam_widthorself.args.build_config.max_beam_widthifself.args.extended_runtime_perf_knob_configisnotNone:self._executor_config.extended_runtime_perf_knob_config=PybindMirror.maybe_to_pybind(self.args.extended_runtime_perf_knob_config)ifself.args.cache_transceiver_configisnotNone:self._executor_config.cache_transceiver_config=PybindMirror.maybe_to_pybind(self.args.cache_transceiver_config)self._executor_config.llm_parallel_config=self.args.parallel_configreturn_logits=(self.args.gather_generation_logitsor(self.args.build_configandself.args.build_config.gather_context_logits))self._executor=self._executor_cls.create(self._engine_dir,executor_config=self._executor_config,batched_logits_processor=self.args.batched_logits_processor,model_world_size=self.args.parallel_config.world_size,mpi_session=self.mpi_session,reuse_mpi_comm=external_mpi_comm_available(self.args.parallel_config.world_size),return_logits=return_logits,postproc_worker_config=PostprocWorkerConfig(num_postprocess_workers=self.args.num_postprocess_workers,postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,),is_llm_executor=True)@append_docstring(TORCH_LLM_DOCSTRING)class_TorchLLM(BaseLLM):"""LLM class is the main class for running a LLM model using PyTorch backend. Parameters:"""def__init__(self,model:Union[str,Path],tokenizer:Optional[Union[str,Path,TokenizerBase,PreTrainedTokenizerBase]]=None,tokenizer_mode:Literal['auto','slow']='auto',skip_tokenizer_init:bool=False,trust_remote_code:bool=False,tensor_parallel_size:int=1,dtype:str="auto",revision:Optional[str]=None,tokenizer_revision:Optional[str]=None,**kwargs:Any)->None:# TODO: deprecate backend in LLM kwargsbackend=kwargs.pop("backend","pytorch")# Validate that users don't pass TrtLlmArgs-specific argumentsself._validate_args_for_torch_backend(kwargs)super().__init__(model,tokenizer,tokenizer_mode,skip_tokenizer_init,trust_remote_code,tensor_parallel_size,dtype,revision,tokenizer_revision,backend=backend,**kwargs)@set_api_status("prototype")def_collective_rpc(self,method:str,args:tuple[Any,...]=(),kwargs:Optional[dict]=None,non_block:bool=False,unique_reply_rank:Optional[int]=None)->list[Any]:""" Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor. Args: method (str): The name of the worker method to execute. args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to (). kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None. non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False. unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None. Returns: list[Any]: A list of results from each worker. """ifhasattr(self._executor,'collective_rpc'):returnself._executor.collective_rpc(method,args,kwargs,non_block,unique_reply_rank)else:raiseValueError(f"Executor type{type(self._executor)} does not support collective RPC.")def_build_model(self):super()._build_model()assertself._engine_dirisNone# Tokenizer and config loading should be after calling model_loader(), since model_loader() may download the model from HF hub.# It should also be before bindings ExecutorConfig, which may depend on tokenizer info.self._tokenizer=self._try_load_tokenizer()self._hf_model_config=self._try_load_hf_model_config()self._generation_config=self._try_load_generation_config()# Multimodal special handling:# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__checkpoint_format=getattr(self.args,"checkpoint_format",None)self.input_processor=create_input_processor(self._hf_model_dir,self.tokenizer,checkpoint_format)self._tokenizer=self.input_processor.tokenizer# TODO: revisit gather_context_logitsreturn_logits=self.args.gather_generation_logitsself._executor=self._executor_cls.create(self._engine_dir,executor_config=None,batched_logits_processor=self.args.batched_logits_processor,model_world_size=self.args.parallel_config.world_size,mpi_session=self.mpi_session,reuse_mpi_comm=external_mpi_comm_available(self.args.parallel_config.world_size),return_logits=return_logits,postproc_worker_config=PostprocWorkerConfig(num_postprocess_workers=self.args.num_postprocess_workers,postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,),is_llm_executor=True,hf_model_dir=self._hf_model_dir,tokenizer=self.tokenizer,llm_args=self.args)def_validate_args_for_torch_backend(self,kwargs:dict)->None:"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend. """trtllm_fields=set(TrtLlmArgs.model_fields.keys())torchllm_fields=set(TorchLlmArgs.model_fields.keys())trtllm_specific_fields=trtllm_fields-torchllm_fields# Check if any TrtLlmArgs-specific arguments are passedtrtllm_specific_args=[]forkeyinkwargs:ifkeyintrtllm_specific_fields:trtllm_specific_args.append(key)iftrtllm_specific_args:raiseValueError(f"The following arguments are specific to TensorRT backend and cannot be used with PyTorch backend:{trtllm_specific_args}.\n"f"Please use 'from tensorrt_llm._tensorrt_engine import LLM' instead to use the TensorRT backend.")[docs]classLLM(_TorchLLM):[docs]def__init__(self,model:Union[str,Path],tokenizer:Optional[Union[str,Path,TokenizerBase,PreTrainedTokenizerBase]]=None,tokenizer_mode:Literal['auto','slow']='auto',skip_tokenizer_init:bool=False,trust_remote_code:bool=False,tensor_parallel_size:int=1,dtype:str="auto",revision:Optional[str]=None,tokenizer_revision:Optional[str]=None,**kwargs:Any)->None:super().__init__(model,tokenizer,tokenizer_mode,skip_tokenizer_init,trust_remote_code,tensor_parallel_size,dtype,revision,tokenizer_revision,**kwargs) # sphinx will ignore the LLM's docstring if it is not explicitly setLLM.__doc__= \f"""LLM class is the main class for running a LLM model. For more details about the arguments, please refer to :class:`TorchLlmArgs`. Parameters:"""+TORCH_LLM_DOCSTRING