2222#include < exception>
2323#include < iostream>
2424#include < mutex>
25+ #include < regex>
2526#include < sys/socket.h>
27+ #include < ucxx/address.h>
2628#include < ucxx/typedefs.h>
2729#include < unistd.h>
2830
2931namespace tensorrt_llm ::executor::kv_cache
3032{
3133
32- static void listenerCallback (ucp_conn_request_h connRequest, void * arg)
34+ class UcxCmMessage
3335{
34- TLLM_LOG_DEBUG (" listenerCallback" );
35- char ipStr[INET6_ADDRSTRLEN];
36- char portStr[INET6_ADDRSTRLEN];
37- ucp_conn_request_attr_t attr{};
38- UcxConnectionManager* connectionManager =reinterpret_cast <UcxConnectionManager*>(arg);
39-
40- attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR;
41- ucxx::utils::ucsErrorThrow (ucp_conn_request_query (connRequest, &attr));
42- ucxx::utils::sockaddr_get_ip_port_str (&attr.client_address , ipStr, portStr, INET6_ADDRSTRLEN);
43- TLLM_LOG_DEBUG (" Server received a connection request from client at address %s:%s" , ipStr, portStr);
44- connectionManager->addConnection (connRequest);
45- }
36+ public:
37+ enum class MessageType
38+ {
39+ GET_WORKER_ADDRESS =1 ,
40+ SERVER_WORKER_ADDRESS =2 ,
41+ STOP =3 ,
42+ };
43+
44+ MessageTypemType ;
45+ std::optional<std::string>mWorkerAddress ;
46+
47+ UcxCmMessage (MessageType type, std::optional<std::string> workerAddress)
48+ :mType (type)
49+ ,mWorkerAddress (std::move(workerAddress))
50+ {
51+ }
52+
53+ static size_t serializedSize (UcxCmMessageconst & message)
54+ {
55+ namespace su = tensorrt_llm::executor::serialize_utils;
56+
57+ return su::serializedSize (message.mType ) +su::serializedSize (message.mWorkerAddress );
58+ }
59+
60+ static void serialize (UcxCmMessageconst & message, std::ostream& os)
61+ {
62+ namespace su = tensorrt_llm::executor::serialize_utils;
63+ su::serialize (message.mType , os);
64+ su::serialize (message.mWorkerAddress , os);
65+ }
66+
67+ static UcxCmMessagedeserialize (std::istream& is)
68+ {
69+ namespace su = tensorrt_llm::executor::serialize_utils;
70+ auto type = su::deserialize<MessageType>(is);
71+ auto workerAddress = su::deserialize<std::optional<std::string>>(is);
72+ return UcxCmMessage (type, workerAddress);
73+ }
74+ };
4675
4776static std::stringgetLocalIp ()
4877{
@@ -100,6 +129,22 @@ static std::string getLocalIp()
100129return ip;
101130}
102131
132+ std::optional<std::pair<std::string,int >>parse_zmq_endpoint (std::stringconst & endpoint)
133+ {
134+ std::regexipv4_regex (R"( tcp://([\d\.]+):(\d+))" );
135+ std::regexipv6_regex (R"( tcp://\[([0-9a-fA-F:]+)\]:(\d+))" );
136+ std::smatch match;
137+ if (std::regex_match (endpoint, match, ipv4_regex))
138+ {
139+ return std::make_pair (match[1 ].str (),std::stoi (match[2 ].str ()));
140+ }
141+ else if (std::regex_match (endpoint, match, ipv6_regex))
142+ {
143+ return std::make_pair (match[1 ].str (),std::stoi (match[2 ].str ()));
144+ }
145+ return std::nullopt ;
146+ }
147+
103148UcxConnectionManager::UcxConnectionManager ()
104149
105150{
@@ -120,22 +165,23 @@ UcxConnectionManager::UcxConnectionManager()
120165 std::string error =" Error creating worker and starting progress thread for rank" +std::string (e.what ());
121166TLLM_THROW (error);
122167 }
168+ auto workerAddressPtr =mWorkersPool .front ()->getAddress ();
169+ mWorkerAddress = workerAddressPtr->getString ();
123170
124- try
125- {
126-
127- mListener =mWorkersPool .front ()->createListener (0 , listenerCallback,this );
128- }
129- catch (std::exceptionconst & e)
130- {
131- std::string error =" Error creating listener for rank" +std::string (e.what ());
132- TLLM_THROW (error);
133- }
134-
135- // Get local IP address
171+ mZmqRepSocket =zmq::socket_t (mZmqContext , zmq::socket_type::rep);
172+ mZmqRepSocket .set (zmq::sockopt::sndhwm,1000 );
136173 std::string localIp =getLocalIp ();
137- auto port =mListener ->getPort ();
138- SocketState socketState{port, localIp};
174+ mZmqRepSocket .bind (" tcp://" + localIp +" :*" );
175+ mZmqRepEndpoint =mZmqRepSocket .get (zmq::sockopt::last_endpoint);
176+ TLLM_LOG_INFO (mpi::MpiComm::world ().getRank ()," UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s" ,
177+ mZmqRepEndpoint .c_str ());
178+ auto parse_result =parse_zmq_endpoint (mZmqRepEndpoint );
179+ TLLM_CHECK_WITH_INFO (parse_result.has_value ()," Failed to parse ZMQ endpoint" );
180+ auto [ip, port] = parse_result.value ();
181+ TLLM_LOG_INFO (mpi::MpiComm::world ().getRank ()," UcxConnectionManager::UcxConnectionManager ip: %s, port: %d" ,
182+ ip.c_str (), port);
183+
184+ SocketState socketState{static_cast <uint16_t >(port), ip};
139185 std::vector<executor::kv_cache::SocketState>socketStates (mpi::MpiComm::session ().getSize ());
140186
141187if (mpi::MpiComm::session ().getSize () >1 )
@@ -179,6 +225,47 @@ UcxConnectionManager::UcxConnectionManager()
179225 }
180226mCommState =CommState (socketStates,mpi::MpiComm::session ().getRank ());
181227TLLM_LOG_DEBUG (mpi::MpiComm::world ().getRank ()," ***** UCX mCommState: %s" ,mCommState .toString ().c_str ());
228+
229+ mZmqRepThread =std::thread (
230+ [this ]()
231+ {
232+ while (true )
233+ {
234+ zmq::message_t message;
235+ auto ret =mZmqRepSocket .recv (message);
236+ TLLM_CHECK_WITH_INFO (ret," mZmqRepSocket.recv failed" );
237+ std::stringrecvMessage (static_cast <char *>(message.data ()), message.size ());
238+ std::istringstreamis (recvMessage);
239+ UcxCmMessage ucxCmessage =UcxCmMessage::deserialize (is);
240+
241+ if (ucxCmessage.mType == UcxCmMessage::MessageType::GET_WORKER_ADDRESS)
242+ {
243+ // add Connection
244+ TLLM_CHECK_WITH_INFO (ucxCmessage.mWorkerAddress .has_value ()," workerAddress is null" );
245+ std::string workerAddress = ucxCmessage.mWorkerAddress .value ();
246+ std::string selfWorkerAddress =mWorkerAddress ;
247+ UcxCmMessageserverMessage (UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS, selfWorkerAddress);
248+ std::ostringstream oStream;
249+ UcxCmMessage::serialize (serverMessage, oStream);
250+ std::string serverMessageStr = oStream.str ();
251+ mZmqRepSocket .send (zmq::buffer (serverMessageStr), zmq::send_flags::none);
252+ addConnection (workerAddress);
253+ }
254+ else if (ucxCmessage.mType == UcxCmMessage::MessageType::STOP)
255+ {
256+ UcxCmMessagestopMessage (UcxCmMessage::MessageType::STOP, std::nullopt );
257+ std::ostringstream oStream;
258+ UcxCmMessage::serialize (stopMessage, oStream);
259+ std::string stopMessageStr = oStream.str ();
260+ mZmqRepSocket .send (zmq::buffer (stopMessageStr), zmq::send_flags::none);
261+ break ;
262+ }
263+ else
264+ {
265+ TLLM_THROW (" Zmq recv unknown message: %s" , recvMessage.c_str ());
266+ }
267+ }
268+ });
182269 }
183270catch (std::exceptionconst & e)
184271 {
@@ -195,14 +282,38 @@ UcxConnectionManager::~UcxConnectionManager()
195282 {
196283 worker->stopProgressThread ();
197284 }
285+ if (mZmqRepThread .joinable ())
286+ {
287+ zmq::socket_t socket (mZmqContext , zmq::socket_type::req);
288+ socket.connect (mZmqRepEndpoint );
289+ UcxCmMessagestopMessage (UcxCmMessage::MessageType::STOP, std::nullopt );
290+ std::ostringstream oStream;
291+ UcxCmMessage::serialize (stopMessage, oStream);
292+ std::string stopMessageStr = oStream.str ();
293+ socket.send (zmq::buffer (stopMessageStr), zmq::send_flags::none);
294+ zmq::message_t reply;
295+ auto ret = socket.recv (reply);
296+ TLLM_CHECK_WITH_INFO (ret," zmq socket.recv failed" );
297+ std::stringreplyStr (static_cast <char *>(reply.data ()), reply.size ());
298+ std::istringstreamis (replyStr);
299+ UcxCmMessage serverMessage =UcxCmMessage::deserialize (is);
300+ TLLM_CHECK_WITH_INFO (serverMessage.mType == UcxCmMessage::MessageType::STOP," serverMessage.mType is not STOP" );
301+ socket.close ();
302+ mZmqRepThread .join ();
303+ }
304+
305+ mZmqRepSocket .close ();
306+
307+ mZmqContext .close ();
198308TLLM_LOG_DEBUG (mpi::MpiComm::world ().getRank ()," END UcxConnectionManager::~UcxConnectionManager" );
199309}
200310
201- void UcxConnectionManager::addConnection (ucp_conn_request_h connRequest )
311+ void UcxConnectionManager::addConnection (std::string const & workerAddress )
202312{
203313try
204314 {
205- std::shared_ptr<ucxx::Endpoint> newEp =mListener ->createEndpointFromConnRequest (connRequest,true );
315+ auto workerAddressPtr =ucxx::createAddressFromString (workerAddress);
316+ auto newEp =mWorkersPool .front ()->createEndpointFromWorkerAddress (workerAddressPtr,true );
206317
207318 UcxConnection::ConnectionIdType connectionId =getNewConnectionId (newEp);
208319 std::scoped_locklock (mConnectionFuturesMutex );
@@ -225,6 +336,23 @@ void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest)
225336 }
226337}
227338
339+ std::stringbuild_zmq_endpoint (std::stringconst & ip,uint16_t port)
340+ {
341+ std::ostringstream oss;
342+
343+ std::regexipv6_regex (R"( [0-9a-fA-F]*:[0-9a-fA-F]*:[0-9a-fA-F]*.*)" );
344+ if (std::regex_match (ip, ipv6_regex) && ip.find (' :' ) != std::string::npos)
345+ {
346+ oss <<" tcp://[" << ip <<" ]:" << port;
347+ }
348+ else
349+ {
350+ oss <<" tcp://" << ip <<" :" << port;
351+ }
352+
353+ return oss.str ();
354+ }
355+
228356UcxConnection::ConnectionIdTypeUcxConnectionManager::addConnection (std::stringconst & ip,uint16_t port)
229357{
230358static std::mutexsAddConnectionIPMutex ;
@@ -237,7 +365,24 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string
237365// This lock ensures that only one thread can create an endpoint from hostname and establish a UCX
238366// connection at a time, guaranteeing that the only one listener will send connectionId to requester in the
239367// same time.
240- std::shared_ptr<ucxx::Endpoint> newEp =mWorkersPool .front ()->createEndpointFromHostname (ip, port,true );
368+ auto reqSocket =zmq::socket_t (mZmqContext , zmq::socket_type::req);
369+ reqSocket.connect (build_zmq_endpoint (ip, port));
370+ UcxCmMessagegetWorkerAddressMessage (UcxCmMessage::MessageType::GET_WORKER_ADDRESS,mWorkerAddress );
371+ std::ostringstream oStream;
372+ UcxCmMessage::serialize (getWorkerAddressMessage, oStream);
373+ std::string getWorkerAddressMessageStr = oStream.str ();
374+ reqSocket.send (zmq::buffer (getWorkerAddressMessageStr), zmq::send_flags::none);
375+ zmq::message_t reply;
376+ auto ret = reqSocket.recv (reply);
377+ TLLM_CHECK_WITH_INFO (ret," zmq socket.recv failed" );
378+ std::stringreplyStr (static_cast <char *>(reply.data ()), reply.size ());
379+ std::istringstreamis (replyStr);
380+ UcxCmMessage serverMessage =UcxCmMessage::deserialize (is);
381+ TLLM_CHECK_WITH_INFO (serverMessage.mType == UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS,
382+ " serverMessage.mType is not SERVER_WORKER_ADDRESS" );
383+ std::string serverWorkerAddress = serverMessage.mWorkerAddress .value ();
384+ auto serverWorkerAddressPtr =ucxx::createAddressFromString (serverWorkerAddress);
385+ auto newEp =mWorkersPool .front ()->createEndpointFromWorkerAddress (serverWorkerAddressPtr,true );
241386 connectionId =getNewConnectionId (newEp);
242387 connection = std::make_shared<UcxConnection>(connectionId, newEp,this ,true );
243388 }