Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit83c63d4

Browse files
authored
fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds (#3354)
* fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmdsWe could turn it into a practice to wrap `cmd.Context()` so that we havemore fine-grained control of cancellation. Sometimes in tests we may berunning commands with a context that is never canceled.Related to#3221* fix: Set ssh session stderr to stderr
1 parent5ae19f0 commit83c63d4

File tree

4 files changed

+74
-52
lines changed

4 files changed

+74
-52
lines changed

‎cli/portforward.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ func portForward() *cobra.Command {
5555
},
5656
),
5757
RunE:func(cmd*cobra.Command,args []string)error {
58+
ctx,cancel:=context.WithCancel(cmd.Context())
59+
defercancel()
60+
5861
specs,err:=parsePortForwards(tcpForwards,udpForwards,unixForwards)
5962
iferr!=nil {
6063
returnxerrors.Errorf("parse port-forward specs: %w",err)
@@ -72,21 +75,21 @@ func portForward() *cobra.Command {
7275
returnerr
7376
}
7477

75-
workspace,agent,err:=getWorkspaceAndAgent(cmd,client,codersdk.Me,args[0],false)
78+
workspace,agent,err:=getWorkspaceAndAgent(ctx,cmd,client,codersdk.Me,args[0],false)
7679
iferr!=nil {
7780
returnerr
7881
}
7982
ifworkspace.LatestBuild.Transition!=codersdk.WorkspaceTransitionStart {
8083
returnxerrors.New("workspace must be in start transition to port-forward")
8184
}
8285
ifworkspace.LatestBuild.Job.CompletedAt==nil {
83-
err=cliui.WorkspaceBuild(cmd.Context(),cmd.ErrOrStderr(),client,workspace.LatestBuild.ID,workspace.CreatedAt)
86+
err=cliui.WorkspaceBuild(ctx,cmd.ErrOrStderr(),client,workspace.LatestBuild.ID,workspace.CreatedAt)
8487
iferr!=nil {
8588
returnerr
8689
}
8790
}
8891

89-
err=cliui.Agent(cmd.Context(),cmd.ErrOrStderr(), cliui.AgentOptions{
92+
err=cliui.Agent(ctx,cmd.ErrOrStderr(), cliui.AgentOptions{
9093
WorkspaceName:workspace.Name,
9194
Fetch:func(ctx context.Context) (codersdk.WorkspaceAgent,error) {
9295
returnclient.WorkspaceAgent(ctx,agent.ID)
@@ -96,15 +99,14 @@ func portForward() *cobra.Command {
9699
returnxerrors.Errorf("await agent: %w",err)
97100
}
98101

99-
conn,err:=client.DialWorkspaceAgent(cmd.Context(),agent.ID,nil)
102+
conn,err:=client.DialWorkspaceAgent(ctx,agent.ID,nil)
100103
iferr!=nil {
101104
returnxerrors.Errorf("dial workspace agent: %w",err)
102105
}
103106
deferconn.Close()
104107

105108
// Start all listeners.
106109
var (
107-
ctx,cancel=context.WithCancel(cmd.Context())
108110
wg=new(sync.WaitGroup)
109111
listeners=make([]net.Listener,len(specs))
110112
closeAllListeners=func() {
@@ -116,11 +118,11 @@ func portForward() *cobra.Command {
116118
}
117119
}
118120
)
119-
defercancel()
121+
defercloseAllListeners()
122+
120123
fori,spec:=rangespecs {
121124
l,err:=listenAndPortForward(ctx,cmd,conn,wg,spec)
122125
iferr!=nil {
123-
closeAllListeners()
124126
returnerr
125127
}
126128
listeners[i]=l
@@ -129,7 +131,10 @@ func portForward() *cobra.Command {
129131
// Wait for the context to be canceled or for a signal and close
130132
// all listeners.
131133
varcloseErrerror
134+
wg.Add(1)
132135
gofunc() {
136+
deferwg.Done()
137+
133138
sigs:=make(chan os.Signal,1)
134139
signal.Notify(sigs,syscall.SIGINT,syscall.SIGTERM)
135140

‎cli/ssh.go

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ func ssh() *cobra.Command {
5151
Short:"SSH into a workspace",
5252
Args:cobra.ArbitraryArgs,
5353
RunE:func(cmd*cobra.Command,args []string)error {
54+
ctx,cancel:=context.WithCancel(cmd.Context())
55+
defercancel()
56+
5457
client,err:=createClient(cmd)
5558
iferr!=nil {
5659
returnerr
@@ -68,14 +71,14 @@ func ssh() *cobra.Command {
6871
}
6972
}
7073

71-
workspace,workspaceAgent,err:=getWorkspaceAndAgent(cmd,client,codersdk.Me,args[0],shuffle)
74+
workspace,workspaceAgent,err:=getWorkspaceAndAgent(ctx,cmd,client,codersdk.Me,args[0],shuffle)
7275
iferr!=nil {
7376
returnerr
7477
}
7578

7679
// OpenSSH passes stderr directly to the calling TTY.
7780
// This is required in "stdio" mode so a connecting indicator can be displayed.
78-
err=cliui.Agent(cmd.Context(),cmd.ErrOrStderr(), cliui.AgentOptions{
81+
err=cliui.Agent(ctx,cmd.ErrOrStderr(), cliui.AgentOptions{
7982
WorkspaceName:workspace.Name,
8083
Fetch:func(ctx context.Context) (codersdk.WorkspaceAgent,error) {
8184
returnclient.WorkspaceAgent(ctx,workspaceAgent.ID)
@@ -85,42 +88,33 @@ func ssh() *cobra.Command {
8588
returnxerrors.Errorf("await agent: %w",err)
8689
}
8790

88-
var (
89-
sshClient*gossh.Client
90-
sshSession*gossh.Session
91-
)
91+
varnewSSHClientfunc() (*gossh.Client,error)
9292

9393
if!wireguard {
94-
conn,err:=client.DialWorkspaceAgent(cmd.Context(),workspaceAgent.ID,nil)
94+
conn,err:=client.DialWorkspaceAgent(ctx,workspaceAgent.ID,nil)
9595
iferr!=nil {
9696
returnerr
9797
}
9898
deferconn.Close()
9999

100-
stopPolling:=tryPollWorkspaceAutostop(cmd.Context(),client,workspace)
100+
stopPolling:=tryPollWorkspaceAutostop(ctx,client,workspace)
101101
deferstopPolling()
102102

103103
ifstdio {
104104
rawSSH,err:=conn.SSH()
105105
iferr!=nil {
106106
returnerr
107107
}
108+
deferrawSSH.Close()
109+
108110
gofunc() {
109111
_,_=io.Copy(cmd.OutOrStdout(),rawSSH)
110112
}()
111113
_,_=io.Copy(rawSSH,cmd.InOrStdin())
112114
returnnil
113115
}
114116

115-
sshClient,err=conn.SSHClient()
116-
iferr!=nil {
117-
returnerr
118-
}
119-
120-
sshSession,err=sshClient.NewSession()
121-
iferr!=nil {
122-
returnerr
123-
}
117+
newSSHClient=conn.SSHClient
124118
}else {
125119
// TODO: more granual control of Tailscale logging.
126120
peerwg.Logf=tslogger.Discard
@@ -133,8 +127,9 @@ func ssh() *cobra.Command {
133127
iferr!=nil {
134128
returnxerrors.Errorf("create wireguard network: %w",err)
135129
}
130+
deferwgn.Close()
136131

137-
err=client.PostWireguardPeer(cmd.Context(),workspace.ID, peerwg.Handshake{
132+
err=client.PostWireguardPeer(ctx,workspace.ID, peerwg.Handshake{
138133
Recipient:workspaceAgent.ID,
139134
NodePublicKey:wgn.NodePrivateKey.Public(),
140135
DiscoPublicKey:wgn.DiscoPublicKey,
@@ -155,10 +150,11 @@ func ssh() *cobra.Command {
155150
}
156151

157152
ifstdio {
158-
rawSSH,err:=wgn.SSH(cmd.Context(),workspaceAgent.IPv6.IP())
153+
rawSSH,err:=wgn.SSH(ctx,workspaceAgent.IPv6.IP())
159154
iferr!=nil {
160155
returnerr
161156
}
157+
deferrawSSH.Close()
162158

163159
gofunc() {
164160
_,_=io.Copy(cmd.OutOrStdout(),rawSSH)
@@ -167,16 +163,29 @@ func ssh() *cobra.Command {
167163
returnnil
168164
}
169165

170-
sshClient,err=wgn.SSHClient(cmd.Context(),workspaceAgent.IPv6.IP())
171-
iferr!=nil {
172-
returnerr
166+
newSSHClient=func() (*gossh.Client,error) {
167+
returnwgn.SSHClient(ctx,workspaceAgent.IPv6.IP())
173168
}
169+
}
174170

175-
sshSession,err=sshClient.NewSession()
176-
iferr!=nil {
177-
returnerr
178-
}
171+
sshClient,err:=newSSHClient()
172+
iferr!=nil {
173+
returnerr
174+
}
175+
defersshClient.Close()
176+
177+
sshSession,err:=sshClient.NewSession()
178+
iferr!=nil {
179+
returnerr
179180
}
181+
defersshSession.Close()
182+
183+
// Ensure context cancellation is propagated to the
184+
// SSH session, e.g. to cancel `Wait()` at the end.
185+
gofunc() {
186+
<-ctx.Done()
187+
_=sshSession.Close()
188+
}()
180189

181190
ifidentityAgent=="" {
182191
identityAgent=os.Getenv("SSH_AUTH_SOCK")
@@ -203,15 +212,18 @@ func ssh() *cobra.Command {
203212
_=term.Restore(int(stdinFile.Fd()),state)
204213
}()
205214

206-
windowChange:=listenWindowSize(cmd.Context())
215+
windowChange:=listenWindowSize(ctx)
207216
gofunc() {
208217
for {
209218
select {
210-
case<-cmd.Context().Done():
219+
case<-ctx.Done():
211220
return
212221
case<-windowChange:
213222
}
214-
width,height,_:=term.GetSize(int(stdoutFile.Fd()))
223+
width,height,err:=term.GetSize(int(stdoutFile.Fd()))
224+
iferr!=nil {
225+
continue
226+
}
215227
_=sshSession.WindowChange(height,width)
216228
}
217229
}()
@@ -224,13 +236,17 @@ func ssh() *cobra.Command {
224236

225237
sshSession.Stdin=cmd.InOrStdin()
226238
sshSession.Stdout=cmd.OutOrStdout()
227-
sshSession.Stderr=cmd.OutOrStdout()
239+
sshSession.Stderr=cmd.ErrOrStderr()
228240

229241
err=sshSession.Shell()
230242
iferr!=nil {
231243
returnerr
232244
}
233245

246+
// Put cancel at the top of the defer stack to initiate
247+
// shutdown of services.
248+
defercancel()
249+
234250
err=sshSession.Wait()
235251
iferr!=nil {
236252
// If the connection drops unexpectedly, we get an ExitMissingError but no other
@@ -259,16 +275,14 @@ func ssh() *cobra.Command {
259275
// getWorkspaceAgent returns the workspace and agent selected using either the
260276
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
261277
// if `shuffle` is true.
262-
funcgetWorkspaceAndAgent(cmd*cobra.Command,client*codersdk.Client,userIDstring,instring,shufflebool) (codersdk.Workspace, codersdk.WorkspaceAgent,error) {//nolint:revive
263-
ctx:=cmd.Context()
264-
278+
funcgetWorkspaceAndAgent(ctx context.Context,cmd*cobra.Command,client*codersdk.Client,userIDstring,instring,shufflebool) (codersdk.Workspace, codersdk.WorkspaceAgent,error) {//nolint:revive
265279
var (
266280
workspace codersdk.Workspace
267281
workspaceParts=strings.Split(in,".")
268282
errerror
269283
)
270284
ifshuffle {
271-
workspaces,err:=client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{
285+
workspaces,err:=client.Workspaces(ctx, codersdk.WorkspaceFilter{
272286
Owner:codersdk.Me,
273287
})
274288
iferr!=nil {

‎cli/ssh_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func TestSSH(t *testing.T) {
229229
pty:=ptytest.New(t)
230230
cmd.SetIn(pty.Input())
231231
cmd.SetOut(pty.Output())
232-
cmd.SetErr(io.Discard)
232+
cmd.SetErr(pty.Output())
233233
cmdDone:=tGo(t,func() {
234234
err:=cmd.ExecuteContext(ctx)
235235
assert.NoError(t,err)
@@ -248,9 +248,6 @@ func TestSSH(t *testing.T) {
248248

249249
// And we're done.
250250
pty.WriteLine("exit")
251-
// Read output to prevent hang on macOS, see:
252-
// https://github.com/coder/coder/issues/2122
253-
pty.ExpectMatch("exit")
254251
<-cmdDone
255252
})
256253
}

‎cli/wireguardtunnel.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
5252
},
5353
),
5454
RunE:func(cmd*cobra.Command,args []string)error {
55+
ctx,cancel:=context.WithCancel(cmd.Context())
56+
defercancel()
57+
5558
specs,err:=parsePortForwards(tcpForwards,nil,nil)
5659
iferr!=nil {
5760
returnxerrors.Errorf("parse port-forward specs: %w",err)
@@ -69,21 +72,21 @@ func wireguardPortForward() *cobra.Command {
6972
returnerr
7073
}
7174

72-
workspace,workspaceAgent,err:=getWorkspaceAndAgent(cmd,client,codersdk.Me,args[0],false)
75+
workspace,workspaceAgent,err:=getWorkspaceAndAgent(ctx,cmd,client,codersdk.Me,args[0],false)
7376
iferr!=nil {
7477
returnerr
7578
}
7679
ifworkspace.LatestBuild.Transition!=codersdk.WorkspaceTransitionStart {
7780
returnxerrors.New("workspace must be in start transition to port-forward")
7881
}
7982
ifworkspace.LatestBuild.Job.CompletedAt==nil {
80-
err=cliui.WorkspaceBuild(cmd.Context(),cmd.ErrOrStderr(),client,workspace.LatestBuild.ID,workspace.CreatedAt)
83+
err=cliui.WorkspaceBuild(ctx,cmd.ErrOrStderr(),client,workspace.LatestBuild.ID,workspace.CreatedAt)
8184
iferr!=nil {
8285
returnerr
8386
}
8487
}
8588

86-
err=cliui.Agent(cmd.Context(),cmd.ErrOrStderr(), cliui.AgentOptions{
89+
err=cliui.Agent(ctx,cmd.ErrOrStderr(), cliui.AgentOptions{
8790
WorkspaceName:workspace.Name,
8891
Fetch:func(ctx context.Context) (codersdk.WorkspaceAgent,error) {
8992
returnclient.WorkspaceAgent(ctx,workspaceAgent.ID)
@@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
101104
iferr!=nil {
102105
returnxerrors.Errorf("create wireguard network: %w",err)
103106
}
107+
deferwgn.Close()
104108

105-
err=client.PostWireguardPeer(cmd.Context(),workspace.ID, peerwg.Handshake{
109+
err=client.PostWireguardPeer(ctx,workspace.ID, peerwg.Handshake{
106110
Recipient:workspaceAgent.ID,
107111
NodePublicKey:wgn.NodePrivateKey.Public(),
108112
DiscoPublicKey:wgn.DiscoPublicKey,
@@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {
124128

125129
// Start all listeners.
126130
var (
127-
ctx,cancel=context.WithCancel(cmd.Context())
128131
wg=new(sync.WaitGroup)
129132
listeners=make([]net.Listener,len(specs))
130133
closeAllListeners=func() {
@@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
136139
}
137140
}
138141
)
139-
defercancel()
142+
defercloseAllListeners()
143+
140144
fori,spec:=rangespecs {
141145
l,err:=listenAndPortForwardWireguard(ctx,cmd,wgn,wg,spec,workspaceAgent.IPv6.IP())
142146
iferr!=nil {
143-
closeAllListeners()
144147
returnerr
145148
}
146149
listeners[i]=l
@@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
149152
// Wait for the context to be canceled or for a signal and close
150153
// all listeners.
151154
varcloseErrerror
155+
wg.Add(1)
152156
gofunc() {
157+
deferwg.Done()
158+
153159
sigs:=make(chan os.Signal,1)
154160
signal.Notify(sigs,syscall.SIGINT,syscall.SIGTERM)
155161

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp