Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4d040b5

Browse files
authored
[None][chore] ucx establish connection with zmq (NVIDIA#6090)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
1 parent164acfa commit4d040b5

File tree

8 files changed

+206
-38
lines changed

8 files changed

+206
-38
lines changed

‎.gitmodules‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@
2323
[submodule "3rdparty/nanobind"]
2424
path=3rdparty/nanobind
2525
url=https://github.com/wjakob/nanobind
26+
[submodule "3rdparty/cppzmq"]
27+
path=3rdparty/cppzmq
28+
url=https://github.com/zeromq/cppzmq.git

‎3rdparty/cppzmq‎

Submodulecppzmq added atc94c207

‎cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
159159
{
160160
std::lock_guard<std::mutex>lock(mDllMutex);
161161
mWrapperLibHandle =dllOpen(UCX_WRAPPER_LIB_NAME);
162-
TLLM_CHECK_WITH_INFO(mWrapperLibHandle !=nullptr,"UCX wrapper library is not open correctly.");
162+
TLLM_CHECK_WITH_INFO(
163+
mWrapperLibHandle !=nullptr,"UCX wrapper library is not open correctly. error : %s",dlerror());
163164
auto load_sym = [](void* handle,charconst* name)
164165
{
165166
void* ret =dllGetSym(handle, name);

‎cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/CMakeLists.txt‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ if(ENABLE_UCX)
44
find_package(ucx REQUIRED)
55
find_package(ucxx REQUIRED)
66

7+
include_directories(${3RDPARTY_DIR}/cppzmq)
8+
9+
# Find and link ZMQ
10+
find_package(PkgConfig REQUIRED)
11+
pkg_check_modules(ZMQ REQUIRED libzmq)
12+
# Add the NIXL wrapper target
13+
714
add_library(${UCX_WRAPPER_TARGET} SHARED connection.cpp
815
ucxCacheCommunicator.cpp)
916
set_target_properties(
@@ -20,4 +27,8 @@ if(ENABLE_UCX)
2027
PRIVATE $<LINK_LIBRARY:WHOLE_ARCHIVE,ucxx::ucxx>)
2128
target_link_libraries(${UCX_WRAPPER_TARGET}PUBLIC ucxx::ucxx ucx::ucs)
2229
target_link_libraries(${UCX_WRAPPER_TARGET}PUBLIC${CUDA_RT_LIB})
30+
31+
# Add include directories
32+
target_include_directories(${UCX_WRAPPER_TARGET}PRIVATE${ZMQ_INCLUDE_DIRS})
33+
target_link_libraries(${UCX_WRAPPER_TARGET}PRIVATE${ZMQ_LIBRARIES})
2334
endif()

‎cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp‎

Lines changed: 175 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,56 @@
2222
#include<exception>
2323
#include<iostream>
2424
#include<mutex>
25+
#include<regex>
2526
#include<sys/socket.h>
27+
#include<ucxx/address.h>
2628
#include<ucxx/typedefs.h>
2729
#include<unistd.h>
2830

2931
namespacetensorrt_llm::executor::kv_cache
3032
{
3133

32-
staticvoidlistenerCallback(ucp_conn_request_h connRequest,void* arg)
34+
classUcxCmMessage
3335
{
34-
TLLM_LOG_DEBUG("listenerCallback");
35-
char ipStr[INET6_ADDRSTRLEN];
36-
char portStr[INET6_ADDRSTRLEN];
37-
ucp_conn_request_attr_t attr{};
38-
UcxConnectionManager* connectionManager =reinterpret_cast<UcxConnectionManager*>(arg);
39-
40-
attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR;
41-
ucxx::utils::ucsErrorThrow(ucp_conn_request_query(connRequest, &attr));
42-
ucxx::utils::sockaddr_get_ip_port_str(&attr.client_address, ipStr, portStr, INET6_ADDRSTRLEN);
43-
TLLM_LOG_DEBUG("Server received a connection request from client at address %s:%s", ipStr, portStr);
44-
connectionManager->addConnection(connRequest);
45-
}
36+
public:
37+
enumclassMessageType
38+
{
39+
GET_WORKER_ADDRESS =1,
40+
SERVER_WORKER_ADDRESS =2,
41+
STOP =3,
42+
};
43+
44+
MessageTypemType;
45+
std::optional<std::string>mWorkerAddress;
46+
47+
UcxCmMessage(MessageType type, std::optional<std::string> workerAddress)
48+
:mType(type)
49+
,mWorkerAddress(std::move(workerAddress))
50+
{
51+
}
52+
53+
staticsize_tserializedSize(UcxCmMessageconst& message)
54+
{
55+
namespacesu= tensorrt_llm::executor::serialize_utils;
56+
57+
returnsu::serializedSize(message.mType) +su::serializedSize(message.mWorkerAddress);
58+
}
59+
60+
staticvoidserialize(UcxCmMessageconst& message, std::ostream& os)
61+
{
62+
namespacesu= tensorrt_llm::executor::serialize_utils;
63+
su::serialize(message.mType, os);
64+
su::serialize(message.mWorkerAddress, os);
65+
}
66+
67+
static UcxCmMessagedeserialize(std::istream& is)
68+
{
69+
namespacesu= tensorrt_llm::executor::serialize_utils;
70+
auto type = su::deserialize<MessageType>(is);
71+
auto workerAddress = su::deserialize<std::optional<std::string>>(is);
72+
returnUcxCmMessage(type, workerAddress);
73+
}
74+
};
4675

4776
static std::stringgetLocalIp()
4877
{
@@ -100,6 +129,22 @@ static std::string getLocalIp()
100129
return ip;
101130
}
102131

132+
std::optional<std::pair<std::string,int>>parse_zmq_endpoint(std::stringconst& endpoint)
133+
{
134+
std::regexipv4_regex(R"(tcp://([\d\.]+):(\d+))");
135+
std::regexipv6_regex(R"(tcp://\[([0-9a-fA-F:]+)\]:(\d+))");
136+
std::smatch match;
137+
if (std::regex_match(endpoint, match, ipv4_regex))
138+
{
139+
returnstd::make_pair(match[1].str(),std::stoi(match[2].str()));
140+
}
141+
elseif (std::regex_match(endpoint, match, ipv6_regex))
142+
{
143+
returnstd::make_pair(match[1].str(),std::stoi(match[2].str()));
144+
}
145+
return std::nullopt;
146+
}
147+
103148
UcxConnectionManager::UcxConnectionManager()
104149

105150
{
@@ -120,22 +165,23 @@ UcxConnectionManager::UcxConnectionManager()
120165
std::string error ="Error creating worker and starting progress thread for rank" +std::string(e.what());
121166
TLLM_THROW(error);
122167
}
168+
auto workerAddressPtr =mWorkersPool.front()->getAddress();
169+
mWorkerAddress = workerAddressPtr->getString();
123170

124-
try
125-
{
126-
127-
mListener =mWorkersPool.front()->createListener(0, listenerCallback,this);
128-
}
129-
catch (std::exceptionconst& e)
130-
{
131-
std::string error ="Error creating listener for rank" +std::string(e.what());
132-
TLLM_THROW(error);
133-
}
134-
135-
// Get local IP address
171+
mZmqRepSocket =zmq::socket_t(mZmqContext, zmq::socket_type::rep);
172+
mZmqRepSocket.set(zmq::sockopt::sndhwm,1000);
136173
std::string localIp =getLocalIp();
137-
auto port =mListener->getPort();
138-
SocketState socketState{port, localIp};
174+
mZmqRepSocket.bind("tcp://" + localIp +":*");
175+
mZmqRepEndpoint =mZmqRepSocket.get(zmq::sockopt::last_endpoint);
176+
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(),"UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s",
177+
mZmqRepEndpoint.c_str());
178+
auto parse_result =parse_zmq_endpoint(mZmqRepEndpoint);
179+
TLLM_CHECK_WITH_INFO(parse_result.has_value(),"Failed to parse ZMQ endpoint");
180+
auto [ip, port] = parse_result.value();
181+
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(),"UcxConnectionManager::UcxConnectionManager ip: %s, port: %d",
182+
ip.c_str(), port);
183+
184+
SocketState socketState{static_cast<uint16_t>(port), ip};
139185
std::vector<executor::kv_cache::SocketState>socketStates(mpi::MpiComm::session().getSize());
140186

141187
if (mpi::MpiComm::session().getSize() >1)
@@ -179,6 +225,47 @@ UcxConnectionManager::UcxConnectionManager()
179225
}
180226
mCommState =CommState(socketStates,mpi::MpiComm::session().getRank());
181227
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank()," ***** UCX mCommState: %s",mCommState.toString().c_str());
228+
229+
mZmqRepThread =std::thread(
230+
[this]()
231+
{
232+
while (true)
233+
{
234+
zmq::message_t message;
235+
auto ret =mZmqRepSocket.recv(message);
236+
TLLM_CHECK_WITH_INFO(ret,"mZmqRepSocket.recv failed");
237+
std::stringrecvMessage(static_cast<char*>(message.data()), message.size());
238+
std::istringstreamis(recvMessage);
239+
UcxCmMessage ucxCmessage =UcxCmMessage::deserialize(is);
240+
241+
if (ucxCmessage.mType == UcxCmMessage::MessageType::GET_WORKER_ADDRESS)
242+
{
243+
// add Connection
244+
TLLM_CHECK_WITH_INFO(ucxCmessage.mWorkerAddress.has_value(),"workerAddress is null");
245+
std::string workerAddress = ucxCmessage.mWorkerAddress.value();
246+
std::string selfWorkerAddress =mWorkerAddress;
247+
UcxCmMessageserverMessage(UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS, selfWorkerAddress);
248+
std::ostringstream oStream;
249+
UcxCmMessage::serialize(serverMessage, oStream);
250+
std::string serverMessageStr = oStream.str();
251+
mZmqRepSocket.send(zmq::buffer(serverMessageStr), zmq::send_flags::none);
252+
addConnection(workerAddress);
253+
}
254+
elseif (ucxCmessage.mType == UcxCmMessage::MessageType::STOP)
255+
{
256+
UcxCmMessagestopMessage(UcxCmMessage::MessageType::STOP, std::nullopt);
257+
std::ostringstream oStream;
258+
UcxCmMessage::serialize(stopMessage, oStream);
259+
std::string stopMessageStr = oStream.str();
260+
mZmqRepSocket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none);
261+
break;
262+
}
263+
else
264+
{
265+
TLLM_THROW("Zmq recv unknown message: %s", recvMessage.c_str());
266+
}
267+
}
268+
});
182269
}
183270
catch (std::exceptionconst& e)
184271
{
@@ -195,14 +282,38 @@ UcxConnectionManager::~UcxConnectionManager()
195282
{
196283
worker->stopProgressThread();
197284
}
285+
if (mZmqRepThread.joinable())
286+
{
287+
zmq::socket_tsocket(mZmqContext, zmq::socket_type::req);
288+
socket.connect(mZmqRepEndpoint);
289+
UcxCmMessagestopMessage(UcxCmMessage::MessageType::STOP, std::nullopt);
290+
std::ostringstream oStream;
291+
UcxCmMessage::serialize(stopMessage, oStream);
292+
std::string stopMessageStr = oStream.str();
293+
socket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none);
294+
zmq::message_t reply;
295+
auto ret = socket.recv(reply);
296+
TLLM_CHECK_WITH_INFO(ret,"zmq socket.recv failed");
297+
std::stringreplyStr(static_cast<char*>(reply.data()), reply.size());
298+
std::istringstreamis(replyStr);
299+
UcxCmMessage serverMessage =UcxCmMessage::deserialize(is);
300+
TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::STOP,"serverMessage.mType is not STOP");
301+
socket.close();
302+
mZmqRepThread.join();
303+
}
304+
305+
mZmqRepSocket.close();
306+
307+
mZmqContext.close();
198308
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),"END UcxConnectionManager::~UcxConnectionManager");
199309
}
200310

