55"fmt"
66"io"
77"net"
8+ "net/netip"
89"slices"
910"sync"
1011"sync/atomic"
@@ -23,6 +24,7 @@ import (
2324"storj.io/drpc/drpcerr"
2425"tailscale.com/tailcfg"
2526"tailscale.com/types/key"
27+ "tailscale.com/util/dnsname"
2628
2729"cdr.dev/slog"
2830"cdr.dev/slog/sloggers/slogtest"
@@ -1344,14 +1346,56 @@ func testUUID(b ...byte) uuid.UUID {
13441346return o
13451347}
13461348
1349+ type fakeDNSSetter struct {
1350+ ctx context.Context
1351+ t testing.TB
1352+ calls chan * setDNSCall
1353+ }
1354+
1355+ type setDNSCall struct {
1356+ hosts map [dnsname.FQDN ][]netip.Addr
1357+ err chan <- error
1358+ }
1359+
1360+ func newFakeDNSSetter (ctx context.Context ,t testing.TB )* fakeDNSSetter {
1361+ return & fakeDNSSetter {
1362+ ctx :ctx ,
1363+ t :t ,
1364+ calls :make (chan * setDNSCall ),
1365+ }
1366+ }
1367+
1368+ func (f * fakeDNSSetter )SetDNSHosts (hosts map [dnsname.FQDN ][]netip.Addr )error {
1369+ f .t .Helper ()
1370+ errs := make (chan error )
1371+ call := & setDNSCall {
1372+ hosts :hosts ,
1373+ err :errs ,
1374+ }
1375+ select {
1376+ case <- f .ctx .Done ():
1377+ f .t .Error ("timed out waiting to send SetDNSHosts() call" )
1378+ return f .ctx .Err ()
1379+ case f .calls <- call :
1380+ // OK
1381+ }
1382+ select {
1383+ case <- f .ctx .Done ():
1384+ f .t .Error ("timed out waiting for SetDNSHosts() call response" )
1385+ return f .ctx .Err ()
1386+ case err := <- errs :
1387+ return err
1388+ }
1389+ }
1390+
13471391func setupConnectedAllWorkspaceUpdatesController (
1348- ctx context.Context ,t testing.TB ,logger slog.Logger ,
1392+ ctx context.Context ,t testing.TB ,logger slog.Logger ,dnsSetter tailnet. DNSHostsSetter ,
13491393) (
13501394* fakeCoordinatorClient ,* fakeWorkspaceUpdateClient ,
13511395) {
13521396fConn := & fakeCoordinatee {}
13531397tsc := tailnet .NewTunnelSrcCoordController (logger ,fConn )
1354- uut := tailnet .NewTunnelAllWorkspaceUpdatesController (logger ,tsc )
1398+ uut := tailnet .NewTunnelAllWorkspaceUpdatesController (logger ,tsc , dnsSetter )
13551399
13561400// connect up a coordinator client, to track adding and removing tunnels
13571401coordC := newFakeCoordinatorClient (ctx ,t )
@@ -1385,7 +1429,8 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
13851429ctx := testutil .Context (t ,testutil .WaitShort )
13861430logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
13871431
1388- coordC ,updateC := setupConnectedAllWorkspaceUpdatesController (ctx ,t ,logger )
1432+ fDNS := newFakeDNSSetter (ctx ,t )
1433+ coordC ,updateC := setupConnectedAllWorkspaceUpdatesController (ctx ,t ,logger ,fDNS )
13891434
13901435// Initial update contains 2 workspaces with 1 & 2 agents, respectively
13911436w1ID := testUUID (1 )
@@ -1418,14 +1463,25 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
14181463require .Contains (t ,adds ,w1a1ID )
14191464require .Contains (t ,adds ,w2a1ID )
14201465require .Contains (t ,adds ,w2a2ID )
1466+
1467+ // Also triggers setting DNS hosts
1468+ expectedDNS := map [dnsname.FQDN ][]netip.Addr {
1469+ "w1a1.w1.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0101::" )},
1470+ "w2a1.w2.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0201::" )},
1471+ "w2a2.w2.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0202::" )},
1472+ }
1473+ dnsCall := testutil .RequireRecvCtx (ctx ,t ,fDNS .calls )
1474+ require .Equal (t ,expectedDNS ,dnsCall .hosts )
1475+ testutil .RequireSendCtx (ctx ,t ,dnsCall .err ,nil )
14211476}
14221477
14231478func TestTunnelAllWorkspaceUpdatesController_DeleteAgent (t * testing.T ) {
14241479t .Parallel ()
14251480ctx := testutil .Context (t ,testutil .WaitShort )
14261481logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
14271482
1428- coordC ,updateC := setupConnectedAllWorkspaceUpdatesController (ctx ,t ,logger )
1483+ fDNS := newFakeDNSSetter (ctx ,t )
1484+ coordC ,updateC := setupConnectedAllWorkspaceUpdatesController (ctx ,t ,logger ,fDNS )
14291485
14301486w1ID := testUUID (1 )
14311487w1a1ID := testUUID (1 ,1 )
@@ -1447,6 +1503,14 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
14471503require .Equal (t ,w1a1ID [:],coordCall .req .GetAddTunnel ().GetId ())
14481504testutil .RequireSendCtx (ctx ,t ,coordCall .err ,nil )
14491505
1506+ // DNS for w1a1
1507+ expectedDNS := map [dnsname.FQDN ][]netip.Addr {
1508+ "w1a1.w1.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0101::" )},
1509+ }
1510+ dnsCall := testutil .RequireRecvCtx (ctx ,t ,fDNS .calls )
1511+ require .Equal (t ,expectedDNS ,dnsCall .hosts )
1512+ testutil .RequireSendCtx (ctx ,t ,dnsCall .err ,nil )
1513+
14501514// Send update that removes w1a1 and adds w1a2
14511515agentUpdate := & proto.WorkspaceUpdate {
14521516UpsertedAgents : []* proto.Agent {
@@ -1468,6 +1532,60 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
14681532coordCall = testutil .RequireRecvCtx (ctx ,t ,coordC .reqs )
14691533require .Equal (t ,w1a1ID [:],coordCall .req .GetRemoveTunnel ().GetId ())
14701534testutil .RequireSendCtx (ctx ,t ,coordCall .err ,nil )
1535+
1536+ // DNS contains only w1a2
1537+ expectedDNS = map [dnsname.FQDN ][]netip.Addr {
1538+ "w1a2.w1.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0102::" )},
1539+ }
1540+ dnsCall = testutil .RequireRecvCtx (ctx ,t ,fDNS .calls )
1541+ require .Equal (t ,expectedDNS ,dnsCall .hosts )
1542+ testutil .RequireSendCtx (ctx ,t ,dnsCall .err ,nil )
1543+ }
1544+
1545+ func TestTunnelAllWorkspaceUpdatesController_DNSError (t * testing.T ) {
1546+ t .Parallel ()
1547+ ctx := testutil .Context (t ,testutil .WaitShort )
1548+ dnsError := xerrors .New ("a bad thing happened" )
1549+ logger := slogtest .Make (t ,
1550+ & slogtest.Options {IgnoredErrorIs : []error {dnsError }}).
1551+ Leveled (slog .LevelDebug )
1552+
1553+ fDNS := newFakeDNSSetter (ctx ,t )
1554+ fConn := & fakeCoordinatee {}
1555+ tsc := tailnet .NewTunnelSrcCoordController (logger ,fConn )
1556+ uut := tailnet .NewTunnelAllWorkspaceUpdatesController (logger ,tsc ,fDNS )
1557+
1558+ updateC := newFakeWorkspaceUpdateClient (ctx ,t )
1559+ updateCW := uut .New (updateC )
1560+
1561+ w1ID := testUUID (1 )
1562+ w1a1ID := testUUID (1 ,1 )
1563+ initUp := & proto.WorkspaceUpdate {
1564+ UpsertedWorkspaces : []* proto.Workspace {
1565+ {Id :w1ID [:],Name :"w1" },
1566+ },
1567+ UpsertedAgents : []* proto.Agent {
1568+ {Id :w1a1ID [:],Name :"w1a1" ,WorkspaceId :w1ID [:]},
1569+ },
1570+ }
1571+ upRecvCall := testutil .RequireRecvCtx (ctx ,t ,updateC .recv )
1572+ testutil .RequireSendCtx (ctx ,t ,upRecvCall .resp ,initUp )
1573+
1574+ // DNS for w1a1
1575+ expectedDNS := map [dnsname.FQDN ][]netip.Addr {
1576+ "w1a1.w1.me.coder." : {netip .MustParseAddr ("fd60:627a:a42b:0101::" )},
1577+ }
1578+ dnsCall := testutil .RequireRecvCtx (ctx ,t ,fDNS .calls )
1579+ require .Equal (t ,expectedDNS ,dnsCall .hosts )
1580+ testutil .RequireSendCtx (ctx ,t ,dnsCall .err ,dnsError )
1581+
1582+ // should trigger a close on the client
1583+ closeCall := testutil .RequireRecvCtx (ctx ,t ,updateC .close )
1584+ testutil .RequireSendCtx (ctx ,t ,closeCall ,io .EOF )
1585+
1586+ // error should be our initial DNS error
1587+ err := testutil .RequireRecvCtx (ctx ,t ,updateCW .Wait ())
1588+ require .ErrorIs (t ,err ,dnsError )
14711589}
14721590
14731591func TestTunnelAllWorkspaceUpdatesController_HandleErrors (t * testing.T ) {
@@ -1562,7 +1680,7 @@ func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) {
15621680
15631681fConn := & fakeCoordinatee {}
15641682tsc := tailnet .NewTunnelSrcCoordController (logger ,fConn )
1565- uut := tailnet .NewTunnelAllWorkspaceUpdatesController (logger ,tsc )
1683+ uut := tailnet .NewTunnelAllWorkspaceUpdatesController (logger ,tsc , nil )
15661684updateC := newFakeWorkspaceUpdateClient (ctx ,t )
15671685updateCW := uut .New (updateC )
15681686