@@ -10,7 +10,6 @@ import (
1010"net"
1111"net/http"
1212"sync"
13- "time"
1413
1514"github.com/google/uuid"
1615lru"github.com/hashicorp/golang-lru/v2"
@@ -79,44 +78,50 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
7978return node
8079}
8180
81+ func (c * haCoordinator )clientLogger (id ,agent uuid.UUID ) slog.Logger {
82+ return c .log .With (slog .F ("client_id" ,id ),slog .F ("agent_id" ,agent ))
83+ }
84+
85+ func (c * haCoordinator )agentLogger (agent uuid.UUID ) slog.Logger {
86+ return c .log .With (slog .F ("agent_id" ,agent ))
87+ }
88+
8289// ServeClient accepts a WebSocket connection that wants to connect to an agent
8390// with the specified ID.
8491func (c * haCoordinator )ServeClient (conn net.Conn ,id uuid.UUID ,agent uuid.UUID )error {
92+ ctx ,cancel := context .WithCancel (context .Background ())
93+ defer cancel ()
94+ logger := c .clientLogger (id ,agent )
95+
8596c .mutex .Lock ()
8697connectionSockets ,ok := c .agentToConnectionSockets [agent ]
8798if ! ok {
8899connectionSockets = map [uuid.UUID ]* agpl.TrackedConn {}
89100c .agentToConnectionSockets [agent ]= connectionSockets
90101}
91102
92- now := time . Now (). Unix ( )
103+ tc := agpl . NewTrackedConn ( ctx , cancel , conn , id , logger , 0 )
93104// Insert this connection into a map so the agent
94105// can publish node updates.
95- connectionSockets [id ]= & agpl.TrackedConn {
96- Conn :conn ,
97- Start :now ,
98- LastWrite :now ,
99- }
106+ connectionSockets [id ]= tc
100107
101108// When a new connection is requested, we update it with the latest
102109// node of the agent. This allows the connection to establish.
103110node ,ok := c .nodes [agent ]
104- c .mutex .Unlock ()
105111if ok {
106- data ,err := json .Marshal ([]* agpl.Node {node })
107- if err != nil {
108- return xerrors .Errorf ("marshal node: %w" ,err )
109- }
110- _ ,err = conn .Write (data )
112+ err := tc .Enqueue ([]* agpl.Node {node })
113+ c .mutex .Unlock ()
111114if err != nil {
112- return xerrors .Errorf ("write nodes : %w" ,err )
115+ return xerrors .Errorf ("enqueue node : %w" ,err )
113116}
114117}else {
118+ c .mutex .Unlock ()
115119err := c .publishClientHello (agent )
116120if err != nil {
117121return xerrors .Errorf ("publish client hello: %w" ,err )
118122}
119123}
124+ go tc .SendUpdates ()
120125
121126defer func () {
122127c .mutex .Lock ()
@@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
161166c .nodes [id ]= & node
162167// Write the new node from this client to the actively connected agent.
163168agentSocket ,ok := c .agentSockets [agent ]
164- c . mutex . Unlock ()
169+
165170if ! ok {
171+ c .mutex .Unlock ()
166172// If we don't own the agent locally, send it over pubsub to a node that
167173// owns the agent.
168174err := c .publishNodesToAgent (agent , []* agpl.Node {& node })
@@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
171177}
172178return nil
173179}
174-
175- // Write the new node from this client to the actively
176- // connected agent.
177- data ,err := json .Marshal ([]* agpl.Node {& node })
178- if err != nil {
179- return xerrors .Errorf ("marshal nodes: %w" ,err )
180- }
181-
182- _ ,err = agentSocket .Write (data )
180+ err = agentSocket .Enqueue ([]* agpl.Node {& node })
181+ c .mutex .Unlock ()
183182if err != nil {
184- if errors .Is (err ,io .EOF )|| errors .Is (err ,io .ErrClosedPipe ) {
185- return nil
186- }
187- return xerrors .Errorf ("write json: %w" ,err )
183+ return xerrors .Errorf ("enqueu nodes: %w" ,err )
188184}
189-
190185return nil
191186}
192187
193188// ServeAgent accepts a WebSocket connection to an agent that listens to
194189// incoming connections and publishes node updates.
195190func (c * haCoordinator )ServeAgent (conn net.Conn ,id uuid.UUID ,name string )error {
191+ ctx ,cancel := context .WithCancel (context .Background ())
192+ defer cancel ()
193+ logger := c .agentLogger (id )
196194c .agentNameCache .Add (id ,name )
197195
198- // Publish all nodes on this instance that want to connect to this agent.
199- nodes := c .nodesSubscribedToAgent (id )
200- if len (nodes )> 0 {
201- data ,err := json .Marshal (nodes )
202- if err != nil {
203- return xerrors .Errorf ("marshal json: %w" ,err )
204- }
205- _ ,err = conn .Write (data )
206- if err != nil {
207- return xerrors .Errorf ("write nodes: %w" ,err )
208- }
209- }
210-
211- // This uniquely identifies a connection that belongs to this goroutine.
212- unique := uuid .New ()
213- now := time .Now ().Unix ()
214- overwrites := int64 (0 )
215-
216- // If an old agent socket is connected, we close it
217- // to avoid any leaks. This shouldn't ever occur because
218- // we expect one agent to be running.
219196c .mutex .Lock ()
197+ overwrites := int64 (0 )
198+ // If an old agent socket is connected, we Close it to avoid any leaks. This
199+ // shouldn't ever occur because we expect one agent to be running, but it's
200+ // possible for a race condition to happen when an agent is disconnected and
201+ // attempts to reconnect before the server realizes the old connection is
202+ // dead.
220203oldAgentSocket ,ok := c .agentSockets [id ]
221204if ok {
222205overwrites = oldAgentSocket .Overwrites + 1
223206_ = oldAgentSocket .Close ()
224207}
225- c . agentSockets [ id ] = & agpl. TrackedConn {
226- ID : unique ,
227- Conn : conn ,
208+ // This uniquely identifies a connection that belongs to this goroutine.
209+ unique := uuid . New ()
210+ tc := agpl . NewTrackedConn ( ctx , cancel , conn ,unique , logger , overwrites )
228211
229- Name :name ,
230- Start :now ,
231- LastWrite :now ,
232- Overwrites :overwrites ,
212+ // Publish all nodes on this instance that want to connect to this agent.
213+ nodes := c .nodesSubscribedToAgent (id )
214+ if len (nodes )> 0 {
215+ err := tc .Enqueue (nodes )
216+ if err != nil {
217+ c .mutex .Unlock ()
218+ return xerrors .Errorf ("enqueue nodes: %w" ,err )
219+ }
233220}
221+ c .agentSockets [id ]= tc
234222c .mutex .Unlock ()
223+ go tc .SendUpdates ()
235224
236225// Tell clients on other instances to send a callmemaybe to us.
237226err := c .publishAgentHello (id )
@@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
269258}
270259
271260func (c * haCoordinator )nodesSubscribedToAgent (agentID uuid.UUID ) []* agpl.Node {
272- c .mutex .Lock ()
273- defer c .mutex .Unlock ()
274261sockets ,ok := c .agentToConnectionSockets [agentID ]
275262if ! ok {
276263return nil
@@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
320307return & node ,nil
321308}
322309
323- data ,err := json .Marshal ([]* agpl.Node {& node })
324- if err != nil {
325- c .mutex .Unlock ()
326- return nil ,xerrors .Errorf ("marshal nodes: %w" ,err )
327- }
328-
329310// Publish the new node to every listening socket.
330- var wg sync.WaitGroup
331- wg .Add (len (connectionSockets ))
332311for _ ,connectionSocket := range connectionSockets {
333- connectionSocket := connectionSocket
334- go func () {
335- defer wg .Done ()
336- _ = connectionSocket .SetWriteDeadline (time .Now ().Add (5 * time .Second ))
337- _ ,_ = connectionSocket .Write (data )
338- }()
312+ _ = connectionSocket .Enqueue ([]* agpl.Node {& node })
339313}
340314c .mutex .Unlock ()
341- wg .Wait ()
342315return & node ,nil
343316}
344317
@@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
502475
503476c .mutex .Lock ()
504477agentSocket ,ok := c .agentSockets [agentUUID ]
478+ c .mutex .Unlock ()
505479if ! ok {
506- c .mutex .Unlock ()
507480return
508481}
509- c .mutex .Unlock ()
510482
511- // We get a single node over pubsub, so turn into an array.
512- _ ,err = agentSocket .Write (nodeJSON )
483+ // Socket takes a slice of Nodes, so we need to parse the JSON here.
484+ var nodes []* agpl.Node
485+ err = json .Unmarshal (nodeJSON ,& nodes )
486+ if err != nil {
487+ c .log .Error (ctx ,"invalid nodes JSON" ,slog .F ("id" ,agentID ),slog .Error (err ),slog .F ("node" ,string (nodeJSON )))
488+ }
489+ err = agentSocket .Enqueue (nodes )
513490if err != nil {
514- if errors .Is (err ,io .EOF )|| errors .Is (err ,io .ErrClosedPipe ) {
515- return
516- }
517491c .log .Error (ctx ,"send callmemaybe to agent" ,slog .Error (err ))
518492return
519493}
@@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
536510return
537511}
538512
513+ c .mutex .RLock ()
539514nodes := c .nodesSubscribedToAgent (agentUUID )
515+ c .mutex .RUnlock ()
540516if len (nodes )> 0 {
541517err := c .publishNodesToAgent (agentUUID ,nodes )
542518if err != nil {