Expand Up @@ -14,6 +14,8 @@ import ( "strings" "time" "golang.org/x/sync/errgroup" "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" Expand Down Expand Up @@ -317,142 +319,28 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, q := coordinateURL.Query() q.Add("version", proto.CurrentVersion.String()) coordinateURL.RawQuery = q.Encode() closedCoordinator := make(chan struct{}) // Must only ever be used once, send error OR close to avoid // reassignment race. Buffered so we don't hang in goroutine. firstCoordinator := make(chan error, 1) go func() { defer close(closedCoordinator) isFirst := true for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { options.Logger.Debug(ctx, "connecting") // nolint:bodyclose ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: c.HTTPClient, HTTPHeader: headers, // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, }) if isFirst { if res != nil && res.StatusCode == http.StatusConflict { firstCoordinator <- ReadBodyAsError(res) return } isFirst = false close(firstCoordinator) } if err != nil { if errors.Is(err, context.Canceled) { return } options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary)) if err != nil { options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) _ = ws.Close(websocket.StatusInternalError, "") continue } coordinate, err := client.Coordinate(ctx) if err != nil { options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) _ = ws.Close(websocket.StatusInternalError, "") continue } coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID) options.Logger.Debug(ctx, "serving coordinator") err = <-coordination.Error() if errors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return } if err != nil { options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err)) _ = ws.Close(websocket.StatusGoingAway, "") continue } _ = ws.Close(websocket.StatusGoingAway, "") } }() derpMapURL, err := c.URL.Parse("/api/v2/derp-map") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } closedDerpMap := make(chan struct{}) // Must only ever be used once, send error OR close to avoid // reassignment race. Buffered so we don't hang in goroutine. firstDerpMap := make(chan error, 1) go func() { defer close(closedDerpMap) isFirst := true for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { options.Logger.Debug(ctx, "connecting to server for derp map updates") // nolint:bodyclose ws, res, err := websocket.Dial(ctx, derpMapURL.String(), &websocket.DialOptions{ HTTPClient: c.HTTPClient, HTTPHeader: headers, // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, }) if isFirst { if res != nil && res.StatusCode == http.StatusConflict { firstDerpMap <- ReadBodyAsError(res) return } isFirst = false close(firstDerpMap) } if err != nil { if errors.Is(err, context.Canceled) { return } options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } var ( nconn = websocket.NetConn(ctx, ws, websocket.MessageBinary) dec = json.NewDecoder(nconn) ) for { var derpMap tailcfg.DERPMap err := dec.Decode(&derpMap) if xerrors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return } if err != nil { options.Logger.Debug(ctx, "failed to decode derp map", slog.Error(err)) _ = ws.Close(websocket.StatusGoingAway, "") return } if !tailnet.CompareDERPMaps(conn.DERPMap(), &derpMap) { options.Logger.Debug(ctx, "updating derp map due to detected changes") conn.SetDERPMap(&derpMap) } } } }() for firstCoordinator != nil || firstDerpMap != nil { select { case <-dialCtx.Done(): return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) case err = <-firstCoordinator: if err != nil { return nil, xerrors.Errorf("start coordinator: %w", err) } firstCoordinator = nil case err = <-firstDerpMap: if err != nil { return nil, xerrors.Errorf("receive derp map: %w", err) } firstDerpMap = nil connector := runTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: c.HTTPClient, HTTPHeader: headers, // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, }, conn, ) options.Logger.Debug(ctx, "running tailnet API v2+ connector") select { case <-dialCtx.Done(): return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) case err = <-connector.connected: if err != nil { options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err)) return nil, xerrors.Errorf("start connector: %w", err) } options.Logger.Debug(ctx, "connected to tailnet v2+ API") } agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ Expand All @@ -464,8 +352,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, AgentIP: WorkspaceAgentIP, CloseFunc: func() error { cancel() <-closedCoordinator <-closedDerpMap <-connector.closed return conn.Close() }, }) Expand All @@ -478,6 +365,171 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, return agentConn, nil } // tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to // // 1) run the Coordinate API and pass node information back and forth // 2) stream DERPMap updates and program the Conn // // These functions share the same websocket, and so are combined here so that if we hit a problem // we tear the whole thing down and start over with a new websocket. // // @typescript-ignore tailnetAPIConnector type tailnetAPIConnector struct { ctx context.Context logger slog.Logger agentID uuid.UUID coordinateURL string dialOptions *websocket.DialOptions conn *tailnet.Conn connected chan error isFirst bool closed chan struct{} } // runTailnetAPIConnector creates and runs a tailnetAPIConnector func runTailnetAPIConnector( ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions, conn *tailnet.Conn, ) *tailnetAPIConnector { tac := &tailnetAPIConnector{ ctx: ctx, logger: logger, agentID: agentID, coordinateURL: coordinateURL, dialOptions: dialOptions, conn: conn, connected: make(chan error, 1), closed: make(chan struct{}), } go tac.run() return tac } func (tac *tailnetAPIConnector) run() { tac.isFirst = true defer close(tac.closed) for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); { tailnetClient, err := tac.dial() if err != nil { continue } tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client") tac.coordinateAndDERPMap(tailnetClient) tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost") } } func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API") // nolint:bodyclose ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions) if tac.isFirst { if res != nil && res.StatusCode == http.StatusConflict { err = ReadBodyAsError(res) tac.connected <- err return nil, err } tac.isFirst = false close(tac.connected) } if err != nil { if !errors.Is(err, context.Canceled) { tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err)) } return nil, err } client, err := tailnet.NewDRPCClient(websocket.NetConn(tac.ctx, ws, websocket.MessageBinary)) if err != nil { tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err)) _ = ws.Close(websocket.StatusInternalError, "") return nil, err } return client, err } // coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined // into one function so that a problem with one tears down the other and triggers a retry (if // appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same // fate. func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) { defer func() { conn := client.DRPCConn() closeErr := conn.Close() if closeErr != nil && !xerrors.Is(closeErr, io.EOF) && !xerrors.Is(closeErr, context.Canceled) && !xerrors.Is(closeErr, context.DeadlineExceeded) { tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr)) <-conn.Closed() } }() eg, egCtx := errgroup.WithContext(tac.ctx) eg.Go(func() error { return tac.coordinate(egCtx, client) }) eg.Go(func() error { return tac.derpMap(egCtx, client) }) err := eg.Wait() if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API") } } func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error { coord, err := client.Coordinate(ctx) if err != nil { return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err) } defer func() { cErr := coord.Close() if cErr != nil { tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) tac.logger.Debug(ctx, "serving coordinator") err = <-coordination.Error() if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("remote coordination error: %w", err) } return nil } func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error { s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{}) if err != nil { return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err) } defer func() { cErr := s.Close() if cErr != nil { tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) } }() for { dmp, err := s.Recv() if err != nil { if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { return nil } return xerrors.Errorf("error receiving DERP Map: %w", err) } tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp)) dm := tailnet.DERPMapFromProto(dmp) tac.conn.SetDERPMap(dm) } } // WatchWorkspaceAgentMetadata watches the metadata of a workspace agent. // The returned channel will be closed when the context is canceled. Exactly // one error will be sent on the error channel. The metadata channel is never closed. Expand Down