@@ -6,19 +6,24 @@ import (
66"net"
77"net/http"
88"net/http/httptest"
9+ "sync"
10+ "sync/atomic"
911"testing"
1012"time"
1113
12- "nhooyr.io/websocket"
13-
14- "cdr.dev/slog"
15- "cdr.dev/slog/sloggers/slogtest"
16-
1714"github.com/google/uuid"
1815"github.com/stretchr/testify/assert"
1916"github.com/stretchr/testify/require"
17+ "go.uber.org/mock/gomock"
18+ "nhooyr.io/websocket"
19+ "tailscale.com/tailcfg"
20+ "tailscale.com/types/key"
2021
22+ "cdr.dev/slog"
23+ "cdr.dev/slog/sloggers/slogtest"
2124"github.com/coder/coder/v2/tailnet"
25+ "github.com/coder/coder/v2/tailnet/proto"
26+ "github.com/coder/coder/v2/tailnet/tailnettest"
2227"github.com/coder/coder/v2/tailnet/test"
2328"github.com/coder/coder/v2/testutil"
2429)
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
400405require .True (t ,ok )
401406return client ,server
402407}
408+
409+ func TestInMemoryCoordination (t * testing.T ) {
410+ t .Parallel ()
411+ ctx := testutil .Context (t ,testutil .WaitShort )
412+ logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
413+ clientID := uuid.UUID {1 }
414+ agentID := uuid.UUID {2 }
415+ mCoord := tailnettest .NewMockCoordinator (gomock .NewController (t ))
416+ fConn := & fakeCoordinatee {}
417+
418+ reqs := make (chan * proto.CoordinateRequest ,100 )
419+ resps := make (chan * proto.CoordinateResponse ,100 )
420+ mCoord .EXPECT ().Coordinate (gomock .Any (),clientID ,gomock .Any (), tailnet.ClientTunnelAuth {agentID }).
421+ Times (1 ).Return (reqs ,resps )
422+
423+ uut := tailnet .NewInMemoryCoordination (ctx ,logger ,clientID ,agentID ,mCoord ,fConn )
424+ defer uut .Close ()
425+
426+ coordinationTest (ctx ,t ,uut ,fConn ,reqs ,resps ,agentID )
427+
428+ select {
429+ case err := <- uut .Error ():
430+ require .NoError (t ,err )
431+ default :
432+ // OK!
433+ }
434+ }
435+
436+ func TestRemoteCoordination (t * testing.T ) {
437+ t .Parallel ()
438+ ctx := testutil .Context (t ,testutil .WaitShort )
439+ logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
440+ clientID := uuid.UUID {1 }
441+ agentID := uuid.UUID {2 }
442+ mCoord := tailnettest .NewMockCoordinator (gomock .NewController (t ))
443+ fConn := & fakeCoordinatee {}
444+
445+ reqs := make (chan * proto.CoordinateRequest ,100 )
446+ resps := make (chan * proto.CoordinateResponse ,100 )
447+ mCoord .EXPECT ().Coordinate (gomock .Any (),clientID ,gomock .Any (), tailnet.ClientTunnelAuth {agentID }).
448+ Times (1 ).Return (reqs ,resps )
449+
450+ var coord tailnet.Coordinator = mCoord
451+ coordPtr := atomic.Pointer [tailnet.Coordinator ]{}
452+ coordPtr .Store (& coord )
453+ svc ,err := tailnet .NewClientService (
454+ logger .Named ("svc" ),& coordPtr ,
455+ time .Hour ,
456+ func ()* tailcfg.DERPMap {panic ("not implemented" ) },
457+ )
458+ require .NoError (t ,err )
459+ sC ,cC := net .Pipe ()
460+
461+ serveErr := make (chan error ,1 )
462+ go func () {
463+ err := svc .ServeClient (ctx ,tailnet .CurrentVersion .String (),sC ,clientID ,agentID )
464+ serveErr <- err
465+ }()
466+
467+ client ,err := tailnet .NewDRPCClient (cC )
468+ require .NoError (t ,err )
469+ protocol ,err := client .Coordinate (ctx )
470+ require .NoError (t ,err )
471+
472+ uut := tailnet .NewRemoteCoordination (logger .Named ("coordination" ),protocol ,fConn ,agentID )
473+ defer uut .Close ()
474+
475+ coordinationTest (ctx ,t ,uut ,fConn ,reqs ,resps ,agentID )
476+
477+ select {
478+ case err := <- uut .Error ():
479+ require .ErrorContains (t ,err ,"stream terminated by sending close" )
480+ default :
481+ // OK!
482+ }
483+ }
484+
485+ // coordinationTest tests that a coordination behaves correctly
486+ func coordinationTest (
487+ ctx context.Context ,t * testing.T ,
488+ uut tailnet.Coordination ,fConn * fakeCoordinatee ,
489+ reqs chan * proto.CoordinateRequest ,resps chan * proto.CoordinateResponse ,
490+ agentID uuid.UUID ,
491+ ) {
492+ // It should add the tunnel, since we configured as a client
493+ req := testutil .RequireRecvCtx (ctx ,t ,reqs )
494+ require .Equal (t ,agentID [:],req .GetAddTunnel ().GetId ())
495+
496+ // when we call the callback, it should send a node update
497+ require .NotNil (t ,fConn .callback )
498+ fConn .callback (& tailnet.Node {PreferredDERP :1 })
499+
500+ req = testutil .RequireRecvCtx (ctx ,t ,reqs )
501+ require .Equal (t ,int32 (1 ),req .GetUpdateSelf ().GetNode ().GetPreferredDerp ())
502+
503+ // When we send a peer update, it should update the coordinatee
504+ nk ,err := key .NewNode ().Public ().MarshalBinary ()
505+ require .NoError (t ,err )
506+ dk ,err := key .NewDisco ().Public ().MarshalText ()
507+ require .NoError (t ,err )
508+ updates := []* proto.CoordinateResponse_PeerUpdate {
509+ {
510+ Id :agentID [:],
511+ Kind :proto .CoordinateResponse_PeerUpdate_NODE ,
512+ Node :& proto.Node {
513+ Id :2 ,
514+ Key :nk ,
515+ Disco :string (dk ),
516+ },
517+ },
518+ }
519+ testutil .RequireSendCtx (ctx ,t ,resps ,& proto.CoordinateResponse {PeerUpdates :updates })
520+ require .Eventually (t ,func ()bool {
521+ fConn .Lock ()
522+ defer fConn .Unlock ()
523+ return len (fConn .updates )> 0
524+ },testutil .WaitShort ,testutil .IntervalFast )
525+ require .Len (t ,fConn .updates [0 ],1 )
526+ require .Equal (t ,agentID [:],fConn .updates [0 ][0 ].Id )
527+
528+ err = uut .Close ()
529+ require .NoError (t ,err )
530+ uut .Error ()
531+
532+ // When we close, it should gracefully disconnect
533+ req = testutil .RequireRecvCtx (ctx ,t ,reqs )
534+ require .NotNil (t ,req .Disconnect )
535+
536+ // It should set all peers lost on the coordinatee
537+ require .Equal (t ,1 ,fConn .setAllPeersLostCalls )
538+ }
539+
540+ type fakeCoordinatee struct {
541+ sync.Mutex
542+ callback func (* tailnet.Node )
543+ updates [][]* proto.CoordinateResponse_PeerUpdate
544+ setAllPeersLostCalls int
545+ }
546+
547+ func (f * fakeCoordinatee )UpdatePeers (updates []* proto.CoordinateResponse_PeerUpdate )error {
548+ f .Lock ()
549+ defer f .Unlock ()
550+ f .updates = append (f .updates ,updates )
551+ return nil
552+ }
553+
554+ func (f * fakeCoordinatee )SetAllPeersLost () {
555+ f .Lock ()
556+ defer f .Unlock ()
557+ f .setAllPeersLostCalls ++
558+ }
559+
560+ func (f * fakeCoordinatee )SetNodeCallback (callback func (* tailnet.Node )) {
561+ f .Lock ()
562+ defer f .Unlock ()
563+ f .callback = callback
564+ }