201-
voidUcxConnectionManager::addConnection(ucp_conn_request_h connRequest)
311+
voidUcxConnectionManager::addConnection(std::stringconst& workerAddress)
202312
{
203313
try
204314
{
205-
std::shared_ptr<ucxx::Endpoint> newEp =mListener->createEndpointFromConnRequest(connRequest,true);
315+
auto workerAddressPtr =ucxx::createAddressFromString(workerAddress);
316+
auto newEp =mWorkersPool.front()->createEndpointFromWorkerAddress(workerAddressPtr,true);
206317

207318
UcxConnection::ConnectionIdType connectionId =getNewConnectionId(newEp);
208319
std::scoped_locklock(mConnectionFuturesMutex);
@@ -225,6 +336,23 @@ void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest)
225336
}
226337
}
227338

339+
std::stringbuild_zmq_endpoint(std::stringconst& ip,uint16_t port)
340+
{
341+
std::ostringstream oss;
342+
343+
std::regexipv6_regex(R"([0-9a-fA-F]*:[0-9a-fA-F]*:[0-9a-fA-F]*.*)");
344+
if (std::regex_match(ip, ipv6_regex) && ip.find(':') != std::string::npos)
345+
{
346+
oss <<"tcp://[" << ip <<"]:" << port;
347+
}
348+
else
349+
{
350+
oss <<"tcp://" << ip <<":" << port;
351+
}
352+
353+
return oss.str();
354+
}
355+
228356
UcxConnection::ConnectionIdTypeUcxConnectionManager::addConnection(std::stringconst& ip,uint16_t port)
229357
{
230358
static std::mutexsAddConnectionIPMutex;
@@ -237,7 +365,24 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string
237365
// This lock ensures that only one thread can create an endpoint from hostname and establish a UCX
238366
// connection at a time, guaranteeing that the only one listener will send connectionId to requester in the
239367
// same time.
240-
std::shared_ptr<ucxx::Endpoint> newEp =mWorkersPool.front()->createEndpointFromHostname(ip, port,true);
368+
auto reqSocket =zmq::socket_t(mZmqContext, zmq::socket_type::req);
369+
reqSocket.connect(build_zmq_endpoint(ip, port));
370+
UcxCmMessagegetWorkerAddressMessage(UcxCmMessage::MessageType::GET_WORKER_ADDRESS,mWorkerAddress);
371+
std::ostringstream oStream;
372+
UcxCmMessage::serialize(getWorkerAddressMessage, oStream);
373+
std::string getWorkerAddressMessageStr = oStream.str();
374+
reqSocket.send(zmq::buffer(getWorkerAddressMessageStr), zmq::send_flags::none);
375+
zmq::message_t reply;
376+
auto ret = reqSocket.recv(reply);
377+
TLLM_CHECK_WITH_INFO(ret,"zmq socket.recv failed");
378+
std::stringreplyStr(static_cast<char*>(reply.data()), reply.size());
379+
std::istringstreamis(replyStr);
380+
UcxCmMessage serverMessage =UcxCmMessage::deserialize(is);
381+
TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS,
382+
"serverMessage.mType is not SERVER_WORKER_ADDRESS");
383+
std::string serverWorkerAddress = serverMessage.mWorkerAddress.value();
384+
auto serverWorkerAddressPtr =ucxx::createAddressFromString(serverWorkerAddress);
385+
auto newEp =mWorkersPool.front()->createEndpointFromWorkerAddress(serverWorkerAddressPtr,true);
241386
connectionId =getNewConnectionId(newEp);
242387
connection = std::make_shared<UcxConnection>(connectionId, newEp,this,true);
243388
}

