KV Cache Connector#

SourceNVIDIA/TensorRT-LLM.

  1  2importos  3importsys  4fromdataclassesimportdataclass,field  5frompathlibimportPath  6fromtempfileimportTemporaryDirectory  7  8importclick  9importtorch 10 11fromtensorrt_llmimportLLM,SamplingParams,logger 12fromtensorrt_llm._torch.pyexecutor.kv_cache_connectorimport( 13KvCacheConnectorScheduler,KvCacheConnectorWorker,SchedulerOutput) 14fromtensorrt_llm.bindings.internal.batch_managerimportLlmRequest 15fromtensorrt_llm.llmapi.llm_argsimportKvCacheConnectorConfig,TorchLlmArgs 16 17# This is a simple example of the use of the KV cache connector. 18# It persists KV cache contents into a folder, and can load them back on subsequent runs. 19# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface. 20# NOTE: This example connector implementation is NOT suitable for production use. 21 22CONNECTOR_CACHE_FOLDER_KEY="CONNECTOR_CACHE_FOLDER" 23 24 25@dataclass 26classPersistentKvCacheConnectorMetadata: 27load:list[tuple[str,int]]=field(default_factory=list) 28save:list[tuple[str,int]]=field(default_factory=list) 29 30 31classPersistentKvCacheConnectorWorker(KvCacheConnectorWorker): 32 33def__init__(self,llm_args:TorchLlmArgs): 34super().__init__(llm_args) 35 36self.kv_cache_tensor=None 37 38defregister_kv_caches(self,kv_cache_tensor:torch.Tensor): 39assertself.kv_cache_tensorisNone,"KV cache tensor already registered" 40self.kv_cache_tensor=kv_cache_tensor 41 42defstart_load_kv(self,stream:torch.cuda.Stream): 43# Do all loads synchronously, and blockwise. 44forpath,block_idinself._metadata.load: 45cpu_tensor=torch.load(path,map_location="cpu") 46 47# Copy into the device block. 48self.kv_cache_tensor[block_id].copy_(cpu_tensor,non_blocking=False) 49 50defwait_for_layer_load(self,layer_idx:int,stream:torch.cuda.Stream): 51pass 52 53defsave_kv_layer(self,layer_idx:int,stream:torch.cuda.Stream): 54pass 55 56defwait_for_save(self,stream:torch.cuda.Stream): 57 58# Make sure the forward pass is complete before beginning our save. 59stream.synchronize() 60 61forpath,block_idinself._metadata.save: 62cpu_tensor=self.kv_cache_tensor[block_id].cpu() 63 64# Don't write anything if this specific block already exists. 65ifPath(path).exists(): 66continue 67 68# Do a blocking save to the file. This way, we only return once all saves are complete. 69torch.save(cpu_tensor,path) 70 71defget_finished( 72self,finished_gen_req_ids:list[int], 73started_loading_req_ids:list[int])->tuple[list[int],list[int]]: 74 75return[],[] 76 77 78classPersistentKvCacheConnectorLeader(KvCacheConnectorScheduler): 79 80def__init__(self,llm_args:TorchLlmArgs): 81super().__init__(llm_args) 82 83self.block_size=self._llm_args.kv_cache_config.tokens_per_block 84self.pending_loads={} 85 86self.cache_folder=os.environ.get(CONNECTOR_CACHE_FOLDER_KEY, 87"./connector_cache") 88 89os.makedirs(self.cache_folder,exist_ok=True) 90 91defbuild_connector_meta(self,scheduler_output:SchedulerOutput): 92# NOTE: This is a simplified implementation, and does not work with chunked prefill. 93 94metadata=PersistentKvCacheConnectorMetadata() 95 96forreqinscheduler_output.new_requests: 97# If we don't have any pending loads for this request, we can skip it. 98ifreq.request_idnotinself.pending_loads: 99continue100101num_computed_blocks=req.computed_position//self.block_size102block_ids=req.new_block_ids103104pending_load=self.pending_loads[req.request_id]105106forfile_path,block_posinzip(107pending_load,range(num_computed_blocks,len(block_ids))):108metadata.load.append((file_path,block_ids[block_pos]))109110# Break up the remainder of the token sequence into chunks.111chunks=self._chunk_tokens(req.new_tokens)112113# For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.114forblock_posinrange(num_computed_blocks+len(pending_load),115len(block_ids)):116iflen(chunks[block_pos])==self.block_size:117hashed_tokens=self._hash_tokens(chunks[block_pos])118119file_path=self._file_path(hashed_tokens)120121metadata.save.append((file_path,block_ids[block_pos]))122123self.pending_loads={}124125returnmetadata126127def_hash_tokens(self,tokens:list[int])->int:128returnabs(hash(tuple(tokens)))129130def_file_path(self,hash_value:int)->Path:131returnPath(self.cache_folder)/f"{hash_value}.pt"132133def_chunk_tokens(self,tokens:list[int])->list[list[int]]:134return[135tokens[i:i+self.block_size]136foriinrange(0,len(tokens),self.block_size)137]138139defget_num_new_matched_tokens(140self,request:LlmRequest,141num_computed_tokens:int)->tuple[int,bool]:142self.pending_loads[request.request_id]=[]143144# Don't bother with sequences with partial matches.145if(num_computed_tokens%self.block_size)!=0:146return0,False147148computed_blocks=num_computed_tokens//self.block_size149150# Get all the tokens that don't have a cache hit on device.151remaining_tokens=request.get_tokens(0)[computed_blocks*152self.block_size:]153154remaining_chunks=self._chunk_tokens(remaining_tokens)155156# For each chunk, check if it exists in our cache.157forchunkinremaining_chunks:158# Only do full blocks.159iflen(chunk)==self.block_size:160hashed_tokens=self._hash_tokens(chunk)161162file_path=self._file_path(hashed_tokens)163164# If we get a cache hit, we want to load it into device.165# Otherwise, we can stop looking.166iffile_path.exists():167self.pending_loads[request.request_id].append(file_path)168else:169break170171logger.info(172f"KV CONNECTOR: Matched{len(self.pending_loads[request.request_id])} blocks for request{request.request_id}"173)174175returnlen(176self.pending_loads[request.request_id])*self.block_size,False177178defrequest_finished(self,request:LlmRequest,179cache_block_ids:list[int])->bool:180# We don't do any asynchronous saving, so always return False181returnFalse182183defupdate_state_after_alloc(self,request:LlmRequest,184block_ids:list[int]):185pass186187188@click.command()189@click.argument("model",type=str)190defmain(model:str):191sys.path.append(os.path.join(192os.path.dirname(__file__),193"..",194))195196this_module=__file__[__file__.rfind("/")+1:__file__.rfind(".py")]197198kv_connector_config=KvCacheConnectorConfig(199connector_module=this_module,200connector_scheduler_class="PersistentKvCacheConnectorLeader",201connector_worker_class="PersistentKvCacheConnectorWorker",202)203204connector_cache_dir=TemporaryDirectory()205os.environ[CONNECTOR_CACHE_FOLDER_KEY]=connector_cache_dir.name206207llm=LLM(model=model,208backend="pytorch",209cuda_graph_config=None,210kv_connector_config=kv_connector_config)211212test_text=(213"Nvidia Corporation is an American technology company headquartered in Santa Clara, California."214"Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "215"system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "216"and mobile and automotive applications. Tell me about the company.")217218sampling_params=SamplingParams(max_tokens=32)219220output=llm.generate([test_text],sampling_params)221text0=output[0].outputs[0].text222223print("First output: ",text0)224print("Loading new LLM instance...")225226delllm227228llm=LLM(model=model,229backend="pytorch",230cuda_graph_config=None,231kv_connector_config=kv_connector_config)232233output=llm.generate([test_text],sampling_params)234text1=output[0].outputs[0].text235236print("Second output (using connector cache): ",text1)237238asserttext0==text1239240connector_cache_dir.cleanup()241242243if__name__=="__main__":244main()