Expand Up @@ -3,12 +3,10 @@ package agent import ( "bytes" "context" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/netip" "os" Expand Down Expand Up @@ -216,8 +214,8 @@ type agent struct { portCacheDuration time.Duration subsystems []codersdk.AgentSubsystem reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration reconnectingPTYServer *reconnectingpty.Server // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time // to start gracefully shutting down and "hard" which is Done when it is time to close Expand Down Expand Up @@ -252,8 +250,6 @@ type agent struct { statsReporter *statsReporter logSender *agentsdk.LogSender connCountReconnectingPTY atomic.Int64 prometheusRegistry *prometheus.Registry // metrics are prometheus registered metrics that will be collected and // labeled in Coder with the agent + workspace. Expand Down Expand Up @@ -297,6 +293,13 @@ func (a *agent) init() { // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), a.sshServer, a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors, a.reconnectingPTYTimeout, ) go a.runLoop() } Expand Down Expand Up @@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t } }() if err = a.trackGoroutine(func() { logger := a.logger.Named("reconnecting-pty") var wg sync.WaitGroup for { conn, err := reconnectingPTYListener.Accept() if err != nil { if !a.isClosed() { logger.Debug(ctx, "accept pty failed", slog.Error(err)) } break } clog := logger.With( slog.F("remote", conn.RemoteAddr().String()), slog.F("local", conn.LocalAddr().String())) clog.Info(ctx, "accepted conn") wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() }() go func() { defer close(closed) // This cannot use a JSON decoder, since that can // buffer additional data that is required for the PTY. rawLen := make([]byte, 2) _, err = conn.Read(rawLen) if err != nil { return } length := binary.LittleEndian.Uint16(rawLen) data := make([]byte, length) _, err = conn.Read(data) if err != nil { return } var msg workspacesdk.AgentReconnectingPTYInit err = json.Unmarshal(data, &msg) if err != nil { logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data)) return } _ = a.handleReconnectingPTY(ctx, clog, msg, conn) }() rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener) if rPTYServeErr != nil && a.gracefulCtx.Err() == nil && !strings.Contains(rPTYServeErr.Error(), "use of closed network connection") { a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err)) } wg.Wait() }); err != nil { return nil, err } Expand Down Expand Up @@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = server.Close() }() err := server.Serve(apiListener)iferr != nil && !xerrors.Is(err , http.ErrServerClosed) && !strings.Contains(err .Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err )) apiServErr := server.Serve(apiListener)ifapiServErr != nil && !xerrors.Is(apiServErr , http.ErrServerClosed) && !strings.Contains(apiServErr .Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr )) } }); err != nil { return nil, err Expand Down Expand Up @@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D } } func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() a.metrics.connectionsTotal.Add(1) a.connCountReconnectingPTY.Add(1) defer a.connCountReconnectingPTY.Add(-1) connectionID := uuid.NewString() connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) connLogger.Debug(ctx, "starting handler") defer func() { if err := retErr; err != nil { a.closeMutex.Lock() closed := a.isClosed() a.closeMutex.Unlock() // If the agent is closed, we don't want to // log this as an error since it's expected. if closed { connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) } else { connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) } } connLogger.Info(ctx, "reconnecting pty connection closed") }() var rpty reconnectingpty.ReconnectingPTY sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1) // On store, reserve this ID to prevent multiple concurrent new connections. waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) if ok { close(sendConnected) // Unused. connLogger.Debug(ctx, "connecting to existing reconnecting pty") c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY) if !ok { return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) } rpty, ok = <-c if !ok || rpty == nil { return xerrors.Errorf("reconnecting pty closed before connection") } c <- rpty // Put it back for the next reconnect. } else { connLogger.Debug(ctx, "creating new reconnecting pty") connected := false defer func() { if !connected && retErr != nil { a.reconnectingPTYs.Delete(msg.ID) close(sendConnected) } }() // Empty command will default to the users shell! cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err) } rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{ Timeout: a.reconnectingPTYTimeout, Metrics: a.metrics.reconnectingPTYErrors, }, logger.With(slog.F("message_id", msg.ID))) if err = a.trackGoroutine(func() { rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { rpty.Close(err) return xerrors.Errorf("start routine: %w", err) } connected = true sendConnected <- rpty } return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) } // Collect collects additional stats from the agent func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats { a.logger.Debug(context.Background(), "computing stats report") Expand All @@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect stats.SessionCountVscode = sshStats.VSCode stats.SessionCountJetbrains = sshStats.JetBrains stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load () stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount () // Compute the median connection latency! a.logger.Debug(ctx, "starting peer latency measurement for stats") Expand Down