@@ -13,6 +13,8 @@ import (
1313"github.com/stretchr/testify/require"
1414"go.uber.org/mock/gomock"
1515"golang.org/x/xerrors"
16+ "google.golang.org/protobuf/types/known/durationpb"
17+ "google.golang.org/protobuf/types/known/timestamppb"
1618"storj.io/drpc"
1719"storj.io/drpc/drpcerr"
1820"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
2426"github.com/coder/coder/v2/tailnet/proto"
2527"github.com/coder/coder/v2/tailnet/tailnettest"
2628"github.com/coder/coder/v2/testutil"
29+ "github.com/coder/quartz"
2730)
2831
2932func TestInMemoryCoordination (t * testing.T ) {
@@ -507,3 +510,171 @@ type fakeTelemetryCall struct {
507510req * proto.TelemetryRequest
508511errCh chan error
509512}
513+
514+ func TestBasicResumeTokenController_Mainline (t * testing.T ) {
515+ t .Parallel ()
516+ ctx := testutil .Context (t ,testutil .WaitShort )
517+ logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
518+ fr := newFakeResumeTokenClient (ctx )
519+ mClock := quartz .NewMock (t )
520+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" ,"refresh" )
521+ defer trp .Close ()
522+
523+ uut := tailnet .NewBasicResumeTokenController (logger ,mClock )
524+ _ ,ok := uut .Token ()
525+ require .False (t ,ok )
526+
527+ cwCh := make (chan tailnet.CloserWaiter ,1 )
528+ go func () {
529+ cwCh <- uut .New (fr )
530+ }()
531+ call := testutil .RequireRecvCtx (ctx ,t ,fr .calls )
532+ testutil .RequireSendCtx (ctx ,t ,call .resp ,& proto.RefreshResumeTokenResponse {
533+ Token :"test token 1" ,
534+ RefreshIn :durationpb .New (100 * time .Second ),
535+ ExpiresAt :timestamppb .New (mClock .Now ().Add (200 * time .Second )),
536+ })
537+ trp .MustWait (ctx ).Release ()// initial refresh done
538+ token ,ok := uut .Token ()
539+ require .True (t ,ok )
540+ require .Equal (t ,"test token 1" ,token )
541+ cw := testutil .RequireRecvCtx (ctx ,t ,cwCh )
542+
543+ w := mClock .Advance (100 * time .Second )
544+ call = testutil .RequireRecvCtx (ctx ,t ,fr .calls )
545+ testutil .RequireSendCtx (ctx ,t ,call .resp ,& proto.RefreshResumeTokenResponse {
546+ Token :"test token 2" ,
547+ RefreshIn :durationpb .New (50 * time .Second ),
548+ ExpiresAt :timestamppb .New (mClock .Now ().Add (200 * time .Second )),
549+ })
550+ resetCall := trp .MustWait (ctx )
551+ require .Equal (t ,resetCall .Duration ,50 * time .Second )
552+ resetCall .Release ()
553+ w .MustWait (ctx )
554+ token ,ok = uut .Token ()
555+ require .True (t ,ok )
556+ require .Equal (t ,"test token 2" ,token )
557+
558+ err := cw .Close (ctx )
559+ require .NoError (t ,err )
560+ err = testutil .RequireRecvCtx (ctx ,t ,cw .Wait ())
561+ require .NoError (t ,err )
562+
563+ token ,ok = uut .Token ()
564+ require .True (t ,ok )
565+ require .Equal (t ,"test token 2" ,token )
566+
567+ mClock .Advance (201 * time .Second ).MustWait (ctx )
568+ _ ,ok = uut .Token ()
569+ require .False (t ,ok )
570+ }
571+
572+ func TestBasicResumeTokenController_NewWhileRefreshing (t * testing.T ) {
573+ t .Parallel ()
574+ ctx := testutil .Context (t ,testutil .WaitShort )
575+ logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
576+ mClock := quartz .NewMock (t )
577+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" ,"refresh" )
578+ defer trp .Close ()
579+
580+ uut := tailnet .NewBasicResumeTokenController (logger ,mClock )
581+ _ ,ok := uut .Token ()
582+ require .False (t ,ok )
583+
584+ fr1 := newFakeResumeTokenClient (ctx )
585+ cwCh1 := make (chan tailnet.CloserWaiter ,1 )
586+ go func () {
587+ cwCh1 <- uut .New (fr1 )
588+ }()
589+ call1 := testutil .RequireRecvCtx (ctx ,t ,fr1 .calls )
590+
591+ fr2 := newFakeResumeTokenClient (ctx )
592+ cwCh2 := make (chan tailnet.CloserWaiter ,1 )
593+ go func () {
594+ cwCh2 <- uut .New (fr2 )
595+ }()
596+ call2 := testutil .RequireRecvCtx (ctx ,t ,fr2 .calls )
597+
598+ testutil .RequireSendCtx (ctx ,t ,call2 .resp ,& proto.RefreshResumeTokenResponse {
599+ Token :"test token 2.0" ,
600+ RefreshIn :durationpb .New (102 * time .Second ),
601+ ExpiresAt :timestamppb .New (mClock .Now ().Add (200 * time .Second )),
602+ })
603+
604+ cw2 := testutil .RequireRecvCtx (ctx ,t ,cwCh2 )// this ensures Close was called on 1
605+
606+ testutil .RequireSendCtx (ctx ,t ,call1 .resp ,& proto.RefreshResumeTokenResponse {
607+ Token :"test token 1" ,
608+ RefreshIn :durationpb .New (101 * time .Second ),
609+ ExpiresAt :timestamppb .New (mClock .Now ().Add (200 * time .Second )),
610+ })
611+
612+ trp .MustWait (ctx ).Release ()
613+
614+ token ,ok := uut .Token ()
615+ require .True (t ,ok )
616+ require .Equal (t ,"test token 2.0" ,token )
617+
618+ // refresher 1 should already be closed.
619+ cw1 := testutil .RequireRecvCtx (ctx ,t ,cwCh1 )
620+ err := testutil .RequireRecvCtx (ctx ,t ,cw1 .Wait ())
621+ require .NoError (t ,err )
622+
623+ w := mClock .Advance (102 * time .Second )
624+ call := testutil .RequireRecvCtx (ctx ,t ,fr2 .calls )
625+ testutil .RequireSendCtx (ctx ,t ,call .resp ,& proto.RefreshResumeTokenResponse {
626+ Token :"test token 2.1" ,
627+ RefreshIn :durationpb .New (50 * time .Second ),
628+ ExpiresAt :timestamppb .New (mClock .Now ().Add (200 * time .Second )),
629+ })
630+ resetCall := trp .MustWait (ctx )
631+ require .Equal (t ,resetCall .Duration ,50 * time .Second )
632+ resetCall .Release ()
633+ w .MustWait (ctx )
634+ token ,ok = uut .Token ()
635+ require .True (t ,ok )
636+ require .Equal (t ,"test token 2.1" ,token )
637+
638+ err = cw2 .Close (ctx )
639+ require .NoError (t ,err )
640+ err = testutil .RequireRecvCtx (ctx ,t ,cw2 .Wait ())
641+ require .NoError (t ,err )
642+ }
643+
644+ func newFakeResumeTokenClient (ctx context.Context )* fakeResumeTokenClient {
645+ return & fakeResumeTokenClient {
646+ ctx :ctx ,
647+ calls :make (chan * fakeResumeTokenCall ),
648+ }
649+ }
650+
651+ type fakeResumeTokenClient struct {
652+ ctx context.Context
653+ calls chan * fakeResumeTokenCall
654+ }
655+
656+ func (f * fakeResumeTokenClient )RefreshResumeToken (_ context.Context ,_ * proto.RefreshResumeTokenRequest ) (* proto.RefreshResumeTokenResponse ,error ) {
657+ call := & fakeResumeTokenCall {
658+ resp :make (chan * proto.RefreshResumeTokenResponse ),
659+ errCh :make (chan error ),
660+ }
661+ select {
662+ case <- f .ctx .Done ():
663+ return nil ,f .ctx .Err ()
664+ case f .calls <- call :
665+ // OK
666+ }
667+ select {
668+ case <- f .ctx .Done ():
669+ return nil ,f .ctx .Err ()
670+ case err := <- call .errCh :
671+ return nil ,err
672+ case resp := <- call .resp :
673+ return resp ,nil
674+ }
675+ }
676+
677+ type fakeResumeTokenCall struct {
678+ resp chan * proto.RefreshResumeTokenResponse
679+ errCh chan error
680+ }