Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite0cfd87

Browse files
committed
fix: Disaggregate serving malfunction when using attention dp
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
1 parent7137cc8 commite0cfd87

File tree

4 files changed

+15
-4
lines changed

4 files changed

+15
-4
lines changed

‎cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,9 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
522522
// Gather the kv cache transfer time from all workers and update to leader rank
523523
if (!common::getEnvKVCacheTransferOutputPath().empty())
524524
{
525-
updateKVCacheTransferBW(*mMpiGroupComm, it->first);
525+
auto syncComm
526+
=mCacheState->getParallelConfig().mEnableAttentionDP ?mMpiGroupDataComm.get() :mMpiGroupComm;
527+
updateKVCacheTransferBW(*syncComm, it->first);
526528
}
527529
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
528530
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",

‎cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,25 +185,28 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
185185
NVTX3_SCOPED_RANGE(sendBufferFun);
186186

187187
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
188+
auto startTime =std::chrono::steady_clock::now();
188189
auto cacheIdx = processIdx % pPDomainSize;
190+
size_t size;
189191
if (cacheIdx < bufferCoverTargetNum)
190192
{
191-
193+
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
192194
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(cacheIdx), reqId);
193195
}
194196
elseif (bufferCoverTargetNum >0)
195197
{
196198
// copy buffer allocated by cudaMallocAsync to buffer allocated by cudaMalloc before sending
197199
auto sendBufferIdx = cacheIdx % bufferCoverTargetNum;
200+
size = outputSplitCaches.at(sendBufferIdx)->getSizeInBytes();
198201
bufferManager.copy(*outputSplitCaches.at(cacheIdx), *outputSplitCaches.at(sendBufferIdx));
199202
bufferManager.getStream().synchronize();
200203
TransferHelper::sendBuffer(*connections.at(processIdx), *outputSplitCaches.at(sendBufferIdx), reqId);
201204
}
202205
else
203206
{
204-
205207
// bufferCoverTargetNum=0, mSendBuffer size < one outputSlice
206208
// send multiple times
209+
size = targetBufferSize;
207210
size_t remainSendSize = targetBufferSize;
208211
while (remainSendSize >0)
209212
{
@@ -220,6 +223,10 @@ void MLACacheFormatter::formatOutput(LlmRequest const& llmRequest,
220223
remainSendSize -= sendSize;
221224
}
222225
}
226+
auto endTime =std::chrono::steady_clock::now();
227+
double cacheTransferTime
228+
=std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
229+
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, cacheTransferTime, size);
223230
};
224231

225232
if (connections.size() >1)

‎cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class MLACacheFormatter final : public IOFormatter
7979
private:
8080
BaseKVCacheManager*mCacheManager{};
8181
CacheTransBufferManager*mCacheTransBufferManager;
82+
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
8283
};
8384

8485
}// namespace tensorrt_llm::batch_manager::kv_cache_manager

‎tensorrt_llm/_torch/pyexecutor/py_executor.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,8 @@ def _executor_loop(self):
895895

896896
finished_requests= []
897897

898-
ifscheduled_batch.batch_size>0:
898+
ifscheduled_batch.batch_size>0or (
899+
self.enable_attention_dpandself.dist.tp_size>1):
899900
ifself.kv_cache_transceiver:
900901
# For generation requests which have completed KV cache transfer
901902
self._prepare_disagg_gen_transmission_complete(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp