Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit03a81d4

Browse files
authored
Merge branch 'main' into jaaydenh/workspace-creation-fix
2 parents0b13269 +b61f0ab commit03a81d4

File tree

5 files changed

+191
-57
lines changed

5 files changed

+191
-57
lines changed

‎agent/agent.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,15 +1773,22 @@ func (a *agent) Close() error {
17731773
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
17741774

17751775
// Attempt to gracefully shut down all active SSH connections and
1776-
// stop accepting new ones.
1777-
err:=a.sshServer.Shutdown(a.hardCtx)
1776+
// stop accepting new ones. If all processes have not exited after 5
1777+
// seconds, we just log it and move on as it's more important to run
1778+
// the shutdown scripts. A typical shutdown time for containers is
1779+
// 10 seconds, so this still leaves a bit of time to run the
1780+
// shutdown scripts in the worst-case.
1781+
sshShutdownCtx,sshShutdownCancel:=context.WithTimeout(a.hardCtx,5*time.Second)
1782+
defersshShutdownCancel()
1783+
err:=a.sshServer.Shutdown(sshShutdownCtx)
17781784
iferr!=nil {
1779-
a.logger.Error(a.hardCtx,"ssh server shutdown",slog.Error(err))
1780-
}
1781-
err=a.sshServer.Close()
1782-
iferr!=nil {
1783-
a.logger.Error(a.hardCtx,"ssh server close",slog.Error(err))
1785+
iferrors.Is(err,context.DeadlineExceeded) {
1786+
a.logger.Warn(sshShutdownCtx,"ssh server shutdown timeout",slog.Error(err))
1787+
}else {
1788+
a.logger.Error(sshShutdownCtx,"ssh server shutdown",slog.Error(err))
1789+
}
17841790
}
1791+
17851792
// wait for SSH to shut down before the general graceful cancel, because
17861793
// this triggers a disconnect in the tailnet layer, telling all clients to
17871794
// shut down their wireguard tunnels to us. If SSH sessions are still up,

‎agent/agentssh/agentssh.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
582582
func (s*Server)startNonPTYSession(logger slog.Logger,session ssh.Session,magicTypeLabelstring,cmd*exec.Cmd)error {
583583
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel,"no").Add(1)
584584

585+
// Create a process group and send SIGHUP to child processes,
586+
// otherwise context cancellation will not propagate properly
587+
// and SSH server close may be delayed.
588+
cmd.SysProcAttr=cmdSysProcAttr()
589+
cmd.Cancel=cmdCancel(session.Context(),logger,cmd)
590+
585591
cmd.Stdout=session
586592
cmd.Stderr=session.Stderr()
587593
// This blocks forever until stdin is received if we don't
@@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
926932
// Serve starts the server to handle incoming connections on the provided listener.
927933
// It returns an error if no host keys are set or if there is an issue accepting connections.
928934
func (s*Server)Serve(l net.Listener) (retErrerror) {
929-
iflen(s.srv.HostSigners)==0 {
935+
// Ensure we're not mutating HostSigners as we're reading it.
936+
s.mu.RLock()
937+
noHostKeys:=len(s.srv.HostSigners)==0
938+
s.mu.RUnlock()
939+
940+
ifnoHostKeys {
930941
returnxerrors.New("no host keys set")
931942
}
932943

@@ -1054,43 +1065,72 @@ func (s *Server) Close() error {
10541065
}
10551066
s.closing=make(chanstruct{})
10561067

1068+
ctx:=context.Background()
1069+
1070+
s.logger.Debug(ctx,"closing server")
1071+
1072+
// Stop accepting new connections.
1073+
s.logger.Debug(ctx,"closing all active listeners",slog.F("count",len(s.listeners)))
1074+
forl:=ranges.listeners {
1075+
_=l.Close()
1076+
}
1077+
10571078
// Close all active sessions to gracefully
10581079
// terminate client connections.
1080+
s.logger.Debug(ctx,"closing all active sessions",slog.F("count",len(s.sessions)))
10591081
forss:=ranges.sessions {
10601082
// We call Close on the underlying channel here because we don't
10611083
// want to send an exit status to the client (via Exit()).
10621084
// Typically OpenSSH clients will return 255 as the exit status.
10631085
_=ss.Close()
10641086
}
1065-
1066-
// Close all active listeners and connections.
1067-
forl:=ranges.listeners {
1068-
_=l.Close()
1069-
}
1087+
s.logger.Debug(ctx,"closing all active connections",slog.F("count",len(s.conns)))
10701088
forc:=ranges.conns {
10711089
_=c.Close()
10721090
}
10731091

1074-
// Close the underlyingSSH server.
1092+
s.logger.Debug(ctx,"closingSSH server")
10751093
err:=s.srv.Close()
10761094

10771095
s.mu.Unlock()
1096+
1097+
s.logger.Debug(ctx,"waiting for all goroutines to exit")
10781098
s.wg.Wait()// Wait for all goroutines to exit.
10791099

10801100
s.mu.Lock()
10811101
close(s.closing)
10821102
s.closing=nil
10831103
s.mu.Unlock()
10841104

1105+
s.logger.Debug(ctx,"closing server done")
1106+
10851107
returnerr
10861108
}
10871109

1088-
// Shutdown gracefully closes all active SSH connections and stops
1089-
// accepting new connections.
1090-
//
1091-
// Shutdown is not implemented.
1092-
func (*Server)Shutdown(_ context.Context)error {
1093-
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1110+
// Shutdown stops accepting new connections. The current implementation
1111+
// calls Close() for simplicity instead of waiting for existing
1112+
// connections to close. If the context times out, Shutdown will return
1113+
// but Close() may not have completed.
1114+
func (s*Server)Shutdown(ctx context.Context)error {
1115+
ch:=make(chanerror,1)
1116+
gofunc() {
1117+
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1118+
// For now we just close the server.
1119+
ch<-s.Close()
1120+
}()
1121+
varerrerror
1122+
select {
1123+
case<-ctx.Done():
1124+
err=ctx.Err()
1125+
caseerr=<-ch:
1126+
}
1127+
// Re-check for context cancellation precedence.
1128+
ifctx.Err()!=nil {
1129+
err=ctx.Err()
1130+
}
1131+
iferr!=nil {
1132+
returnxerrors.Errorf("close server: %w",err)
1133+
}
10941134
returnnil
10951135
}
10961136

‎agent/agentssh/agentssh_test.go

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"go.uber.org/goleak"
2222
"golang.org/x/crypto/ssh"
2323

24+
"cdr.dev/slog"
2425
"cdr.dev/slog/sloggers/slogtest"
2526

2627
"github.com/coder/coder/v2/agent/agentexec"
@@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
147148
funcTestNewServer_CloseActiveConnections(t*testing.T) {
148149
t.Parallel()
149150

150-
ctx:=context.Background()
151-
logger:=slogtest.Make(t,&slogtest.Options{IgnoreErrors:true})
152-
s,err:=agentssh.NewServer(ctx,logger,prometheus.NewRegistry(),afero.NewMemMapFs(),agentexec.DefaultExecer,nil)
153-
require.NoError(t,err)
154-
defers.Close()
155-
err=s.UpdateHostSigner(42)
156-
assert.NoError(t,err)
151+
prepare:=func(ctx context.Context,t*testing.T) (*agentssh.Server,func()) {
152+
t.Helper()
153+
logger:=slogtest.Make(t,&slogtest.Options{IgnoreErrors:true}).Leveled(slog.LevelDebug)
154+
s,err:=agentssh.NewServer(ctx,logger,prometheus.NewRegistry(),afero.NewMemMapFs(),agentexec.DefaultExecer,nil)
155+
require.NoError(t,err)
156+
defers.Close()
157+
err=s.UpdateHostSigner(42)
158+
assert.NoError(t,err)
157159

158-
ln,err:=net.Listen("tcp","127.0.0.1:0")
159-
require.NoError(t,err)
160+
ln,err:=net.Listen("tcp","127.0.0.1:0")
161+
require.NoError(t,err)
160162

161-
varwg sync.WaitGroup
162-
wg.Add(2)
163-
gofunc() {
164-
deferwg.Done()
165-
err:=s.Serve(ln)
166-
assert.Error(t,err)// Server is closed.
167-
}()
163+
waitConns:=make([]chanstruct{},4)
168164

169-
pty:=ptytest.New(t)
165+
varwg sync.WaitGroup
166+
wg.Add(1+len(waitConns))
170167

171-
doClose:=make(chanstruct{})
172-
gofunc() {
173-
deferwg.Done()
174-
c:=sshClient(t,ln.Addr().String())
175-
sess,err:=c.NewSession()
176-
assert.NoError(t,err)
177-
sess.Stdin=pty.Input()
178-
sess.Stdout=pty.Output()
179-
sess.Stderr=pty.Output()
168+
gofunc() {
169+
deferwg.Done()
170+
err:=s.Serve(ln)
171+
assert.Error(t,err)// Server is closed.
172+
}()
180173

181-
assert.NoError(t,err)
182-
err=sess.Start("")
183-
assert.NoError(t,err)
174+
fori:=0;i<len(waitConns);i++ {
175+
waitConns[i]=make(chanstruct{})
176+
gofunc(chchanstruct{}) {
177+
deferwg.Done()
178+
c:=sshClient(t,ln.Addr().String())
179+
sess,err:=c.NewSession()
180+
assert.NoError(t,err)
181+
pty:=ptytest.New(t)
182+
sess.Stdin=pty.Input()
183+
sess.Stdout=pty.Output()
184+
sess.Stderr=pty.Output()
185+
186+
// Every other session will request a PTY.
187+
ifi%2==0 {
188+
err=sess.RequestPty("xterm",80,80,nil)
189+
assert.NoError(t,err)
190+
}
191+
// The 60 seconds here is intended to be longer than the
192+
// test. The shutdown should propagate.
193+
err=sess.Start("/bin/bash -c 'trap\"sleep 60\" SIGTERM; sleep 60'")
194+
assert.NoError(t,err)
195+
196+
close(ch)
197+
err=sess.Wait()
198+
assert.Error(t,err)
199+
}(waitConns[i])
200+
}
184201

185-
close(doClose)
186-
err=sess.Wait()
187-
assert.Error(t,err)
188-
}()
202+
for_,ch:=rangewaitConns {
203+
<-ch
204+
}
189205

190-
<-doClose
191-
err=s.Close()
192-
require.NoError(t,err)
206+
returns,wg.Wait
207+
}
208+
209+
t.Run("Close",func(t*testing.T) {
210+
t.Parallel()
211+
ctx:=testutil.Context(t,testutil.WaitMedium)
212+
s,wait:=prepare(ctx,t)
213+
err:=s.Close()
214+
require.NoError(t,err)
215+
wait()
216+
})
193217

194-
wg.Wait()
218+
t.Run("Shutdown",func(t*testing.T) {
219+
t.Parallel()
220+
ctx:=testutil.Context(t,testutil.WaitMedium)
221+
s,wait:=prepare(ctx,t)
222+
err:=s.Shutdown(ctx)
223+
require.NoError(t,err)
224+
wait()
225+
})
226+
227+
t.Run("Shutdown Early",func(t*testing.T) {
228+
t.Parallel()
229+
ctx:=testutil.Context(t,testutil.WaitMedium)
230+
s,wait:=prepare(ctx,t)
231+
ctx,cancel:=context.WithCancel(ctx)
232+
cancel()
233+
err:=s.Shutdown(ctx)
234+
require.ErrorIs(t,err,context.Canceled)
235+
wait()
236+
})
195237
}
196238

197239
funcTestNewServer_Signal(t*testing.T) {

‎agent/agentssh/exec_other.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//go:build !windows
2+
3+
package agentssh
4+
5+
import (
6+
"context"
7+
"os/exec"
8+
"syscall"
9+
10+
"cdr.dev/slog"
11+
)
12+
13+
funccmdSysProcAttr()*syscall.SysProcAttr {
14+
return&syscall.SysProcAttr{
15+
Setsid:true,
16+
}
17+
}
18+
19+
funccmdCancel(ctx context.Context,logger slog.Logger,cmd*exec.Cmd)func()error {
20+
returnfunc()error {
21+
logger.Debug(ctx,"cmdCancel: sending SIGHUP to process and children",slog.F("pid",cmd.Process.Pid))
22+
returnsyscall.Kill(-cmd.Process.Pid,syscall.SIGHUP)
23+
}
24+
}

‎agent/agentssh/exec_windows.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package agentssh
2+
3+
import (
4+
"context"
5+
"os"
6+
"os/exec"
7+
"syscall"
8+
9+
"cdr.dev/slog"
10+
)
11+
12+
funccmdSysProcAttr()*syscall.SysProcAttr {
13+
return&syscall.SysProcAttr{}
14+
}
15+
16+
funccmdCancel(ctx context.Context,logger slog.Logger,cmd*exec.Cmd)func()error {
17+
returnfunc()error {
18+
logger.Debug(ctx,"cmdCancel: sending interrupt to process",slog.F("pid",cmd.Process.Pid))
19+
returncmd.Process.Signal(os.Interrupt)
20+
}
21+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp