Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[None][feat] Add Request specific exception#6931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
Shunkangz merged 14 commits intoNVIDIA:mainfromShunkangz:request_specific_exception
Sep 4, 2025
Merged
Show file tree
Hide file tree
Changes from1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PrevPrevious commit
NextNext commit
Propagate request id in error
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
  • Loading branch information
Shunkang authored andShunkang committedSep 4, 2025
commitb443f7012438dd7fbc76a2b6c2630e2f541e567d
7 changes: 5 additions & 2 deletionscpp/include/tensorrt_llm/common/tllmException.h
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -47,12 +47,15 @@ namespace tensorrt_llm::common
enum class RequestErrorCode : uint32_t
{
// General errors (0-999)
UNKNOWN_ERROR = 0,
kUNKNOWN_ERROR = 0,

// Network and communication errors (1000-1999)
NETWORK_ERROR = 1000,
kNETWORK_ERROR = 1000,
};

/// @brief Constant for unknown request ID
static constexpr uint64_t kUNKNOWN_REQUEST_ID = static_cast<uint64_t>(-1);

class TllmException : public std::runtime_error
{
public:
Expand Down
17 changes: 17 additions & 0 deletionscpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -22,6 +22,7 @@
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/utils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <future>
Expand DownExpand Up@@ -190,6 +191,13 @@ class DataResponder::Impl
mSender->release(id);
resp.mPromise.set_value();
}
catch (tensorrt_llm::common::RequestSpecificException const& e)
{
TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what());
auto new_exception
= NEW_TLLM_REQUEST_SPECIFIC_EXCEPTION_WITH_ERROR_CODE(id, e.getErrorCode(), "%s", e.what());
resp.mPromise.set_exception(std::make_exception_ptr(new_exception));
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what());
Expand DownExpand Up@@ -496,6 +504,15 @@ class DataRequester::Impl
requestSync(*requestAndPromise.mRequest);
requestAndPromise.mPromise->set_value();
}
catch (tensorrt_llm::common::RequestSpecificException const& err)
{
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s",
requestAndPromise.mRequest->mRequestId,
requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what());
auto new_exception = NEW_TLLM_REQUEST_SPECIFIC_EXCEPTION_WITH_ERROR_CODE(
requestAndPromise.mRequest->mRequestId, err.getErrorCode(), "%s", err.what());
requestAndPromise.mPromise->set_exception(std::make_exception_ptr(new_exception));
}
catch (std::exception const& err)
{
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s",
Expand Down
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -20,11 +20,16 @@

#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/executor/cache_transmission/ucx_utils/connection.h"

namespace tensorrt_llm::executor::kv_cache
{

// Using declarations to shorten the code
using RequestSpecificException = tensorrt_llm::common::RequestSpecificException;
using RequestErrorCode = tensorrt_llm::common::RequestErrorCode;

UcxConnection::UcxConnection(ConnectionIdType connectionId, std::shared_ptr<ucxx::Endpoint> endpoint,
UcxConnectionManager* manager, bool fromRequester)
: mConnectionId(connectionId)
Expand DownExpand Up@@ -128,56 +133,74 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s

void UcxConnection::send(DataContext const& ctx, void const* data, size_t size) const
{
if (ctx.getTag() == batch_manager::TransceiverTag::kID_TAG)
try
{
sendConnectionId(ctx, data, size);
return;
if (ctx.getTag() == batch_manager::TransceiverTag::kID_TAG)
{
sendConnectionId(ctx, data, size);
return;
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"start UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);

TLLM_CHECK_WITH_INFO((mEndpoint), "sendBuffer called without established communicator channel.");
std::promise<void> promise;
std::future<void> future = promise.get_future();
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
uint64_t sendTag = ((mSendTagPrefix & 0xFFFFFFFF) << 32) | (static_cast<uint64_t>(ctx.getTag()) & (0xFFFFFFFF));

auto req = mEndpoint->tagSend(const_cast<void*>(data), size, ucxx::Tag(sendTag), false, completionCallback);
if (!req->isCompleted())
{
future.get();
}
TLLM_CHECK_WITH_INFO(req->isCompleted(), "send should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"end UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"start UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);

TLLM_CHECK_WITH_INFO((mEndpoint), "sendBuffer called without established communicator channel.");
std::promise<void> promise;
std::future<void> future = promise.get_future();
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
uint64_t sendTag = ((mSendTagPrefix & 0xFFFFFFFF) << 32) | (static_cast<uint64_t>(ctx.getTag()) & (0xFFFFFFFF));

auto req = mEndpoint->tagSend(const_cast<void*>(data), size, ucxx::Tag(sendTag), false, completionCallback);
if (!req->isCompleted())
catch (std::exception const& e)
{
future.get();
// Convert any exception to RequestSpecificException
// Use unknown request ID and NETWORK_ERROR as error code
throw RequestSpecificException(
__FILE__, __LINE__, e.what(), tensorrt_llm::common::kUNKNOWN_REQUEST_ID, RequestErrorCode::kNETWORK_ERROR);
}
TLLM_CHECK_WITH_INFO(req->isCompleted(), "send should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"end UcxConnection::send , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}

void UcxConnection::recv(DataContext const& ctx, void* data, size_t size) const
{
// Guard to ensure CUDA context is initialized for UCX ops
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"start UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
TLLM_CHECK_WITH_INFO((mEndpoint), "recvBuffer called without established communicator channel.");
std::promise<void> promise;
std::future<void> future = promise.get_future();
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
uint64_t recvTag = ((mRecvTagPrefix & 0xFFFFFFFF) << 32) | (static_cast<uint64_t>(ctx.getTag()) & (0xFFFFFFFF));
auto req = mEndpoint->tagRecv(data, size, ucxx::Tag(recvTag), ucxx::TagMaskFull, false, completionCallback);
if (!req->isCompleted())
try
{
future.get();
// Guard to ensure CUDA context is initialized for UCX ops
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"start UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
TLLM_CHECK_WITH_INFO((mEndpoint), "recvBuffer called without established communicator channel.");
std::promise<void> promise;
std::future<void> future = promise.get_future();
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
uint64_t recvTag = ((mRecvTagPrefix & 0xFFFFFFFF) << 32) | (static_cast<uint64_t>(ctx.getTag()) & (0xFFFFFFFF));
auto req = mEndpoint->tagRecv(data, size, ucxx::Tag(recvTag), ucxx::TagMaskFull, false, completionCallback);
if (!req->isCompleted())
{
future.get();
}
TLLM_CHECK_WITH_INFO(req->isCompleted(), "recv should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"end UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}
catch (std::exception const& e)
{
throw RequestSpecificException(
__FILE__, __LINE__, e.what(), tensorrt_llm::common::kUNKNOWN_REQUEST_ID, RequestErrorCode::kNETWORK_ERROR);
}
TLLM_CHECK_WITH_INFO(req->isCompleted(), "recv should be completed");
// throw if there is error
req->checkError();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"end UcxConnection::recv , mConnectionId: %lu, mConnectionIdInPeer: %lu,fromRequester: %d", mConnectionId,
mConnectionIdInPeer, mFromRequester);
}

} // namespace tensorrt_llm::executor::kv_cache
Expand Down
4 changes: 2 additions & 2 deletionscpp/tensorrt_llm/pybind/common/tllmExceptions.cpp
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -26,8 +26,8 @@ void initExceptionsBindings(py::module_& m)
{
// Bind the RequestErrorCode enum
py::enum_<tc::RequestErrorCode>(m, "RequestErrorCode")
.value("UNKNOWN_ERROR", tc::RequestErrorCode::UNKNOWN_ERROR)
.value("NETWORK_ERROR", tc::RequestErrorCode::NETWORK_ERROR)
.value("UNKNOWN_ERROR", tc::RequestErrorCode::kUNKNOWN_ERROR)
.value("NETWORK_ERROR", tc::RequestErrorCode::kNETWORK_ERROR)
.export_values();

// Create the RequestSpecificException Python exception class
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp