Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commiteaf8bec

Browse files
authored
fix: Disaggregate serving with attention DP (#4993)
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
1 parentc8fa08d commiteaf8bec

File tree

4 files changed

+15
-4
lines changed

4 files changed

+15
-4
lines changed

‎cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,9 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
517517
// Gather the kv cache transfer time from all workers and update to leader rank
518518
if (!common::getEnvKVCacheTransferOutputPath().empty())
519519
{
520-
updateKVCacheTransferBW(*mMpiGroupComm, it->first);
520+
auto syncComm
521+
=mCacheState->getParallelConfig().mEnableAttentionDP ?mMpiGroupDataComm.get() :mMpiGroupComm;
522+
updateKVCacheTransferBW(*syncComm, it->first);
521523
}
522524
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
523525
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",

‎cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,28 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
187187
NVTX3_SCOPED_RANGE(sendBufferFun);
188188

189189
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
190+
auto startTime =std::chrono::steady_clock::now();
190191
auto cacheIdx = processIdx % pPDomainSize;
192+
size_t size;
191193
if (cacheIdx < bufferCoverTargetNum)
192194
{
193-
195+
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
194196
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(cacheIdx), reqId);
195197
}
196198
elseif (bufferCoverTargetNum >0)
197199
{
198200
// copy buffer allocated by cudaMallocAsync to buffer allocated by cudaMalloc before sending
199201
auto sendBufferIdx = cacheIdx % bufferCoverTargetNum;
202+
size = outputSplitCaches.at(sendBufferIdx)->getSizeInBytes();
200203
bufferManager.copy(*outputSplitCaches.at(cacheIdx), *outputSplitCaches.at(sendBufferIdx));
201204
bufferManager.getStream().synchronize();
202205
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(sendBufferIdx), reqId);
203206
}
204207
else
205208
{
206-
207209
// bufferCoverTargetNum=0, mSendBuffer size < one outputSlice
208210
// send multiple times
211+
size = targetBufferSize;
209212
size_t remainSendSize = targetBufferSize;
210213
while (remainSendSize >0)
211214
{
@@ -222,6 +225,10 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
222225
remainSendSize -= sendSize;
223226
}
224227
}
228+
auto endTime =std::chrono::steady_clock::now();
229+
double cacheTransferTime
230+
=std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
231+
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, cacheTransferTime, size);
225232
};
226233

227234
if (connections.size() >1)

‎cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class MLACacheFormatter final : public BaseCacheFormatter
6464
private:
6565
BaseKVCacheManager*mCacheManager;
6666
CacheTransBufferManager*mCacheTransBufferManager;
67+
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
6768
};
6869

6970
}// namespace tensorrt_llm::batch_manager::kv_cache_manager

‎tensorrt_llm/_torch/pyexecutor/py_executor.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,8 @@ def _executor_loop(self):
911911

912912
finished_requests= []
913913

914-
ifscheduled_batch.batch_size>0:
914+
ifscheduled_batch.batch_size>0or (
915+
self.enable_attention_dpandself.dist.tp_size>1):
915916
ifself.kv_cache_transceiver:
916917
# For generation requests which have completed KV cache transfer
917918
self._prepare_disagg_gen_transmission_complete(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp