Source code for tensorrt_llm.builder

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importjsonimportmathimportosimportshutilimporttimefrompathlibimportPathfromtypingimportDict,Optional,UnionimportnumpyasnpimporttensorrtastrtfrompydanticimportBaseModel,Fieldfrom._commonimport_is_building,check_max_num_tokens,serialize_enginefrom._utilsimport(get_sm_version,np_bfloat16,np_float8,str_dtype_to_trt,to_json_file,trt_gte)from.functionalimportPositionEmbeddingTypefrom.graph_rewritingimportoptimizefrom.llmapi.kv_cache_typeimportKVCacheTypefrom.loggerimportloggerfrom.lora_helperimportLoraConfigfrom.modelsimportPretrainedConfig,PretrainedModelfrom.models.modeling_utilsimportSpeculativeDecodingMode,optimize_modelfrom.networkimportNetwork,net_guardfrom.pluginimportPluginConfigfrom.quantizationimportQuantAlgo,QuantModefrom.versionimport__version__classConfigEncoder(json.JSONEncoder):defdefault(self,obj):ifhasattr(obj,'model_dump'):# Handle Pydantic models (including DecodingBaseConfig and subclasses)returnobj.model_dump(mode='json')else:returnsuper().default(obj)classBuilderConfig(object):def__init__(self,**kwargs):# intentionally use **kwargs, user should never call this ctor directly,# use Builder.create_builder_config() insteadpassdef_init(self,trt_builder_config,**kwargs):self._trt_builder_config=trt_builder_configforkey,valueinkwargs.items():setattr(self,key,value)returnself@propertydeftrt_builder_config(self)->trt.IBuilderConfig:returnself._trt_builder_configdefto_dict(self)->Dict:'''return a dict with keys        {            "builder_config": {                # all key values set by the _init function            },            "plugin_config": {                # the network plugin_config (if any) attached to this BuilderConfig object                # inside the Builder.build_engine            }        }        '''config={'builder_config':{}}forkinself.__dict__.keys():ifknotin['_trt_builder_config','plugin_config']:config['builder_config'][k]=self.__getattribute__(k)ifhasattr(self,'plugin_config'):assertisinstance(self.plugin_config,PluginConfig), \f"Found unexpected plugin_config object with type:{type(self.plugin_config)}"config['plugin_config']=self.plugin_config.model_dump(mode="json")returnconfigclassBuilder():_ALLOWED_PRECISIONS=['float32','float16','bfloat16',trt.DataType.HALF,trt.DataType.FLOAT,trt.DataType.BF16]def__init__(self):super().__init__()self._trt_builder=trt.Builder(logger.trt_logger)self.strongly_typed=True@propertydeftrt_builder(self)->trt.Builder:returnself._trt_builderdefcreate_network(self)->Network:explicit_batch_flag=0# Explicit batch flag will be deprecated in TRT 10if"EXPLICIT_BATCH"intrt.NetworkDefinitionCreationFlag.__members__.keys():explicit_batch_flag=1<<int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)ifself.strongly_typed:returnNetwork()._init(self.trt_builder.create_network(explicit_batch_flag|(1<<int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))else:returnNetwork()._init(self.trt_builder.create_network(explicit_batch_flag))defcreate_builder_config(self,precision:Union[str,trt.DataType],timing_cache:Union[str,Path,trt.ITimingCache]=None,tensor_parallel:int=1,use_refit:bool=False,int8:bool=False,strongly_typed:bool=True,force_num_profiles:Optional[int]=None,profiling_verbosity:str="layer_names_only",use_strip_plan:bool=False,weight_streaming:bool=False,precision_constraints:Optional[str]="obey",**kwargs)->BuilderConfig:''' @brief Create a builder config with given precisions and timing cache            @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS            @param timing_cache: a timing cache object or a path to a timing cache file            @param tensor_parallel: number of GPUs used for tensor parallel            @param kwargs: any other arguments users would like to attach to the config object as attributes            @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others            @param int8: whether to build with int8 enabled or not. Can't be used together with refit option            @return: A BuilderConfig object, return None if failed        '''self.strongly_typed=self.strongly_typedandstrongly_typedquant_mode=kwargs.get("quant_mode",QuantMode(0))ifnotstrongly_typedandprecisionnotinself._ALLOWED_PRECISIONS:logger.error(f"precision should be one of{self._ALLOWED_PRECISIONS}")config=self.trt_builder.create_builder_config()ifweight_streaming:config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)ifnotself.strongly_typed:fp8=quant_mode.has_fp8_qdq()orquant_mode.has_fp8_kv_cache()ifprecision=='float16'orprecision==trt.DataType.HALF:config.set_flag(trt.BuilderFlag.FP16)ifprecision_constraints=='obey':config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)elifprecision=='bfloat16'orprecision==trt.DataType.BF16:config.set_flag(trt.BuilderFlag.BF16)ifprecision_constraints=='obey':config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)ifint8:config.set_flag(trt.BuilderFlag.INT8)iffp8:config.set_flag(trt.BuilderFlag.FP8)ifprecision_constraints=='obey':config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)ifuse_refit:config.set_flag(trt.BuilderFlag.REFIT)# Use fine-grained refit when strip plan is enabled in TRT10.2+.ifuse_strip_plan:config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)ifuse_strip_plan:config.set_flag(trt.BuilderFlag.STRIP_PLAN)# Set TRT Engine profiling verbosityifprofiling_verbosity=="detailed":config.profiling_verbosity=trt.ProfilingVerbosity.DETAILEDelifprofiling_verbosity=="none":config.profiling_verbosity=trt.ProfilingVerbosity.NONEelse:config.profiling_verbosity=trt.ProfilingVerbosity.LAYER_NAMES_ONLY# set timing cachecache=Noneiftiming_cacheisnotNone:# use given cacheifisinstance(timing_cache,trt.ITimingCache):cache=timing_cache# read cache from fileelifisinstance(timing_cache,(str,Path))andos.path.exists(timing_cache):withopen(timing_cache,"rb")asf:cache=config.create_timing_cache(f.read())else:logger.warning("Invalid timing cache, using freshly created one")ifcacheisNone:cache=config.create_timing_cache(b"")# When user does not given any existing cache, internally always created one# so the cache should never None hereassertcacheisnotNoneandisinstance(cache,trt.ITimingCache)config.set_timing_cache(cache,ignore_mismatch=False)# set weight sparsityweight_sparsity=kwargs.get("weight_sparsity",False)ifweight_sparsity:config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)# TODO: remove this constraint after trt 10.6 is integratediftrt_gte(10,6):# set monitor memorymonitor_memory=kwargs.get("monitor_memory",False)ifmonitor_memory:config.set_flag(trt.BuilderFlag.MONITOR_MEMORY)returnBuilderConfig()._init(config,precision=precision,tensor_parallel=tensor_parallel,use_refit=use_refit,int8=int8,force_num_profiles=force_num_profiles,strongly_typed=self.strongly_typed,use_strip_plan=use_strip_plan,**kwargs)def_add_optimization_profile(self,network:Network,builder_config:BuilderConfig):assertisinstance(builder_config,BuilderConfig)assertisinstance(network,Network)input_tensors=network._inputsiflen(input_tensors)==0:logger.warning("There are no inputs in the network!")returnnum_profiles=len(list(input_tensors.values())[0].profiles)force_num_profiles=getattr(builder_config,"force_num_profiles",None)foriinrange(num_profiles):logger.debug(f'Adding optimization profile{i+1}/{num_profiles}')profile=self.trt_builder.create_optimization_profile()forinput_nameininput_tensors.keys():iflen(input_tensors[input_name].profiles)==0:continueshape_profile=input_tensors[input_name].profiles[i]min_shape=[*shape_profile.min]opt_shape=[*shape_profile.opt]max_shape=[*shape_profile.max]profile.set_shape(input_name,min_shape,opt_shape,max_shape)logger.debug(f'{input_name}, min:{min_shape}, opt:{opt_shape}, max:{max_shape}, dimension names:{shape_profile.dimension_names}')ret=builder_config.trt_builder_config.add_optimization_profile(profile)logger.debug(f"Added optimization profile: #{ret}")ifforce_num_profilesisnotNoneand(i+1)==force_num_profilesandforce_num_profiles<num_profiles:logger.warning(f"Only adding{force_num_profiles} profiles instead of{num_profiles}.")breakassertself._validate_named_dimensions(network,builder_config),"Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"def_validate_named_dimensions(self,network:Network,builder_config)->bool:'''            For each profile, validate that the named dimensions of different input tensors in this profile all have same range.            TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT LLM are correct and            makes the error msg more user friendly.        '''valid=Trueforprofile_idxinrange(builder_config.trt_builder_config.num_optimization_profiles):dimension_to_range={}forinput_name,input_tensorinnetwork._inputs.items():# it's legal that a Tensor does not have dim_range?iflen(input_tensor.profiles)!=0:profile=input_tensor.profiles[profile_idx]fordim_idx,dim_nameinenumerate(profile.dimension_names):ifdim_namenotindimension_to_range:dimension_to_range[dim_name]=[]min,opt,max=profile.min[dim_idx],profile.opt[dim_idx],profile.max[dim_idx]dimension_to_range[dim_name].append((input_name,(min,opt,max)))fordim,rangesindimension_to_range.items():unique_ranges=set([r[1]forrinranges])logger.debug(f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}")iflen(unique_ranges)!=1:logger.error(f"Found illegal dimension setting for profile{profile_idx}, dimension name is:{dim}")logger.error("Offensive tensors which have this dimension are:\n"+"\n".join([f"{r[1]}{dim}{r[0]}"forrinranges]))valid=Falsereturnvalid@_is_buildingdefrefit_engine(self,network:Network,engine_buffer)->trt.IHostMemory:'''            @brief: Refit one TensorRT engine using weights from the network,                user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.            @param engine_buffer: A serialized TensorRT engine.            @param network: Network object.            @return: A serialized TRT engine if refit successfully, None otherwise        '''assertisinstance(network,Network)logger.info('Refit TRT engine')runtime=trt.Runtime(logger.trt_logger)engine=runtime.deserialize_cuda_engine(engine_buffer)tik=time.time()# Refit enginerefitter=trt.Refitter(engine,logger.trt_logger)ifnetwork.named_parametersisnotNone:forname,paraminnetwork.named_parameters:ifparam._get_weights()isNoneornotrefitter.set_named_weights(name,param._get_weights()):logger.error(f'Failed to refit weight:{name}')returnNoneelse:logger.error('Please set named parameters before building multiple engines.')returnNoneifnotrefitter.refit_cuda_engine():logger.error('Failed to refit engine.')returnNonetok=time.time()t=time.strftime('%H:%M:%S',time.gmtime(tok-tik))logger.info(f'Total time of refitting{engine.name}:{t}')serialized_engine=engine.serialize()returnserialized_engine@_is_buildingdefbuild_engine(self,network:Network,builder_config:BuilderConfig,managed_weights:dict=None)->trt.IHostMemory:'''            @brief: Build one TensorRT engine from the network.            @param network: Network object.            @param builder_config: BuilderConfig object.            @return: A serialized TRT engine.        '''assertisinstance(network,Network)builder_config.plugin_config=network.plugin_configifbuilder_config.trt_builder_config.num_optimization_profiles==0:self._add_optimization_profile(network,builder_config)logger.info(f"Total optimization profiles added:{builder_config.trt_builder_config.num_optimization_profiles}")engine=Nonetik=time.time()# Rename weightsifnetwork.named_parametersisnotNone:managed_parameters=[]forname,paraminnetwork.named_parameters:ifparam.is_managed(network):assertmanaged_weightsisnotNone,"managed_weights should be provided when enabled"managed_parameters.append(param)param.set_name(name,network)continueifparam._get_weights(network)isNone:ifnotparam.is_buffer:logger.debug(f"Parameter{name}{param.raw_value.shape}{param.raw_value.dtype} was created"" but unused in forward method, so not materialized to TRT network")continueifnotparam.set_name(name,network):raiseRuntimeError(f'Failed to set weight:{name}')# This mark_weights_refittable has no side effect when refit_individual is not enabled.network.trt_network.mark_weights_refittable(name)network._fill_weights()tok=time.time()t=time.strftime('%H:%M:%S',time.gmtime(tok-tik))logger.info(f'Total time to initialize the weights in network{network.trt_network.name}:{t}')# Build enginelogger.info(f'Build TensorRT engine{network.trt_network.name}')tik=time.time()engine=self.trt_builder.build_serialized_network(network.trt_network,builder_config.trt_builder_config)assertengineisnotNone,'Engine building failed, please check the error log.'tok=time.time()t=time.strftime('%H:%M:%S',time.gmtime(tok-tik))logger.info(f'Total time of building{network.trt_network.name}:{t}')ifmanaged_weightsisnotNoneandnetwork.named_parametersisnotNone:forparaminmanaged_parameters:name=param.namevalue:np.ndarray=param._valueifvalueisNone:logger.error(f'Failed to get weight:{name}')continueifparam.need_transpose:# MOE has ndim=3 and uses plugin, no need to transposevalue=value.transpose(1,0)# WAR for bug 4641821managed_weights[name]=valuereturnengine@staticmethoddefsave_timing_cache(builder_config:BuilderConfig,out_path:str)->bool:'''Serialize timing cache of given builder config to file specified by out_path            return True if the cache is successfully serialized, False otherwise        '''cache=builder_config.trt_builder_config.get_timing_cache()ifcacheisNone:logger.warning('No timing cache found in the given builder config, skip saving.')returnFalsewithcache.serialize()asbuffer:withopen(out_path,"wb")asf:f.write(buffer)f.flush()os.fsync(f)logger.info(f'Timing cache serialized to{out_path}')returnTrue@staticmethoddefsave_config(builder_config:BuilderConfig,config_path:str):config=builder_config.to_dict()to_json_file(config,config_path)logger.info(f'Config saved to{config_path}.')
[docs]classBuildConfig(BaseModel):"""Configuration class for TensorRT LLM engine building parameters. This class contains all the configuration parameters needed to build a TensorRT LLM engine, including sequence length limits, batch sizes, optimization settings, and various features. """max_input_len:int=Field(default=1024,description="Maximum length of input sequences.")max_seq_len:Optional[int]=Field(default=None,description="The maximum possible sequence length for a single request, including both input and generated ""output tokens.")opt_batch_size:int=Field(default=8,description="Optimal batch size for engine optimization.")max_batch_size:int=Field(default=2048,description="Maximum batch size the engine can handle.")max_beam_width:int=Field(default=1,description="Maximum beam width for beam search decoding.")max_num_tokens:int=Field(default=8192,description="Maximum number of batched input tokens after padding is ""removed in each batch.")opt_num_tokens:Optional[int]=Field(default=None,description="Optimal number of batched input tokens for engine optimization.")max_prompt_embedding_table_size:int=Field(default=0,description="Maximum size of prompt embedding table for prompt tuning.")kv_cache_type:Optional[KVCacheType]=Field(default=None,description="Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED.")gather_context_logits:bool=Field(default=False,description="Whether to gather logits during context phase.")gather_generation_logits:bool=Field(default=False,description="Whether to gather logits during generation phase.")strongly_typed:bool=Field(default=True,description="Whether to use strongly_typed.")force_num_profiles:Optional[int]=Field(default=None,description="Force a specific number of optimization profiles. If None, auto-determined.")profiling_verbosity:str=Field(default='layer_names_only',description="Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none').")enable_debug_output:bool=Field(default=False,description="Whether to enable debug output during building.")max_draft_len:int=Field(default=0,description="Maximum length of draft tokens for speculative decoding.")speculative_decoding_mode:SpeculativeDecodingMode=Field(default=SpeculativeDecodingMode.NONE,description="Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.).")use_refit:bool=Field(default=False,description="Whether to enable engine refitting capabilities.")input_timing_cache:Optional[str]=Field(default=None,description="Path to input timing cache file. If None, no input cache used.")output_timing_cache:str=Field(default='model.cache',description="Path to output timing cache file.")lora_config:LoraConfig=Field(default_factory=LoraConfig,description="Configuration for LoRA (Low-Rank Adaptation) fine-tuning.")weight_sparsity:bool=Field(default=False,description="Whether to enable weight sparsity optimization.")weight_streaming:bool=Field(default=False,description="Whether to enable weight streaming for large models.")plugin_config:PluginConfig=Field(default_factory=PluginConfig,description="Configuration for TensorRT LLM plugins.")use_strip_plan:bool=Field(default=False,description="Whether to use stripped plan for engine building.")max_encoder_input_len:int=Field(default=1024,description="Maximum encoder input length for encoder-decoder models.")dry_run:bool=Field(default=False,description="Whether to perform a dry run without actually building the engine.")visualize_network:Optional[str]=Field(default=None,description="Path to save network visualization. If None, no visualization generated.")monitor_memory:bool=Field(default=False,description="Whether to monitor memory usage during building.")use_mrope:bool=Field(default=False,description="Whether to use Multi-RoPE (Rotary Position Embedding) optimization.")# Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),# we need to handle it given model architecture.
[docs]defupdate_kv_cache_type(self,model_architecture:str):paged_kv_cache_attr='paged_state'ifmodel_architecturein['MambaForCausalLM','RecurrentGemmaForCausalLM']else'paged_kv_cache'assertself.plugin_configisnotNonepaged_kv_cache_val=getattr(self.plugin_config,paged_kv_cache_attr)ifself.kv_cache_typeisnotNone:ifpaged_kv_cache_valisnotNone:assert(paged_kv_cache_val==Trueandself.kv_cache_type==KVCacheType.PAGED)or(paged_kv_cache_val==Falseandself.kv_cache_type!=KVCacheType.PAGED)else:setattr(self.plugin_config,paged_kv_cache_attr,self.kv_cache_type==KVCacheType.PAGED)else:ifpaged_kv_cache_valisnotNone:self.kv_cache_type=KVCacheType.PAGEDifpaged_kv_cache_valelseKVCacheType.CONTINUOUSelse:self.kv_cache_type=KVCacheType.PAGEDsetattr(self.plugin_config,paged_kv_cache_attr,self.kv_cache_type==KVCacheType.PAGED)assertself.kv_cache_typeisnotNoneandgetattr(self.plugin_config,paged_kv_cache_attr)isnotNonedefoverride_attri(attr_name,value):val=getattr(self.plugin_config,attr_name)ifvalisnotNoneandval!=value:logger.warning(f'Overriding{attr_name} to{value}')setattr(self.plugin_config,attr_name,value)# Init other paged kvcache attri to false. For RecurrentGemma, we only support paged_state and paged_kv_cache have# the same values. All other models should only consume either of the value and set other to False.is_recurrent_gemma=model_architecture=='RecurrentGemmaForCausalLM'ifpaged_kv_cache_attr=='paged_state':override_attri('paged_kv_cache',getattr(self.plugin_config,paged_kv_cache_attr)ifis_recurrent_gemmaelseFalse)else:override_attri('paged_state',False)
[docs]@classmethoddeffrom_json_file(cls,config_file):withopen(config_file)asf:config=json.load(f)returnBuildConfig(**config)
classEngineConfig:def__init__(self,pretrained_config:'PretrainedConfig',build_config:'BuildConfig',version:str):self.pretrained_config=pretrained_configself.build_config=build_configself.version=version@classmethoddeffrom_json_file(cls,config_file):withopen(config_file)asf:returncls.from_json_str(f.read())@classmethoddeffrom_json_str(cls,config_str):config=json.loads(config_str)returncls(PretrainedConfig.from_dict(config['pretrained_config']),BuildConfig(**config['build_config']),config['version'])defto_dict(self):build_config=self.build_config.model_dump(mode="json")build_config.pop('dry_run',None)# Not an Engine Characteristicbuild_config.pop('visualize_network',None)# Not an Engine Characteristicreturn{'version':self.version,'pretrained_config':self.pretrained_config.to_dict(),'build_config':build_config,}classEngine:def__init__(self,config:EngineConfig,engine:Union[trt.IHostMemory,None],managed_weights:dict[str,np.ndarray]={},):self.config=configself.engine=engineself.managed_weights=managed_weightsifself.managed_weightsisNone:self.managed_weights={}forname,valueinself.managed_weights.items():ifnotvalue.flags['C_CONTIGUOUS']:self.managed_weights[name]=np.ascontiguousarray(value)defsave(self,engine_dir:str):os.makedirs(engine_dir,exist_ok=True)lora_config=self.config.build_config.lora_configlora_dirs=lora_config.lora_dirroot_lora_dir=os.path.join(engine_dir,'lora')iflen(lora_dirs)>0:os.makedirs(root_lora_dir,exist_ok=True)forindex,lora_dirinenumerate(lora_dirs):iflora_config.lora_ckpt_source=='hf':target_lora_dir=f"{root_lora_dir}/{index}"os.makedirs(target_lora_dir,exist_ok=True)shutil.copy2(os.path.join(lora_dir,'adapter_config.json'),target_lora_dir)weight_file=os.path.join(lora_dir,'adapter_model.bin')ifos.path.exists(weight_file):shutil.copy2(weight_file,target_lora_dir)weight_file=os.path.join(lora_dir,'adapter_model.safetensors')ifos.path.exists(weight_file):shutil.copy2(weight_file,target_lora_dir)lora_config.lora_dir[index]=f"lora/{index}"eliflora_config.lora_ckpt_source=='nemo':target_lora_file=f"{root_lora_dir}/{index}.nemo"shutil.copyfile(lora_dir,target_lora_file)lora_config.lora_dir[index]=f"lora/{index}.nemo"else:ifos.path.exists(root_lora_dir)andos.path.isdir(root_lora_dir):shutil.rmtree(root_lora_dir)ifself.config.pretrained_config.mapping.rank==0:config_dict=self.config.to_dict()ifself.config.pretrained_config.quant_algo==QuantAlgo.MIXED_PRECISION:quant_dict={'version':self.config.version,}quant_dict.update(config_dict['pretrained_config']['quantization'])config_dict['pretrained_config']['quantization'].pop('quantized_layers',None)withopen(os.path.join(engine_dir,'quant_cfg.json'),"w",encoding="utf-8")asf:json.dump(quant_dict,f,indent=4,cls=ConfigEncoder)withopen(os.path.join(engine_dir,'config.json'),"w",encoding="utf-8")asf:json.dump(config_dict,f,indent=4,cls=ConfigEncoder)ifself.engineisnotNone:serialize_engine(self.engine,os.path.join(engine_dir,f'rank{self.config.pretrained_config.mapping.rank}.engine'))ifself.managed_weightsisnotNoneandlen(self.managed_weights)>0:fn=os.path.join(engine_dir,f'rank{self.config.pretrained_config.mapping.rank}_managed_weights.safetensors')serialize_managed_weights(self.managed_weights,fn)@classmethoddeffrom_dir(cls,engine_dir:str,rank:int=0):withopen(os.path.join(engine_dir,f'rank{rank}.engine'),'rb')asf:engine_buffer=f.read()mw_path=os.path.join(engine_dir,f'rank{rank}_managed_weights.safetensors')managed_weights=deserialize_managed_weights(mw_path)ifos.path.exists(mw_path)elseNoneconfig=EngineConfig.from_json_file(os.path.join(engine_dir,'config.json'))config.pretrained_config.set_rank(rank)returncls(config,engine_buffer,managed_weights)@classmethoddeffrom_buffer(cls,engine_buffer:Union[trt.IHostMemory,bytes],json_config_str:str,rank:int=0):config=EngineConfig.from_json_str(json_config_str)config.pretrained_config.set_rank(rank)returncls(config,engine_buffer)defget_engine_version(engine_dir:str)->Union[None,str]:engine_dir=Path(engine_dir)config_path=engine_dir/"config.json"withopen(config_path,'r')asf:config=json.load(f)if'version'notinconfig:returnNonereturnconfig['version']defoptimize_model_with_config(model:PretrainedModel,build_config:BuildConfig):gemm_swiglu_plugin=build_config.plugin_config.gemm_swiglu_pluginlow_latency_gemm_swiglu_plugin=build_config.plugin_config.low_latency_gemm_swiglu_pluginifgemm_swiglu_pluginorlow_latency_gemm_swiglu_plugin:ifnotbuild_config.plugin_config.use_fused_mlp:raiseRuntimeError("GemmSwiGLU plugin requires --use_fused_mlp flag")ifgemm_swiglu_pluginnotin["fp8"]andlow_latency_gemm_swiglu_pluginnotin["fp8"]:raiseRuntimeError(f"GemmSwiGLU plugin currently has limited support: fp8 only, "f"got:{gemm_swiglu_plugin}"f"got:{low_latency_gemm_swiglu_plugin}")ifbuild_config.plugin_config.lora_pluginisnotNone:model.use_lora(build_config.lora_config)is_enc_dec=model.config.architecturein["EncoderModel","DecoderModel"]# FusedMLP does not support RecurrentGemma FP8 currently.is_recurrent_gemma=model.config.architecturein["RecurrentGemmaForCausalLM"]is_fp8=model.config.quantization.quant_algo==QuantAlgo.FP8model=optimize_model(model,share_embedding_table=True,use_ootb_moe=build_config.plugin_config.moe_pluginisNone,use_fused_mlp=(build_config.plugin_config.use_fused_mlpandnotis_enc_decandnot(is_recurrent_gemmaandis_fp8)),gemm_swiglu_plugin_dtype=gemm_swiglu_plugin,low_latency_gemm_swiglu_plugin_dtype=low_latency_gemm_swiglu_plugin,use_fused_rg_lru=is_recurrent_gemma,use_unfused_qkv_gemm=False,use_prompt_tuning=(build_config.max_prompt_embedding_table_size>0),use_lora=build_config.plugin_config.lora_pluginisnotNone,max_lora_rank=build_config.lora_config.max_lora_rank,use_fp8_context_fmha=(model.config.quantization.quant_algoin[QuantAlgo.FP8,QuantAlgo.W4A8_AWQ,QuantAlgo.NVFP4]andbuild_config.plugin_config.use_fp8_context_fmha),fuse_fp4_quant=build_config.plugin_config.fuse_fp4_quant,use_optimize_cross_qkv=True,use_dora=build_config.plugin_config.dora_plugin)ifis_enc_dec:model.precompute_relative_attention_bias(build_config)returnmodeldef_init_max_seq_len(model_config,build_config):""" If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values. """# Extract rotary scaling which will be used for checks and default value of max_seq_lenrotary_scaling=getattr(model_config,"rotary_scaling",None)ifrotary_scalingisnotNone:rotary_type=rotary_scaling.get('type',rotary_scaling.get('rope_type'))rotary_factor=rotary_scaling.get('factor',1.0)ifrotary_typenotin("su","longrope","llama3")else1else:rotary_factor=1ifmodel_config.architecture=="EncoderModel":ifbuild_config.max_seq_lenisNone:build_config.max_seq_len=build_config.max_input_lenlogger.info(f'max_seq_len is not specified for EncoderModel, using --max_input_len.')assertbuild_config.max_input_len==build_config.max_seq_len,f"EncoderModel should have same --max_input_len ({build_config.max_input_len}) and --max_seq_len ({build_config.max_seq_len})."ifbuild_config.max_seq_lenisNone:# Step 1: Find the upper bound of max_seq_lendeduced_max_seq_len=2048ifmodel_config.max_position_embeddingsisnotNone:deduced_max_seq_len=model_config.max_position_embeddings# Step 2: Scale max_seq_len with rotary scalingifrotary_factor!=1:deduced_max_seq_len=math.ceil(deduced_max_seq_len*rotary_factor)logger.warning(f'max_seq_len is scaled to{deduced_max_seq_len} by rotary scaling{rotary_factor}')# Step 3: Assign the new max_seq_lenbuild_config.max_seq_len=int(deduced_max_seq_len)logger.info(f'max_seq_len is not specified, using deduced value{deduced_max_seq_len}')else:ifnotbuild_config.plugin_config.streamingllmandmodel_config.max_position_embeddingsisnotNone \andmodel_config.position_embedding_type!=PositionEmbeddingType.relative:ifbuild_config.max_seq_len>model_config.max_position_embeddings*rotary_factor:logger.warning(f'max_seq_len{build_config.max_seq_len} is larger than max_position_embeddings{model_config.max_position_embeddings} * rotary scaling{rotary_factor}, ''the model accuracy might be affected')ifbuild_config.max_input_len>build_config.max_seq_len:logger.warning(f'max_input_len is{build_config.max_input_len} is larger than max_seq_len{build_config.max_seq_len}, clipping it to max_seq_len')build_config.max_input_len=build_config.max_seq_len# Check and may modify max_num_tokens and opt_num_tokens (need to happen after max_seq_len is deduced)max_num_tokens,opt_num_tokens=check_max_num_tokens(max_num_tokens=build_config.max_num_tokens,opt_num_tokens=build_config.opt_num_tokens,max_batch_size=build_config.max_batch_size,max_input_len=build_config.max_input_len,max_seq_len=build_config.max_seq_len,max_beam_width=build_config.max_beam_width,remove_input_padding=build_config.plugin_config.remove_input_padding,enable_context_fmha=build_config.plugin_config.context_fmha,tokens_per_block=build_config.plugin_config.tokens_per_block,multiple_profiles=build_config.plugin_config.multiple_profiles,)build_config.max_num_tokens,build_config.opt_num_tokens=max_num_tokens,opt_num_tokensifbuild_config.plugin_config.remove_input_paddingandbuild_config.plugin_config.context_fmha:ifbuild_config.max_input_len:logger.warning('padding removal and fMHA are both enabled, max_input_len is not required and will be ignored')else:assertbuild_config.max_input_lenisnotNone,'padding removal and fMHA aren\'t both enabled, max_input_len is required'ifbuild_config.max_seq_len:assertbuild_config.max_input_len<=build_config.max_seq_len,'max_input_len should not be larger than max_seq_len'defserialize_managed_weights(managed_weights:dict[str,np.ndarray],path:str|Path,metadata=None)->None:header={}ifmetadataisnotNone:header["__metadata__"]=metadatabegin=0forname,valueinmanaged_weights.items():size=value.size*value.itemsizeifvalue.dtype==np.float32:dtype="F32"elifvalue.dtype==np.float16:dtype="F16"elifvalue.dtype==np_bfloat16:dtype="BF16"elifvalue.dtype==np_float8:dtype="F8_E4M3"elifvalue.dtype==np.int64:dtype="I64"elifvalue.dtype==np.int32:dtype="I32"elifvalue.dtype==np.int8:dtype="I8"else:raiseRuntimeError(f"Unsupported dtype:{value.dtype}")header[name]={"dtype":dtype,"shape":value.shape,"data_offsets":[begin,begin+size],}begin+=sizeheader_json=json.dumps(header)header_json_len=len(header_json)withopen(path,"wb")asf:logger.info(f"Serializing{len(managed_weights)} managed weights to{path}...")f.write(header_json_len.to_bytes(8,byteorder="little"))f.write(header_json.encode())forname,valueinmanaged_weights.items():logger.debug(f"Serializing managed weight:{name}")buf=value.dataf.write(buf)defdeserialize_managed_weights(path:str|Path)->dict[str,np.ndarray]:withopen(path,"rb")asf:header_json_len=int.from_bytes(f.read(8),byteorder="little")header_json=f.read(header_json_len).decode()header=json.loads(header_json)managed_weights={}forname,infoinheader.items():dtype=info["dtype"]shape=info["shape"]data_offsets=info["data_offsets"]ifdtype=="F32":dtype=np.float32elifdtype=="F16":dtype=np.float16elifdtype=="BF16":dtype=np_bfloat16elifdtype=="F8_E4M3":dtype=np_float8elifdtype=="I64":dtype=np.int64elifdtype=="I32":dtype=np.int32else:raiseRuntimeError(f"Unsupported dtype:{dtype}")f.seek(data_offsets[0]+header_json_len+8)buf=f.read(data_offsets[1]-data_offsets[0])value=np.frombuffer(buf,dtype=dtype).reshape(shape)managed_weights[name]=valuereturnmanaged_weightsdefbuild(model:PretrainedModel,build_config:BuildConfig)->Engine:'''Build engine from given model and optimization options specified in the build_config WARNING: this function may change the given model object state in some optimization passes to avoid cloning a model since normally the LLM models consumes large memory. Create a new fresh model object if you need to build with different options. '''tic=time.time()# avoid changing the input configbuild_config=build_config.model_copy(deep=True)build_config.plugin_config.dtype=model.config.dtypebuild_config.update_kv_cache_type(model.config.architecture)_init_max_seq_len(model.config,build_config)ifbuild_config.plugin_config.streamingllm:build_config.plugin_config.use_paged_context_fmha=Falselogger.warning("Paged Context FMHA is disabled because StreamingLLM is not supported when enabling paged KV context FMHA.")ifbuild_config.plugin_config.reduce_fusionand(model.config.mapping.tp_size==1or(model.config.architecture!="LlamaForCausalLM"andmodel.config.architecture!="Gemma2ForCausalLM"andmodel.config.architecture!="MedusaForCausalLM")):logger.warning('Overriding reduce_fusion to False')build_config.plugin_config.reduce_fusion=Falseifbuild_config.plugin_config.user_bufferandnotbuild_config.plugin_config.reduce_fusion:logger.warning('Overriding user_buffer to False')build_config.plugin_config.user_buffer=Falseifbuild_config.plugin_config.norm_quant_fusionand(build_config.plugin_config.reduce_fusionormodel.config.architecture!="LlamaForCausalLM"ormodel.config.quantization.quant_algo!=QuantAlgo.NVFP4):logger.warning('Overriding norm_quant_fusion to False')build_config.plugin_config.norm_quant_fusion=Falseifmodel.config.quantization.quant_algo==QuantAlgo.FP8or \model.config.quantization.kv_cache_quant_algo==QuantAlgo.FP8:build_config.strongly_typed=Trueifhasattr(model.config,'max_draft_len'):# If model.config has 'max_draft_len' but build_config not specified,# use the value of model.config.max_draft_len to set the value of build_config.max_draft_lenifbuild_config.max_draft_len==0:build_config.max_draft_len=model.config.max_draft_lenifhasattr(model.config,'redrafter_num_beams')andhasattr(model.config,'redrafter_draft_len_per_beam'):build_config.max_draft_len=model.config.redrafter_num_beams*model.config.redrafter_draft_len_per_beamifbuild_config.speculative_decoding_mode!=SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS:logger.warning('speculative_decoding_mode is not EXPLICIT_DRAFT_TOKENS for ReDrafter model. Overwriting speculative_decoding_mode')build_config.speculative_decoding_mode=SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENSifbuild_config.speculative_decoding_mode!=SpeculativeDecodingMode.NONE:logger.info(f'Increasing max_seq_len ({build_config.max_seq_len}) 'f'by max_draft_len ({build_config.max_draft_len}) ''to account for speculative decoding implementation specifics. ''Maximum number of generated tokens remains the same. 'f'New max_seq_len is set to{build_config.max_seq_len+build_config.max_draft_len}')build_config.max_seq_len+=build_config.max_draft_lenifbuild_config.speculative_decoding_mode==SpeculativeDecodingMode.EAGLE:asserthasattr(model.config,'num_eagle_layers')num_eagle_layers=model.config.num_eagle_layerslogger.info(f'Increasing max_seq_len ({build_config.max_seq_len}) 'f'by num_eagle_layers ({num_eagle_layers}) ''to account for EAGLE implementation specifics. ''Maximum number of generated tokens remains the same. 'f'New max_seq_len is set to{build_config.max_seq_len+num_eagle_layers}')build_config.max_seq_len+=num_eagle_layersifbuild_config.speculative_decoding_mode!=SpeculativeDecodingMode.NONE:num_tokens=build_config.max_batch_size*(build_config.max_draft_len+1)ifbuild_config.max_num_tokens<num_tokens:logger.info(f'max_num_tokens ({build_config.max_num_tokens}) is smaller than ''max_batch_size * (max_draft_len + 1) = 'f'({build_config.max_batch_size} * ({build_config.max_draft_len} + 1)). 'f'New max_num_tokens is set to{num_tokens}.')build_config.max_num_tokens=num_tokens# Logics to control paged_context_fmha and fp8_context_fmhaifnotbuild_config.plugin_config.context_fmha:build_config.plugin_config.use_fp8_context_fmha=Falsebuild_config.plugin_config.use_paged_context_fmha=Falselogger.warning("Context FMHA is disabled, FP8 Context FMHA and Paged Context FMHA are disabled.")elifnotmodel.config.quantization.quant_algoin[QuantAlgo.FP8,QuantAlgo.W4A8_AWQ,QuantAlgo.NVFP4]:ifbuild_config.plugin_config.use_fp8_context_fmha:build_config.plugin_config.use_fp8_context_fmha=Falselogger.warning("FP8 Context FMHA is disabled because it must be used together with the fp8 quantization workflow.")ifbuild_config.plugin_config.use_paged_context_fmhaandmodel.config.quant_mode.has_fp8_kv_cache():build_config.plugin_config.use_paged_context_fmha=Falselogger.warning("FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.")elifget_sm_version()<89:build_config.plugin_config.use_fp8_context_fmha=Falselogger.warning("FP8 context FMHA is disabled because it is only supported on Ada and Hopper Arch.")ifbuild_config.plugin_config.use_paged_context_fmhaandmodel.config.quant_mode.has_fp8_kv_cache():build_config.plugin_config.use_paged_context_fmha=Falselogger.warning("FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.")elifbuild_config.plugin_config.use_paged_context_fmha:ifnotmodel.config.quant_mode.has_fp8_kv_cache()andbuild_config.plugin_config.use_fp8_context_fmha:build_config.plugin_config.use_fp8_context_fmha=Falselogger.warning("FP8 Paged Context FMHA is disabled because it must be used together with fp8 KV Cache.")elifmodel.config.quant_mode.has_fp8_kv_cache()andnotbuild_config.plugin_config.use_fp8_context_fmha:build_config.plugin_config.use_fp8_context_fmha=Truelogger.warning("FP8 Context FMHA is enabled to support FP8 Paged Context FMHA.")ifbuild_config.plugin_config.use_paged_context_fmhaandmodel.config.quant_mode.has_int8_kv_cache():build_config.plugin_config.use_paged_context_fmha=Falselogger.warning("Paged Context FMHA is disabled because it doesn't work with int8 kv cache currently.")ifget_sm_version()>=100andget_sm_version()<120:ifmodel.config.quant_mode.is_int8_weight_only()ormodel.config.quant_mode.is_int4_weight_only()ormodel.config.quant_mode.has_int8_kv_cache():raiseRuntimeError("INT8/INT4 quantization is not supported on SM>=100.")ifmodel.config.quant_mode.has_act_and_weight_quant():raiseRuntimeError("SmoothQuant is not supported on SM>=100.")ifmodel.config.quant_mode.has_per_channel_scaling()ormodel.config.quant_mode.has_per_token_dynamic_scaling():raiseRuntimeError("Per-channel or per-token scaling is not supported on SM>=100.")model=optimize_model_with_config(model,build_config)builder=Builder()builder_config=builder.create_builder_config(precision=model.config.dtype,use_refit=build_config.use_refit,timing_cache=build_config.input_timing_cache,int8=(model.config.quant_mode.has_act_or_weight_quant()andnotmodel.config.quant_mode.has_per_group_scaling())ormodel.config.quant_mode.has_int8_kv_cache(),strongly_typed=build_config.strongly_typed,force_num_profiles=build_config.force_num_profiles,profiling_verbosity=build_config.profiling_verbosity,quant_mode=model.config.quant_mode,use_strip_plan=build_config.use_strip_plan,weight_sparsity=build_config.weight_sparsity,weight_streaming=build_config.weight_streaming,monitor_memory=build_config.monitor_memory,)network=builder.create_network()network.plugin_config=build_config.plugin_configuse_weight_only=model.config.quant_mode.is_weight_only()per_group=model.config.quant_mode.has_per_group_scaling()use_smooth_quant=model.config.quant_mode.has_act_and_weight_quant()use_qserve=model.config.quant_mode.is_qserve_w4a8()use_fp8_rowwise=model.config.quant_mode.has_fp8_rowwise()disable_weight_only_quant_plugin=model.config.disable_weight_only_quant_pluginifhasattr(model.config,'disable_weight_only_quant_plugin')elseFalseuse_fp8_rowwise=model.config.quant_mode.has_fp8_rowwise()use_fp4_gemm=model.config.quant_mode.has_nvfp4()ifuse_fp4_gemmandnetwork.plugin_config._explicitly_disable_gemm_pluginisFalse:logger.info('NVFP4 quantization detected, by default enabling NVFP4 GEMM plugin. To use OOTB GEMM, please explicitly set gemm_plugin to "disable"')network.plugin_config.gemm_plugin="nvfp4"ifbuild_config.plugin_config.manage_weights:ifuse_weight_onlyanddisable_weight_only_quant_plugin:raiseRuntimeError("Manage weights of weight only quant works only with plugin currently.")ifuse_weight_onlyandnotdisable_weight_only_quant_plugin:ifper_group:network.plugin_config.weight_only_groupwise_quant_matmul_plugin=model.config.dtypeelse:network.plugin_config.weight_only_quant_matmul_plugin=model.config.dtypeifuse_smooth_quantandmodel.config.quantization._use_plugin_sqandbuild_config.plugin_config.smooth_quant_plugins:network.plugin_config.set_smooth_quant_plugins(model.config.dtype)ifuse_qserve:network.plugin_config.set_qserve_plugins(model.config.dtype)ifuse_fp8_rowwise:network.plugin_config.set_fp8_rowwise_quant_plugins(model.config.dtype)nccl_plugin=model.config.dtypeifmodel.config.mapping.world_size>1elseNonenetwork.plugin_config.set_nccl_plugin(nccl_plugin)withnet_guard(network):# Preparenetwork.set_named_parameters(model.named_parameters())# Forwardprepare_input_args={"max_batch_size":build_config.max_batch_size,"max_input_len":build_config.max_input_len,"max_seq_len":build_config.max_seq_len,"use_cache":build_config.kv_cache_type!=KVCacheType.DISABLED,"max_beam_width":build_config.max_beam_width,"max_num_tokens":build_config.max_num_tokens,"opt_num_tokens":build_config.opt_num_tokens,"prompt_embedding_table_size":build_config.max_prompt_embedding_table_size,"max_draft_len":build_config.max_draft_len,"speculative_decoding_draft_tokens_external":build_config.speculative_decoding_mode==SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL,"gather_context_logits":build_config.gather_context_logits,"lora_target_modules":build_config.lora_config.lora_target_modules}ifmodel.config.architecture=="DecoderModel"or"mllama"inmodel.config.architecture.lower():prepare_input_args["max_seq_len"]=build_config.max_seq_lenprepare_input_args["max_decoder_input_len"]=build_config.max_input_lenprepare_input_args["max_encoder_input_len"]=build_config.max_encoder_input_lenifmodel.config.architecture=="WhisperEncoder":prepare_input_args={"max_batch_size":build_config.max_batch_size,}ifbuild_config.speculative_decoding_mode==SpeculativeDecodingMode.EAGLE:prepare_input_args["spec_decoding_is_generation_length_variable"]=Trueassertbuild_config.max_batch_size<=512,"Max batch size > 512 is not supported for EAGLE"assertbuild_config.max_draft_len<=256,"Max draft len > 256 is not supported for EAGLE"ifbuild_config.speculative_decoding_mode==SpeculativeDecodingMode.LOOKAHEAD_DECODING:prepare_input_args["spec_decoding_is_generation_length_variable"]=Trueifmodel.config.architecture=="Qwen2VLForConditionalGeneration"ormodel.config.architecture=="Qwen2VLModel":prepare_input_args['mrope_rotary_cos_sin_size']=model.config.max_position_embeddings*model.config.rotary_embedding_dimifbuild_config.speculative_decoding_mode==SpeculativeDecodingMode.EAGLEandnotbuild_config.plugin_config.use_paged_context_fmha:logger.warning("Paged Context FMHA is required for EAGLE. Turning it on")build_config.plugin_config.use_paged_context_fmha=Trueinputs=model.prepare_inputs(**prepare_input_args)model(**inputs)ifbuild_config.enable_debug_output:fork,vinmodel.named_network_outputs():network._mark_output(v,k,str_dtype_to_trt(model.config.dtype))ifmodel.config.architecture!="DecoderModel":optimize(network)ifbuild_config.visualize_networkisnotNone:withnet_guard(network):network.to_onnx(build_config.visualize_network)# Network -> Enginelogger.info(f"Total time of constructing network from module object{time.time()-tic} seconds")managed_weights={}ifnetwork.plugin_config.manage_weightselseNoneengine=Noneifbuild_config.dry_runelsebuilder.build_engine(network,builder_config,managed_weights)engine_config=EngineConfig(model.config,build_config,__version__)ifbuild_config.output_timing_cacheisnotNoneandmodel.config.mapping.rank==0:ok=builder.save_timing_cache(builder_config,build_config.output_timing_cache)assertok,"Failed to save timing cache."importpsutil# Get the current processcurrent_process=psutil.Process()# Get resource usage for the current process (self)rusage_s=current_process.memory_info()# Get resource usage for all child processeschildren=current_process.children(recursive=True)rusage_c=[child.memory_info()forchildinchildren]logger.info(f"Build phase peak memory:{rusage_s.rss/1024/1024:.2f} MB, children:{sum([ru.rssforruinrusage_c])/1024/1024:.2f} MB")returnEngine(engine_config,engine,managed_weights)