Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit3d85cdf

Browse files
authored
feat: set peers lost when disconnected from coordinator (#11681)
Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost.This covers CoderSDK (CLI client) and Agent. Next PR will cover MultiAgent (notably, `wsproxy`).
1 parent9f6b38c commit3d85cdf

File tree

4 files changed

+226
-28
lines changed

4 files changed

+226
-28
lines changed

‎tailnet/conn.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
356356
returnnil
357357
}
358358

359+
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
360+
func (c*Conn)SetAllPeersLost() {
361+
c.configMaps.setAllPeersLost()
362+
}
363+
359364
// NodeAddresses returns the addresses of a node from the NetworkMap.
360365
func (c*Conn)NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix,bool) {
361366
returnc.configMaps.nodeAddresses(publicKey)

‎tailnet/coordinator.go

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type Node struct {
9797
// Conn.
9898
typeCoordinateeinterface {
9999
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate)error
100+
SetAllPeersLost()
100101
SetNodeCallback(func(*Node))
101102
}
102103

@@ -107,20 +108,28 @@ type Coordination interface {
107108

108109
typeremoteCoordinationstruct {
109110
sync.Mutex
110-
closedbool
111-
errChanchanerror
112-
coordinateeCoordinatee
113-
logger slog.Logger
114-
protocol proto.DRPCTailnet_CoordinateClient
111+
closedbool
112+
errChanchanerror
113+
coordinateeCoordinatee
114+
logger slog.Logger
115+
protocol proto.DRPCTailnet_CoordinateClient
116+
respLoopDonechanstruct{}
115117
}
116118

117-
func (c*remoteCoordination)Close()error {
119+
func (c*remoteCoordination)Close()(retErrerror) {
118120
c.Lock()
119121
deferc.Unlock()
120122
ifc.closed {
121123
returnnil
122124
}
123125
c.closed=true
126+
deferfunc() {
127+
protoErr:=c.protocol.Close()
128+
<-c.respLoopDone
129+
ifretErr==nil {
130+
retErr=protoErr
131+
}
132+
}()
124133
err:=c.protocol.Send(&proto.CoordinateRequest{Disconnect:&proto.CoordinateRequest_Disconnect{}})
125134
iferr!=nil {
126135
returnxerrors.Errorf("send disconnect: %w",err)
@@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
140149
}
141150

142151
func (c*remoteCoordination)respLoop() {
152+
deferfunc() {
153+
c.coordinatee.SetAllPeersLost()
154+
close(c.respLoopDone)
155+
}()
143156
for {
144157
resp,err:=c.protocol.Recv()
145158
iferr!=nil {
@@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
162175
tunnelTarget uuid.UUID,
163176
)Coordination {
164177
c:=&remoteCoordination{
165-
errChan:make(chanerror,1),
166-
coordinatee:coordinatee,
167-
logger:logger,
168-
protocol:protocol,
178+
errChan:make(chanerror,1),
179+
coordinatee:coordinatee,
180+
logger:logger,
181+
protocol:protocol,
182+
respLoopDone:make(chanstruct{}),
169183
}
170184
iftunnelTarget!=uuid.Nil {
171185
c.Lock()
@@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
200214

201215
typeinMemoryCoordinationstruct {
202216
sync.Mutex
203-
ctx context.Context
204-
errChanchanerror
205-
closedbool
206-
closedChchanstruct{}
207-
coordinateeCoordinatee
208-
logger slog.Logger
209-
resps<-chan*proto.CoordinateResponse
210-
reqschan<-*proto.CoordinateRequest
217+
ctx context.Context
218+
errChanchanerror
219+
closedbool
220+
closedChchanstruct{}
221+
respLoopDonechanstruct{}
222+
coordinateeCoordinatee
223+
logger slog.Logger
224+
resps<-chan*proto.CoordinateResponse
225+
reqschan<-*proto.CoordinateRequest
211226
}
212227

213228
func (c*inMemoryCoordination)sendErr(errerror) {
@@ -238,11 +253,12 @@ func NewInMemoryCoordination(
238253
thisID=clientID
239254
}
240255
c:=&inMemoryCoordination{
241-
ctx:ctx,
242-
errChan:make(chanerror,1),
243-
coordinatee:coordinatee,
244-
logger:logger,
245-
closedCh:make(chanstruct{}),
256+
ctx:ctx,
257+
errChan:make(chanerror,1),
258+
coordinatee:coordinatee,
259+
logger:logger,
260+
closedCh:make(chanstruct{}),
261+
respLoopDone:make(chanstruct{}),
246262
}
247263

248264
// use the background context since we will depend exclusively on closing the req channel to
@@ -285,6 +301,10 @@ func NewInMemoryCoordination(
285301
}
286302

287303
func (c*inMemoryCoordination)respLoop() {
304+
deferfunc() {
305+
c.coordinatee.SetAllPeersLost()
306+
close(c.respLoopDone)
307+
}()
288308
for {
289309
select {
290310
case<-c.closedCh:
@@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
315335
deferclose(c.reqs)
316336
c.closed=true
317337
close(c.closedCh)
338+
<-c.respLoopDone
318339
select {
319340
case<-c.ctx.Done():
320341
returnxerrors.Errorf("failed to gracefully disconnect: %w",c.ctx.Err())

‎tailnet/coordinator_test.go

Lines changed: 167 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@ import (
66
"net"
77
"net/http"
88
"net/http/httptest"
9+
"sync"
10+
"sync/atomic"
911
"testing"
1012
"time"
1113

12-
"nhooyr.io/websocket"
13-
14-
"cdr.dev/slog"
15-
"cdr.dev/slog/sloggers/slogtest"
16-
1714
"github.com/google/uuid"
1815
"github.com/stretchr/testify/assert"
1916
"github.com/stretchr/testify/require"
17+
"go.uber.org/mock/gomock"
18+
"nhooyr.io/websocket"
19+
"tailscale.com/tailcfg"
20+
"tailscale.com/types/key"
2021

22+
"cdr.dev/slog"
23+
"cdr.dev/slog/sloggers/slogtest"
2124
"github.com/coder/coder/v2/tailnet"
25+
"github.com/coder/coder/v2/tailnet/proto"
26+
"github.com/coder/coder/v2/tailnet/tailnettest"
2227
"github.com/coder/coder/v2/tailnet/test"
2328
"github.com/coder/coder/v2/testutil"
2429
)
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
400405
require.True(t,ok)
401406
returnclient,server
402407
}
408+
409+
funcTestInMemoryCoordination(t*testing.T) {
410+
t.Parallel()
411+
ctx:=testutil.Context(t,testutil.WaitShort)
412+
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
413+
clientID:= uuid.UUID{1}
414+
agentID:= uuid.UUID{2}
415+
mCoord:=tailnettest.NewMockCoordinator(gomock.NewController(t))
416+
fConn:=&fakeCoordinatee{}
417+
418+
reqs:=make(chan*proto.CoordinateRequest,100)
419+
resps:=make(chan*proto.CoordinateResponse,100)
420+
mCoord.EXPECT().Coordinate(gomock.Any(),clientID,gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
421+
Times(1).Return(reqs,resps)
422+
423+
uut:=tailnet.NewInMemoryCoordination(ctx,logger,clientID,agentID,mCoord,fConn)
424+
deferuut.Close()
425+
426+
coordinationTest(ctx,t,uut,fConn,reqs,resps,agentID)
427+
428+
select {
429+
caseerr:=<-uut.Error():
430+
require.NoError(t,err)
431+
default:
432+
// OK!
433+
}
434+
}
435+
436+
funcTestRemoteCoordination(t*testing.T) {
437+
t.Parallel()
438+
ctx:=testutil.Context(t,testutil.WaitShort)
439+
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
440+
clientID:= uuid.UUID{1}
441+
agentID:= uuid.UUID{2}
442+
mCoord:=tailnettest.NewMockCoordinator(gomock.NewController(t))
443+
fConn:=&fakeCoordinatee{}
444+
445+
reqs:=make(chan*proto.CoordinateRequest,100)
446+
resps:=make(chan*proto.CoordinateResponse,100)
447+
mCoord.EXPECT().Coordinate(gomock.Any(),clientID,gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
448+
Times(1).Return(reqs,resps)
449+
450+
varcoord tailnet.Coordinator=mCoord
451+
coordPtr:= atomic.Pointer[tailnet.Coordinator]{}
452+
coordPtr.Store(&coord)
453+
svc,err:=tailnet.NewClientService(
454+
logger.Named("svc"),&coordPtr,
455+
time.Hour,
456+
func()*tailcfg.DERPMap {panic("not implemented") },
457+
)
458+
require.NoError(t,err)
459+
sC,cC:=net.Pipe()
460+
461+
serveErr:=make(chanerror,1)
462+
gofunc() {
463+
err:=svc.ServeClient(ctx,tailnet.CurrentVersion.String(),sC,clientID,agentID)
464+
serveErr<-err
465+
}()
466+
467+
client,err:=tailnet.NewDRPCClient(cC)
468+
require.NoError(t,err)
469+
protocol,err:=client.Coordinate(ctx)
470+
require.NoError(t,err)
471+
472+
uut:=tailnet.NewRemoteCoordination(logger.Named("coordination"),protocol,fConn,agentID)
473+
deferuut.Close()
474+
475+
coordinationTest(ctx,t,uut,fConn,reqs,resps,agentID)
476+
477+
select {
478+
caseerr:=<-uut.Error():
479+
require.ErrorContains(t,err,"stream terminated by sending close")
480+
default:
481+
// OK!
482+
}
483+
}
484+
485+
// coordinationTest tests that a coordination behaves correctly
486+
funccoordinationTest(
487+
ctx context.Context,t*testing.T,
488+
uut tailnet.Coordination,fConn*fakeCoordinatee,
489+
reqschan*proto.CoordinateRequest,respschan*proto.CoordinateResponse,
490+
agentID uuid.UUID,
491+
) {
492+
// It should add the tunnel, since we configured as a client
493+
req:=testutil.RequireRecvCtx(ctx,t,reqs)
494+
require.Equal(t,agentID[:],req.GetAddTunnel().GetId())
495+
496+
// when we call the callback, it should send a node update
497+
require.NotNil(t,fConn.callback)
498+
fConn.callback(&tailnet.Node{PreferredDERP:1})
499+
500+
req=testutil.RequireRecvCtx(ctx,t,reqs)
501+
require.Equal(t,int32(1),req.GetUpdateSelf().GetNode().GetPreferredDerp())
502+
503+
// When we send a peer update, it should update the coordinatee
504+
nk,err:=key.NewNode().Public().MarshalBinary()
505+
require.NoError(t,err)
506+
dk,err:=key.NewDisco().Public().MarshalText()
507+
require.NoError(t,err)
508+
updates:= []*proto.CoordinateResponse_PeerUpdate{
509+
{
510+
Id:agentID[:],
511+
Kind:proto.CoordinateResponse_PeerUpdate_NODE,
512+
Node:&proto.Node{
513+
Id:2,
514+
Key:nk,
515+
Disco:string(dk),
516+
},
517+
},
518+
}
519+
testutil.RequireSendCtx(ctx,t,resps,&proto.CoordinateResponse{PeerUpdates:updates})
520+
require.Eventually(t,func()bool {
521+
fConn.Lock()
522+
deferfConn.Unlock()
523+
returnlen(fConn.updates)>0
524+
},testutil.WaitShort,testutil.IntervalFast)
525+
require.Len(t,fConn.updates[0],1)
526+
require.Equal(t,agentID[:],fConn.updates[0][0].Id)
527+
528+
err=uut.Close()
529+
require.NoError(t,err)
530+
uut.Error()
531+
532+
// When we close, it should gracefully disconnect
533+
req=testutil.RequireRecvCtx(ctx,t,reqs)
534+
require.NotNil(t,req.Disconnect)
535+
536+
// It should set all peers lost on the coordinatee
537+
require.Equal(t,1,fConn.setAllPeersLostCalls)
538+
}
539+
540+
typefakeCoordinateestruct {
541+
sync.Mutex
542+
callbackfunc(*tailnet.Node)
543+
updates [][]*proto.CoordinateResponse_PeerUpdate
544+
setAllPeersLostCallsint
545+
}
546+
547+
func (f*fakeCoordinatee)UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate)error {
548+
f.Lock()
549+
deferf.Unlock()
550+
f.updates=append(f.updates,updates)
551+
returnnil
552+
}
553+
554+
func (f*fakeCoordinatee)SetAllPeersLost() {
555+
f.Lock()
556+
deferf.Unlock()
557+
f.setAllPeersLostCalls++
558+
}
559+
560+
func (f*fakeCoordinatee)SetNodeCallback(callbackfunc(*tailnet.Node)) {
561+
f.Lock()
562+
deferf.Unlock()
563+
f.callback=callback
564+
}

‎testutil/ctx.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
2222
returna
2323
}
2424
}
25+
26+
funcRequireSendCtx[Aany](ctx context.Context,t testing.TB,cchan<-A,aA) {
27+
t.Helper()
28+
select {
29+
case<-ctx.Done():
30+
t.Fatal("timeout")
31+
casec<-a:
32+
// OK!
33+
}
34+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp