Expand Up @@ -31,7 +31,6 @@ import ( "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "storj.io/drpc" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" Expand Down Expand Up @@ -94,7 +93,9 @@ type Options struct { } type Client interface { ConnectRPC(ctx context.Context) (drpc.Conn, error) ConnectRPC23(ctx context.Context) ( proto.DRPCAgentClient23, tailnetproto.DRPCTailnetClient23, error, ) RewriteDERPMap(derpMap *tailcfg.DERPMap) } Expand Down Expand Up @@ -410,7 +411,7 @@ func (t *trySingleflight) Do(key string, fn func()) { fn() } func (a *agent) reportMetadata(ctx context.Context,conn drpc.Conn ) error { func (a *agent) reportMetadata(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { tickerDone := make(chan struct{}) collectDone := make(chan struct{}) ctx, cancel := context.WithCancel(ctx) Expand Down Expand Up @@ -572,7 +573,6 @@ func (a *agent) reportMetadata(ctx context.Context, conn drpc.Conn) error { reportTimeout = 30 * time.Second reportError = make(chan error, 1) reportInFlight = false aAPI = proto.NewDRPCAgentClient(conn) ) for { Expand Down Expand Up @@ -627,8 +627,7 @@ func (a *agent) reportMetadata(ctx context.Context, conn drpc.Conn) error { // reportLifecycle reports the current lifecycle state once. All state // changes are reported in order. func (a *agent) reportLifecycle(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) func (a *agent) reportLifecycle(ctx context.Context, aAPI proto.DRPCAgentClient23) error { for { select { case <-a.lifecycleUpdate: Expand Down Expand Up @@ -710,8 +709,7 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient23) error { ticker := time.NewTicker(a.announcementBannersRefreshInterval) defer ticker.Stop() for { Expand All @@ -737,7 +735,7 @@ func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) erro } func (a *agent) run() (retErr error) { // This allows the agent to refreshit's token if necessary. // This allows the agent to refreshits token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. sessionToken, err := a.exchangeToken(a.hardCtx) Expand All @@ -747,12 +745,12 @@ func (a *agent) run() (retErr error) { a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs conn, err := a.client.ConnectRPC (a.hardCtx)aAPI, tAPI, err := a.client.ConnectRPC23 (a.hardCtx)if err != nil { return err } defer func() { cErr :=conn .Close() cErr :=aAPI.DRPCConn() .Close() if cErr != nil { a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err)) } Expand All @@ -761,11 +759,10 @@ func (a *agent) run() (retErr error) { // A lot of routines need the agent API / tailnet API connection. We run them in their own // goroutines in parallel, but errors in any routine will cause them all to exit so we can // redial the coder server and retry. connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger,conn ) connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger,aAPI, tAPI ) connMan.start("init notification banners", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) connMan.startAgentAPI("init notification banners", gracefulShutdownBehaviorStop, func(ctx context.Context, aAPI proto.DRPCAgentClient23) error { bannersProto, err := aAPI.GetAnnouncementBanners(ctx, &proto.GetAnnouncementBannersRequest{}) if err != nil { return xerrors.Errorf("fetch service banner: %w", err) Expand All @@ -781,9 +778,9 @@ func (a *agent) run() (retErr error) { // sending logs gets gracefulShutdownBehaviorRemain because we want to send logs generated by // shutdown scripts. connMan.start ("send logs", gracefulShutdownBehaviorRemain, func(ctx context.Context,conn drpc.Conn ) error { err := a.logSender.SendLoop(ctx,proto.NewDRPCAgentClient(conn) ) connMan.startAgentAPI ("send logs", gracefulShutdownBehaviorRemain, func(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { err := a.logSender.SendLoop(ctx,aAPI ) if xerrors.Is(err, agentsdk.LogLimitExceededError) { // we don't want this error to tear down the API connection and propagate to the // other routines that use the API. The LogSender has already dropped a warning Expand All @@ -795,10 +792,10 @@ func (a *agent) run() (retErr error) { // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the // lifecycle reporting has to be via gracefulShutdownBehaviorRemain connMan.start ("report lifecycle", gracefulShutdownBehaviorRemain, a.reportLifecycle) connMan.startAgentAPI ("report lifecycle", gracefulShutdownBehaviorRemain, a.reportLifecycle) // metadata reporting can cease as soon as we start gracefully shutting down connMan.start ("report metadata", gracefulShutdownBehaviorStop, a.reportMetadata) connMan.startAgentAPI ("report metadata", gracefulShutdownBehaviorStop, a.reportMetadata) // channels to sync goroutines below // handle manifest Expand All @@ -819,55 +816,55 @@ func (a *agent) run() (retErr error) { networkOK := newCheckpoint(a.logger) manifestOK := newCheckpoint(a.logger) connMan.start ("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) connMan.startAgentAPI ("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) connMan.start ("app health reporter", gracefulShutdownBehaviorStop, func(ctx context.Context,conn drpc.Conn ) error { connMan.startAgentAPI ("app health reporter", gracefulShutdownBehaviorStop, func(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { if err := manifestOK.wait(ctx); err != nil { return xerrors.Errorf("no manifest: %w", err) } manifest := a.manifest.Load() NewWorkspaceAppHealthReporter( a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn) ), a.logger, manifest.Apps, agentsdk.AppHealthPoster(aAPI ), )(ctx) return nil }) connMan.start ("create or update network", gracefulShutdownBehaviorStop, connMan.startAgentAPI ("create or update network", gracefulShutdownBehaviorStop, a.createOrUpdateNetwork(manifestOK, networkOK)) connMan.start ("coordination", gracefulShutdownBehaviorStop, func(ctx context.Context,conn drpc.Conn ) error { connMan.startTailnetAPI ("coordination", gracefulShutdownBehaviorStop, func(ctx context.Context,tAPI tailnetproto.DRPCTailnetClient23 ) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.runCoordinator(ctx,conn , a.network) return a.runCoordinator(ctx,tAPI , a.network) }, ) connMan.start ("derp map subscriber", gracefulShutdownBehaviorStop, func(ctx context.Context,conn drpc.Conn ) error { connMan.startTailnetAPI ("derp map subscriber", gracefulShutdownBehaviorStop, func(ctx context.Context,tAPI tailnetproto.DRPCTailnetClient23 ) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.runDERPMapSubscriber(ctx,conn , a.network) return a.runDERPMapSubscriber(ctx,tAPI , a.network) }) connMan.start ("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) connMan.startAgentAPI ("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) connMan.start ("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context,conn drpc.Conn ) error { connMan.startAgentAPI ("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.statsReporter.reportLoop(ctx,proto.NewDRPCAgentClient(conn) ) return a.statsReporter.reportLoop(ctx,aAPI ) }) return connMan.wait() } // handleManifest returns a function that fetches and processes the manifest func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,conn drpc.Conn ) error { return func(ctx context.Context,conn drpc.Conn ) error { func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { return func(ctx context.Context,aAPI proto.DRPCAgentClient23 ) error { var ( sentResult = false err error Expand All @@ -877,7 +874,6 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, manifestOK.complete(err) } }() aAPI := proto.NewDRPCAgentClient(conn) mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) if err != nil { return xerrors.Errorf("fetch metadata: %w", err) Expand Down Expand Up @@ -977,8 +973,8 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, // createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates // the tailnet using the information in the manifest func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context,drpc.Conn ) error { return func(ctx context.Context, _drpc.Conn ) (retErr error) { func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context,proto.DRPCAgentClient23 ) error { return func(ctx context.Context, _proto.DRPCAgentClient23 ) (retErr error) { if err := manifestOK.wait(ctx); err != nil { return xerrors.Errorf("no manifest: %w", err) } Expand Down Expand Up @@ -1325,9 +1321,8 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t // runCoordinator runs a coordinator and returns whether a reconnect // should occur. func (a *agent) runCoordinator(ctx context.Context,conn drpc.Conn , network *tailnet.Conn) error { func (a *agent) runCoordinator(ctx context.Context,tClient tailnetproto.DRPCTailnetClient23 , network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from coordination RPC") tClient := tailnetproto.NewDRPCTailnetClient(conn) // we run the RPC on the hardCtx so that we have a chance to send the disconnect message if we // gracefully shut down. coordinate, err := tClient.Coordinate(a.hardCtx) Expand Down Expand Up @@ -1373,11 +1368,10 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai } // runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur. func (a *agent) runDERPMapSubscriber(ctx context.Context,conn drpc.Conn , network *tailnet.Conn) error { func (a *agent) runDERPMapSubscriber(ctx context.Context,tClient tailnetproto.DRPCTailnetClient23 , network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from derp map RPC") ctx, cancel := context.WithCancel(ctx) defer cancel() tClient := tailnetproto.NewDRPCTailnetClient(conn) stream, err := tClient.StreamDERPMaps(ctx, &tailnetproto.StreamDERPMapsRequest{}) if err != nil { return xerrors.Errorf("stream DERP Maps: %w", err) Expand Down Expand Up @@ -1981,13 +1975,17 @@ const ( type apiConnRoutineManager struct { logger slog.Logger conn drpc.Conn aAPI proto.DRPCAgentClient23 tAPI tailnetproto.DRPCTailnetClient23 eg *errgroup.Group stopCtx context.Context remainCtx context.Context } func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog.Logger, conn drpc.Conn) *apiConnRoutineManager { func newAPIConnRoutineManager( gracefulCtx, hardCtx context.Context, logger slog.Logger, aAPI proto.DRPCAgentClient23, tAPI tailnetproto.DRPCTailnetClient23, ) *apiConnRoutineManager { // routines that remain in operation during graceful shutdown use the remainCtx. They'll still // exit if the errgroup hits an error, which usually means a problem with the conn. eg, remainCtx := errgroup.WithContext(hardCtx) Expand All @@ -2007,17 +2005,60 @@ func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog. stopCtx := eitherContext(remainCtx, gracefulCtx) return &apiConnRoutineManager{ logger: logger, conn: conn, aAPI: aAPI, tAPI: tAPI, eg: eg, stopCtx: stopCtx, remainCtx: remainCtx, } } func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f func(context.Context, drpc.Conn) error) { // startAgentAPI starts a routine that uses the Agent API. c.f. startTailnetAPI which is the same // but for Tailnet. func (a *apiConnRoutineManager) startAgentAPI( name string, behavior gracefulShutdownBehavior, f func(context.Context, proto.DRPCAgentClient23) error, ) { logger := a.logger.With(slog.F("name", name)) var ctx context.Context switch behavior { case gracefulShutdownBehaviorStop: ctx = a.stopCtx case gracefulShutdownBehaviorRemain: ctx = a.remainCtx default: panic("unknown behavior") } a.eg.Go(func() error { logger.Debug(ctx, "starting agent routine") err := f(ctx, a.aAPI) if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { logger.Debug(ctx, "swallowing context canceled") // Don't propagate context canceled errors to the error group, because we don't want the // graceful context being canceled to halt the work of routines with // gracefulShutdownBehaviorRemain. Note that we check both that the error is // context.Canceled and that *our* context is currently canceled, because when Coderd // unilaterally closes the API connection (for example if the build is outdated), it can // sometimes show up as context.Canceled in our RPC calls. return nil } logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) } return nil }) } // startTailnetAPI starts a routine that uses the Tailnet API. c.f. startAgentAPI which is the same // but for the Agent API. func (a *apiConnRoutineManager) startTailnetAPI( name string, behavior gracefulShutdownBehavior, f func(context.Context, tailnetproto.DRPCTailnetClient23) error, ) { logger := a.logger.With(slog.F("name", name)) var ctx context.Context switchb { switchbehavior { case gracefulShutdownBehaviorStop: ctx = a.stopCtx case gracefulShutdownBehaviorRemain: Expand All @@ -2026,8 +2067,8 @@ func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f panic("unknown behavior") } a.eg.Go(func() error { logger.Debug(ctx, "starting routine") err := f(ctx, a.conn ) logger.Debug(ctx, "startingtailnet routine") err := f(ctx, a.tAPI ) if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { logger.Debug(ctx, "swallowing context canceled") // Don't propagate context canceled errors to the error group, because we don't want the Expand Down