Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

feat: set peers lost when disconnected from coordinator#11681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
spikecurtis merged 1 commit intomainfromspike/10533-set-lost-disconnect
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletionstailnet/conn.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
return nil
}

// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
func (c *Conn) SetAllPeersLost() {
c.configMaps.setAllPeersLost()
}

// NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
return c.configMaps.nodeAddresses(publicKey)
Expand Down
67 changes: 44 additions & 23 deletionstailnet/coordinator.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -97,6 +97,7 @@ type Node struct {
// Conn.
type Coordinatee interface {
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
SetAllPeersLost()
SetNodeCallback(func(*Node))
}

Expand All@@ -107,20 +108,28 @@ type Coordination interface {

type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
respLoopDone chan struct{}
}

func (c *remoteCoordination) Close() error {
func (c *remoteCoordination) Close()(retErrerror) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
protoErr := c.protocol.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
Expand All@@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
}

func (c *remoteCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.protocol.Recv()
if err != nil {
Expand All@@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
respLoopDone: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
c.Lock()
Expand DownExpand Up@@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,

type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
respLoopDone chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}

func (c *inMemoryCoordination) sendErr(err error) {
Expand DownExpand Up@@ -238,11 +253,12 @@ func NewInMemoryCoordination(
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
respLoopDone: make(chan struct{}),
}

// use the background context since we will depend exclusively on closing the req channel to
Expand DownExpand Up@@ -285,6 +301,10 @@ func NewInMemoryCoordination(
}

func (c *inMemoryCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
select {
case <-c.closedCh:
Expand DownExpand Up@@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
defer close(c.reqs)
c.closed = true
close(c.closedCh)
<-c.respLoopDone
select {
case <-c.ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
Expand Down
172 changes: 167 additions & 5 deletionstailnet/coordinator_test.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -6,19 +6,24 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"

"nhooyr.io/websocket"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
Expand DownExpand Up@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
require.True(t, ok)
return client, server
}

func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}

reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)

uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}

func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}

reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)

var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
sC, cC := net.Pipe()

serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()

client, err := tailnet.NewDRPCClient(cC)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)

uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}

// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())

// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})

req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())

// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)

err = uut.Close()
require.NoError(t, err)
uut.Error()

// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)

// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}

type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
}

func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}

func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}

func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}
10 changes: 10 additions & 0 deletionstestutil/ctx.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
return a
}
}

func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout")
case c <- a:
// OK!
}
}

[8]ページ先頭

©2009-2025 Movatter.jp