Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6ff9a05

Browse files
authored
fix: close SSH sessions bottom-up if top-down fails (#14678)
1 parentff1eabe commit6ff9a05

File tree

2 files changed

+146
-29
lines changed

2 files changed

+146
-29
lines changed

‎cli/ssh.go

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/coder/coder/v2/codersdk/workspacesdk"
3838
"github.com/coder/coder/v2/cryptorand"
3939
"github.com/coder/coder/v2/pty"
40+
"github.com/coder/quartz"
4041
"github.com/coder/retry"
4142
"github.com/coder/serpent"
4243
)
@@ -48,6 +49,8 @@ const (
4849
var (
4950
workspacePollInterval=time.Minute
5051
autostopNotifyCountdown= []time.Duration{30*time.Minute}
52+
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
53+
gracefulShutdownTimeout=2*time.Second
5154
)
5255

5356
func (r*RootCmd)ssh()*serpent.Command {
@@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command {
153156
// log HTTP requests
154157
client.SetLogger(logger)
155158
}
156-
stack:=newCloserStack(ctx,logger)
159+
stack:=newCloserStack(ctx,logger,quartz.NewReal())
157160
deferstack.close(nil)
158161

159162
for_,remoteForward:=rangeremoteForwards {
@@ -936,11 +939,18 @@ type closerStack struct {
936939
closedbool
937940
logger slog.Logger
938941
errerror
939-
wg sync.WaitGroup
942+
allDonechanstruct{}
943+
944+
// for testing
945+
clock quartz.Clock
940946
}
941947

942-
funcnewCloserStack(ctx context.Context,logger slog.Logger)*closerStack {
943-
cs:=&closerStack{logger:logger}
948+
funcnewCloserStack(ctx context.Context,logger slog.Logger,clock quartz.Clock)*closerStack {
949+
cs:=&closerStack{
950+
logger:logger,
951+
allDone:make(chanstruct{}),
952+
clock:clock,
953+
}
944954
gocs.closeAfterContext(ctx)
945955
returncs
946956
}
@@ -954,20 +964,58 @@ func (c *closerStack) close(err error) {
954964
c.Lock()
955965
ifc.closed {
956966
c.Unlock()
957-
c.wg.Wait()
967+
<-c.allDone
958968
return
959969
}
960970
c.closed=true
961971
c.err=err
962-
c.wg.Add(1)
963-
deferc.wg.Done()
964972
c.Unlock()
973+
deferclose(c.allDone)
974+
iflen(c.closers)==0 {
975+
return
976+
}
965977

966-
fori:=len(c.closers)-1;i>=0;i-- {
967-
cwn:=c.closers[i]
968-
cErr:=cwn.closer.Close()
969-
c.logger.Debug(context.Background(),
970-
"closed item from stack",slog.F("name",cwn.name),slog.Error(cErr))
978+
// We are going to work down the stack in order. If things close quickly, we trigger the
979+
// closers serially, in order. `done` is a channel that indicates the nth closer is done
980+
// closing, and we should trigger the (n-1) closer. However, if things take too long we don't
981+
// want to wait, so we also start a ticker that works down the stack and sends on `done` as
982+
// well.
983+
next:=len(c.closers)-1
984+
// here we make the buffer 2x the number of closers because we could write once for it being
985+
// actually done and once via the countdown for each closer
986+
done:=make(chanint,len(c.closers)*2)
987+
startNext:=func() {
988+
gofunc(iint) {
989+
deferfunc() {done<-i }()
990+
cwn:=c.closers[i]
991+
cErr:=cwn.closer.Close()
992+
c.logger.Debug(context.Background(),
993+
"closed item from stack",slog.F("name",cwn.name),slog.Error(cErr))
994+
}(next)
995+
next--
996+
}
997+
done<-len(c.closers)// kick us off right away
998+
999+
// start a ticking countdown in case we hang/don't close quickly
1000+
countdown:=len(c.closers)-1
1001+
ctx,cancel:=context.WithCancel(context.Background())
1002+
defercancel()
1003+
c.clock.TickerFunc(ctx,gracefulShutdownTimeout,func()error {
1004+
ifcountdown<0 {
1005+
returnnil
1006+
}
1007+
done<-countdown
1008+
countdown--
1009+
returnnil
1010+
},"closerStack")
1011+
1012+
forn:=rangedone {// the nth closer is done
1013+
ifn==0 {
1014+
return
1015+
}
1016+
ifn-1==next {
1017+
startNext()
1018+
}
9711019
}
9721020
}
9731021

‎cli/ssh_internal_test.go

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package cli
22

33
import (
44
"context"
5+
"fmt"
56
"net/url"
7+
"sync"
68
"testing"
79
"time"
810

@@ -12,6 +14,7 @@ import (
1214

1315
"cdr.dev/slog"
1416
"cdr.dev/slog/sloggers/slogtest"
17+
"github.com/coder/quartz"
1518

1619
"github.com/coder/coder/v2/codersdk"
1720
"github.com/coder/coder/v2/testutil"
@@ -68,7 +71,7 @@ func TestCloserStack_Mainline(t *testing.T) {
6871
t.Parallel()
6972
ctx:=testutil.Context(t,testutil.WaitShort)
7073
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
71-
uut:=newCloserStack(ctx,logger)
74+
uut:=newCloserStack(ctx,logger,quartz.NewMock(t))
7275
closes:=new([]*fakeCloser)
7376
fc0:=&fakeCloser{closes:closes}
7477
fc1:=&fakeCloser{closes:closes}
@@ -84,13 +87,27 @@ func TestCloserStack_Mainline(t *testing.T) {
8487
require.Equal(t, []*fakeCloser{fc1,fc0},*closes)
8588
}
8689

90+
funcTestCloserStack_Empty(t*testing.T) {
91+
t.Parallel()
92+
ctx:=testutil.Context(t,testutil.WaitShort)
93+
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
94+
uut:=newCloserStack(ctx,logger,quartz.NewMock(t))
95+
96+
closed:=make(chanstruct{})
97+
gofunc() {
98+
deferclose(closed)
99+
uut.close(nil)
100+
}()
101+
testutil.RequireRecvCtx(ctx,t,closed)
102+
}
103+
87104
funcTestCloserStack_Context(t*testing.T) {
88105
t.Parallel()
89106
ctx:=testutil.Context(t,testutil.WaitShort)
90107
ctx,cancel:=context.WithCancel(ctx)
91108
defercancel()
92109
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
93-
uut:=newCloserStack(ctx,logger)
110+
uut:=newCloserStack(ctx,logger,quartz.NewMock(t))
94111
closes:=new([]*fakeCloser)
95112
fc0:=&fakeCloser{closes:closes}
96113
fc1:=&fakeCloser{closes:closes}
@@ -111,7 +128,7 @@ func TestCloserStack_PushAfterClose(t *testing.T) {
111128
t.Parallel()
112129
ctx:=testutil.Context(t,testutil.WaitShort)
113130
logger:=slogtest.Make(t,&slogtest.Options{IgnoreErrors:true}).Leveled(slog.LevelDebug)
114-
uut:=newCloserStack(ctx,logger)
131+
uut:=newCloserStack(ctx,logger,quartz.NewMock(t))
115132
closes:=new([]*fakeCloser)
116133
fc0:=&fakeCloser{closes:closes}
117134
fc1:=&fakeCloser{closes:closes}
@@ -134,13 +151,9 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
134151
ctx,cancel:=context.WithCancel(testCtx)
135152
defercancel()
136153
logger:=slogtest.Make(t,&slogtest.Options{IgnoreErrors:true}).Leveled(slog.LevelDebug)
137-
uut:=newCloserStack(ctx,logger)
138-
ac:=&asyncCloser{
139-
t:t,
140-
ctx:testCtx,
141-
complete:make(chanstruct{}),
142-
started:make(chanstruct{}),
143-
}
154+
uut:=newCloserStack(ctx,logger,quartz.NewMock(t))
155+
ac:=newAsyncCloser(testCtx,t)
156+
deferac.complete()
144157
err:=uut.push("async",ac)
145158
require.NoError(t,err)
146159
cancel()
@@ -160,11 +173,53 @@ func TestCloserStack_CloseAfterContext(t *testing.T) {
160173
t.Fatal("closed before stack was finished")
161174
}
162175

163-
// complete the asyncCloser
164-
close(ac.complete)
176+
ac.complete()
165177
testutil.RequireRecvCtx(testCtx,t,closed)
166178
}
167179

180+
funcTestCloserStack_Timeout(t*testing.T) {
181+
t.Parallel()
182+
ctx:=testutil.Context(t,testutil.WaitShort)
183+
logger:=slogtest.Make(t,&slogtest.Options{IgnoreErrors:true}).Leveled(slog.LevelDebug)
184+
mClock:=quartz.NewMock(t)
185+
trap:=mClock.Trap().TickerFunc("closerStack")
186+
defertrap.Close()
187+
uut:=newCloserStack(ctx,logger,mClock)
188+
varac [3]*asyncCloser
189+
fori:=rangeac {
190+
ac[i]=newAsyncCloser(ctx,t)
191+
err:=uut.push(fmt.Sprintf("async %d",i),ac[i])
192+
require.NoError(t,err)
193+
}
194+
deferfunc() {
195+
for_,a:=rangeac {
196+
a.complete()
197+
}
198+
}()
199+
200+
closed:=make(chanstruct{})
201+
gofunc() {
202+
deferclose(closed)
203+
uut.close(nil)
204+
}()
205+
trap.MustWait(ctx).Release()
206+
// top starts right away, but it hangs
207+
testutil.RequireRecvCtx(ctx,t,ac[2].started)
208+
// timer pops and we start the middle one
209+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
210+
testutil.RequireRecvCtx(ctx,t,ac[1].started)
211+
212+
// middle one finishes
213+
ac[1].complete()
214+
// bottom starts, but also hangs
215+
testutil.RequireRecvCtx(ctx,t,ac[0].started)
216+
217+
// timer has to pop twice to time out.
218+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
219+
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
220+
testutil.RequireRecvCtx(ctx,t,closed)
221+
}
222+
168223
typefakeCloserstruct {
169224
closes*[]*fakeCloser
170225
errerror
@@ -176,10 +231,11 @@ func (c *fakeCloser) Close() error {
176231
}
177232

178233
typeasyncCloserstruct {
179-
t*testing.T
180-
ctx context.Context
181-
startedchanstruct{}
182-
completechanstruct{}
234+
t*testing.T
235+
ctx context.Context
236+
startedchanstruct{}
237+
isCompletechanstruct{}
238+
comepleteOnce sync.Once
183239
}
184240

185241
func (c*asyncCloser)Close()error {
@@ -188,7 +244,20 @@ func (c *asyncCloser) Close() error {
188244
case<-c.ctx.Done():
189245
c.t.Error("timed out")
190246
returnc.ctx.Err()
191-
case<-c.complete:
247+
case<-c.isComplete:
192248
returnnil
193249
}
194250
}
251+
252+
func (c*asyncCloser)complete() {
253+
c.comepleteOnce.Do(func() {close(c.isComplete) })
254+
}
255+
256+
funcnewAsyncCloser(ctx context.Context,t*testing.T)*asyncCloser {
257+
return&asyncCloser{
258+
t:t,
259+
ctx:ctx,
260+
isComplete:make(chanstruct{}),
261+
started:make(chanstruct{}),
262+
}
263+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp