- Notifications
You must be signed in to change notification settings - Fork928
fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds#3354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -51,6 +51,9 @@ func ssh() *cobra.Command { | ||
Short: "SSH into a workspace", | ||
Args: cobra.ArbitraryArgs, | ||
RunE: func(cmd *cobra.Command, args []string) error { | ||
ctx, cancel := context.WithCancel(cmd.Context()) | ||
defer cancel() | ||
client, err := createClient(cmd) | ||
if err != nil { | ||
return err | ||
@@ -68,14 +71,14 @@ func ssh() *cobra.Command { | ||
} | ||
} | ||
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx,cmd, client, codersdk.Me, args[0], shuffle) | ||
if err != nil { | ||
return err | ||
} | ||
// OpenSSH passes stderr directly to the calling TTY. | ||
// This is required in "stdio" mode so a connecting indicator can be displayed. | ||
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ | ||
WorkspaceName: workspace.Name, | ||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { | ||
return client.WorkspaceAgent(ctx, workspaceAgent.ID) | ||
@@ -85,42 +88,33 @@ func ssh() *cobra.Command { | ||
return xerrors.Errorf("await agent: %w", err) | ||
} | ||
var newSSHClient func() (*gossh.Client, error) | ||
if !wireguard { | ||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) | ||
if err != nil { | ||
return err | ||
} | ||
defer conn.Close() | ||
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) | ||
defer stopPolling() | ||
if stdio { | ||
rawSSH, err := conn.SSH() | ||
if err != nil { | ||
return err | ||
} | ||
defer rawSSH.Close() | ||
go func() { | ||
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH) | ||
}() | ||
_, _ = io.Copy(rawSSH, cmd.InOrStdin()) | ||
return nil | ||
} | ||
newSSHClient = conn.SSHClient | ||
} else { | ||
// TODO: more granual control of Tailscale logging. | ||
peerwg.Logf = tslogger.Discard | ||
@@ -133,8 +127,9 @@ func ssh() *cobra.Command { | ||
if err != nil { | ||
return xerrors.Errorf("create wireguard network: %w", err) | ||
} | ||
defer wgn.Close() | ||
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ | ||
Recipient: workspaceAgent.ID, | ||
NodePublicKey: wgn.NodePrivateKey.Public(), | ||
DiscoPublicKey: wgn.DiscoPublicKey, | ||
@@ -155,10 +150,11 @@ func ssh() *cobra.Command { | ||
} | ||
if stdio { | ||
rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP()) | ||
if err != nil { | ||
return err | ||
} | ||
defer rawSSH.Close() | ||
go func() { | ||
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH) | ||
@@ -167,16 +163,29 @@ func ssh() *cobra.Command { | ||
return nil | ||
} | ||
newSSHClient = func() (*gossh.Client, error) { | ||
return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP()) | ||
} | ||
} | ||
sshClient, err := newSSHClient() | ||
if err != nil { | ||
return err | ||
} | ||
defer sshClient.Close() | ||
sshSession, err := sshClient.NewSession() | ||
if err != nil { | ||
return err | ||
} | ||
defer sshSession.Close() | ||
// Ensure context cancellation is propagated to the | ||
// SSH session, e.g. to cancel `Wait()` at the end. | ||
go func() { | ||
<-ctx.Done() | ||
_ = sshSession.Close() | ||
}() | ||
if identityAgent == "" { | ||
identityAgent = os.Getenv("SSH_AUTH_SOCK") | ||
@@ -203,15 +212,18 @@ func ssh() *cobra.Command { | ||
_ = term.Restore(int(stdinFile.Fd()), state) | ||
}() | ||
windowChange := listenWindowSize(ctx) | ||
go func() { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
case <-windowChange: | ||
} | ||
width, height, err := term.GetSize(int(stdoutFile.Fd())) | ||
if err != nil { | ||
continue | ||
} | ||
_ = sshSession.WindowChange(height, width) | ||
} | ||
}() | ||
@@ -224,13 +236,17 @@ func ssh() *cobra.Command { | ||
sshSession.Stdin = cmd.InOrStdin() | ||
sshSession.Stdout = cmd.OutOrStdout() | ||
sshSession.Stderr = cmd.ErrOrStderr() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. This is a drive-by change I did. It seemed wrong but perhaps I didn't understand the purpose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Seems reasonable to me! | ||
err = sshSession.Shell() | ||
if err != nil { | ||
return err | ||
} | ||
// Put cancel at the top of the defer stack to initiate | ||
// shutdown of services. | ||
defer cancel() | ||
err = sshSession.Wait() | ||
if err != nil { | ||
// If the connection drops unexpectedly, we get an ExitMissingError but no other | ||
@@ -259,16 +275,14 @@ func ssh() *cobra.Command { | ||
// getWorkspaceAgent returns the workspace and agent selected using either the | ||
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent | ||
// if `shuffle` is true. | ||
func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive | ||
var ( | ||
workspace codersdk.Workspace | ||
workspaceParts = strings.Split(in, ".") | ||
err error | ||
) | ||
if shuffle { | ||
workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ | ||
Owner: codersdk.Me, | ||
}) | ||
if err != nil { | ||