Expand Up @@ -13,12 +13,8 @@ import ( "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" "nhooyr.io/websocket" "storj.io/drpc" "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" "cdr.dev/slog" Expand Down Expand Up @@ -385,7 +381,12 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { testutil.RequireSendCtx(ctx, t, eventCh, batch) select { case <-ctx.Done(): t.Error("timeout sending telemetry event") case eventCh <- batch: t.Log("sent telemetry batch") } }, ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), }) Expand All @@ -409,11 +410,10 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) uut.runConnector(fConn) require.Eventually(t, func() bool { uut.clientMu.Lock() defer uut.clientMu.Unlock() return uut.client != nil }, testutil.WaitShort, testutil.IntervalFast) // Coordinate calls happen _after_ telemetry is connected up, so we use this // to ensure telemetry is connected before sending our event cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) defer close(cc.Resps) uut.SendTelemetryEvent(&proto.TelemetryEvent{ Id: []byte("test event"), Expand All @@ -425,86 +425,6 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { require.Equal(t, []byte("test event"), testEvents[0].Id) } func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) agentID := uuid.UUID{0x55} fConn := newFakeTailnetConn() fakeDRPCClient := newFakeDRPCClient() uut := &tailnetAPIConnector{ ctx: ctx, logger: logger, agentID: agentID, coordinateURL: "", clock: quartz.NewReal(), dialOptions: &websocket.DialOptions{}, connected: make(chan error, 1), closed: make(chan struct{}), customDialFn: func() (proto.DRPCTailnetClient, error) { return fakeDRPCClient, nil }, } uut.runConnector(fConn) require.Eventually(t, func() bool { uut.clientMu.Lock() defer uut.clientMu.Unlock() return uut.client != nil }, testutil.WaitShort, testutil.IntervalFast) fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), 0) uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.False(t, uut.telemetryUnavailable.Load()) require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented) uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.True(t, uut.telemetryUnavailable.Load()) uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) } func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) agentID := uuid.UUID{0x55} fConn := newFakeTailnetConn() fakeDRPCClient := newFakeDRPCClient() uut := &tailnetAPIConnector{ ctx: ctx, logger: logger, agentID: agentID, coordinateURL: "", clock: quartz.NewReal(), dialOptions: &websocket.DialOptions{}, connected: make(chan error, 1), closed: make(chan struct{}), customDialFn: func() (proto.DRPCTailnetClient, error) { return fakeDRPCClient, nil }, } uut.runConnector(fConn) require.Eventually(t, func() bool { uut.clientMu.Lock() defer uut.clientMu.Unlock() return uut.client != nil }, testutil.WaitShort, testutil.IntervalFast) fakeDRPCClient.telemetryError = drpc.ProtocolError.New("Protocol Error") uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.False(t, uut.telemetryUnavailable.Load()) require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) fakeDRPCClient.telemetryError = drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry") uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.True(t, uut.telemetryUnavailable.Load()) uut.SendTelemetryEvent(&proto.TelemetryEvent{}) require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) } type fakeTailnetConn struct{} func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { Expand All @@ -524,65 +444,6 @@ func newFakeTailnetConn() *fakeTailnetConn { return &fakeTailnetConn{} } type fakeDRPCClient struct { postTelemetryCalls int64 refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) telemetryError error fakeDRPPCMapStream } var _ proto.DRPCTailnetClient = &fakeDRPCClient{} func newFakeDRPCClient() *fakeDRPCClient { return &fakeDRPCClient{ postTelemetryCalls: 0, fakeDRPPCMapStream: fakeDRPPCMapStream{ fakeDRPCStream: fakeDRPCStream{ ch: make(chan struct{}), }, }, } } // Coordinate implements proto.DRPCTailnetClient. func (f *fakeDRPCClient) Coordinate(_ context.Context) (proto.DRPCTailnet_CoordinateClient, error) { return &f.fakeDRPCStream, nil } // DRPCConn implements proto.DRPCTailnetClient. func (*fakeDRPCClient) DRPCConn() drpc.Conn { return &fakeDRPCConn{} } // PostTelemetry implements proto.DRPCTailnetClient. func (f *fakeDRPCClient) PostTelemetry(_ context.Context, _ *proto.TelemetryRequest) (*proto.TelemetryResponse, error) { atomic.AddInt64(&f.postTelemetryCalls, 1) return nil, f.telemetryError } // StreamDERPMaps implements proto.DRPCTailnetClient. func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMapsRequest) (proto.DRPCTailnet_StreamDERPMapsClient, error) { return &f.fakeDRPPCMapStream, nil } // RefreshResumeToken implements proto.DRPCTailnetClient. func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) { if f.refreshTokenFn != nil { return f.refreshTokenFn(context.Background(), nil) } return &proto.RefreshResumeTokenResponse{ Token: "test", RefreshIn: durationpb.New(30 * time.Minute), ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), }, nil } // WorkspaceUpdates implements proto.DRPCTailnetClient. func (*fakeDRPCClient) WorkspaceUpdates(context.Context, *proto.WorkspaceUpdatesRequest) (proto.DRPCTailnet_WorkspaceUpdatesClient, error) { panic("unimplemented") } type fakeDRPCConn struct{} var _ drpc.Conn = &fakeDRPCConn{} Expand Down