‎cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h‎

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include<memory>
3737
#include<string>
3838
#include<vector>
39+
#include<zmq.hpp>
3940

4041
namespacetensorrt_llm::executor::kv_cache
4142
{
@@ -45,16 +46,20 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
4546
private:
4647
std::shared_ptr<ucxx::Context>mUcxCtx;
4748
std::vector<std::shared_ptr<ucxx::Worker>>mWorkersPool;
49+
std::stringmWorkerAddress;
4850
std::map<UcxConnection::ConnectionIdType, std::shared_ptr<UcxConnection>>mConnections;
4951
std::map<UcxConnection::ConnectionIdType, std::future<void>>mConnectionFutures;
5052
std::mutexmConnectionsMutex;
5153
std::mutexmConnectionFuturesMutex;
5254
std::unordered_map<std::string,uint64_t>mAddressToConnectionId;
5355
std::mutexmAddressToConnectionIdMutex;
54-
std::shared_ptr<ucxx::Listener>mListener;
5556
CommStatemCommState;
5657
intmDevice;
5758
std::atomic<UcxConnection::ConnectionIdType>mConnectionIdCounter{1};
59+
zmq::context_tmZmqContext;
60+
zmq::socket_tmZmqRepSocket;
61+
std::stringmZmqRepEndpoint;
62+
std::threadmZmqRepThread;
5863

5964
UcxConnection::ConnectionIdTypegetNewConnectionId(std::shared_ptr<ucxx::Endpoint>const& newEp);
6065
UcxConnection::ConnectionIdTypeaddConnection(std::stringconst& ip,uint16_t port);
@@ -69,7 +74,7 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
6974
return std::make_unique<UcxConnectionManager>();
7075
}
7176

