Expand Up @@ -10,7 +10,6 @@ import ( "net" "net/http" "sync" "time" "github.com/google/uuid" lru "github.com/hashicorp/golang-lru/v2" Expand Down Expand Up @@ -79,44 +78,50 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node { return node } func (c *haCoordinator) clientLogger(id, agent uuid.UUID) slog.Logger { return c.log.With(slog.F("client_id", id), slog.F("agent_id", agent)) } func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { return c.log.With(slog.F("agent_id", agent)) } // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() logger := c.clientLogger(id, agent) c.mutex.Lock() connectionSockets, ok := c.agentToConnectionSockets[agent] if !ok { connectionSockets = map[uuid.UUID]*agpl.TrackedConn{} c.agentToConnectionSockets[agent] = connectionSockets } now :=time.Now().Unix( )tc :=agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0 )// Insert this connection into a map so the agent // can publish node updates. connectionSockets[id] = &agpl.TrackedConn{ Conn: conn, Start: now, LastWrite: now, } connectionSockets[id] = tc // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] c.mutex.Unlock() if ok { data, err := json.Marshal([]*agpl.Node{node}) if err != nil { return xerrors.Errorf("marshal node: %w", err) } _, err = conn.Write(data) err := tc.Enqueue([]*agpl.Node{node}) c.mutex.Unlock() if err != nil { return xerrors.Errorf("write nodes : %w", err) return xerrors.Errorf("enqueue node : %w", err) } } else { c.mutex.Unlock() err := c.publishClientHello(agent) if err != nil { return xerrors.Errorf("publish client hello: %w", err) } } go tc.SendUpdates() defer func() { c.mutex.Lock() Expand Down Expand Up @@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js c.nodes[id] = &node // Write the new node from this client to the actively connected agent. agentSocket, ok := c.agentSockets[agent] c.mutex.Unlock() if !ok { c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) Expand All @@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js } return nil } // Write the new node from this client to the actively // connected agent. data, err := json.Marshal([]*agpl.Node{&node}) if err != nil { return xerrors.Errorf("marshal nodes: %w", err) } _, err = agentSocket.Write(data) err = agentSocket.Enqueue([]*agpl.Node{&node}) c.mutex.Unlock() if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { return nil } return xerrors.Errorf("write json: %w", err) return xerrors.Errorf("enqueu nodes: %w", err) } return nil } // ServeAgent accepts a WebSocket connection to an agent that listens to // incoming connections and publishes node updates. func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() logger := c.agentLogger(id) c.agentNameCache.Add(id, name) // Publish all nodes on this instance that want to connect to this agent. nodes := c.nodesSubscribedToAgent(id) if len(nodes) > 0 { data, err := json.Marshal(nodes) if err != nil { return xerrors.Errorf("marshal json: %w", err) } _, err = conn.Write(data) if err != nil { return xerrors.Errorf("write nodes: %w", err) } } // This uniquely identifies a connection that belongs to this goroutine. unique := uuid.New() now := time.Now().Unix() overwrites := int64(0) // If an old agent socket is connected, we close it // to avoid any leaks. This shouldn't ever occur because // we expect one agent to be running. c.mutex.Lock() overwrites := int64(0) // If an old agent socket is connected, we Close it to avoid any leaks. This // shouldn't ever occur because we expect one agent to be running, but it's // possible for a race condition to happen when an agent is disconnected and // attempts to reconnect before the server realizes the old connection is // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { overwrites = oldAgentSocket.Overwrites + 1 _ = oldAgentSocket.Close() } c.agentSockets[id] = &agpl.TrackedConn{ ID: unique, Conn: conn,// This uniquely identifies a connection that belongs to this goroutine. unique := uuid.New() tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) Name: name, Start: now, LastWrite: now, Overwrites: overwrites, // Publish all nodes on this instance that want to connect to this agent. nodes := c.nodesSubscribedToAgent(id) if len(nodes) > 0 { err := tc.Enqueue(nodes) if err != nil { c.mutex.Unlock() return xerrors.Errorf("enqueue nodes: %w", err) } } c.agentSockets[id] = tc c.mutex.Unlock() go tc.SendUpdates() // Tell clients on other instances to send a callmemaybe to us. err := c.publishAgentHello(id) Expand Down Expand Up @@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { c.mutex.Lock() defer c.mutex.Unlock() sockets, ok := c.agentToConnectionSockets[agentID] if !ok { return nil Expand Down Expand Up @@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( return &node, nil } data, err := json.Marshal([]*agpl.Node{&node}) if err != nil { c.mutex.Unlock() return nil, xerrors.Errorf("marshal nodes: %w", err) } // Publish the new node to every listening socket. var wg sync.WaitGroup wg.Add(len(connectionSockets)) for _, connectionSocket := range connectionSockets { connectionSocket := connectionSocket go func() { defer wg.Done() _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) _, _ = connectionSocket.Write(data) }() _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } c.mutex.Unlock() wg.Wait() return &node, nil } Expand Down Expand Up @@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) c.mutex.Lock() agentSocket, ok := c.agentSockets[agentUUID] c.mutex.Unlock() if !ok { c.mutex.Unlock() return } c.mutex.Unlock() // We get a single node over pubsub, so turn into an array. _, err = agentSocket.Write(nodeJSON) // Socket takes a slice of Nodes, so we need to parse the JSON here. var nodes []*agpl.Node err = json.Unmarshal(nodeJSON, &nodes) if err != nil { c.log.Error(ctx, "invalid nodes JSON", slog.F("id", agentID), slog.Error(err), slog.F("node", string(nodeJSON))) } err = agentSocket.Enqueue(nodes) if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { return } c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) return } Expand All @@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) return } c.mutex.RLock() nodes := c.nodesSubscribedToAgent(agentUUID) c.mutex.RUnlock() if len(nodes) > 0 { err := c.publishNodesToAgent(agentUUID, nodes) if err != nil { Expand Down