72-
voidaddConnection(ucp_conn_request_h connRequest);
77+
voidaddConnection(std::stringconst& workerAddress);
7378
Connectionconst*recvConnect(DataContextconst& ctx,void* data,size_t size)override;
7479
std::vector<Connectionconst*>getConnections(CommStateconst& state)override;
7580
[[nodiscard]] CommStateconst&getCommState()constoverride;

‎docker/common/install_base.sh‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ init_ubuntu() {
6161
python3-pip \
6262
python-is-python3 \
6363
wget \
64-
pigz
64+
pigz \
65+
libzmq3-dev
6566
if!command -v mpirun&> /dev/null;then
6667
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev
6768
fi
@@ -129,6 +130,7 @@ install_gcctoolset_rockylinux() {
129130
openmpi-devel \
130131
pigz \
131132
rdma-core-devel \
133+
zeromq-devel \
132134
-y
133135
echo"source scl_source enable gcc-toolset-11">>"${ENV}"
134136
echo'export PATH=/usr/lib64/openmpi/bin:$PATH'>>"${ENV}"

‎jenkins/current_image_tags.properties‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#
1212
# NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that
1313
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
14-
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678
15-
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678
16-
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202507251001-5678
17-
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202507251001-5678
14+
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090
15+
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090
16+
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202508051130-6090
17+
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202508051130-6090

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp