@@ -7,17 +7,17 @@ import (
7
7
"net/http"
8
8
"strconv"
9
9
"strings"
10
+ "sync"
10
11
"testing"
11
12
"time"
12
13
13
- "github.com/coder/coder/coderd/database/dbfake"
14
-
15
14
"github.com/go-chi/chi/v5"
16
15
"github.com/stretchr/testify/assert"
17
16
"github.com/stretchr/testify/require"
18
17
"golang.org/x/xerrors"
19
18
20
19
"github.com/coder/coder/coderd"
20
+ "github.com/coder/coder/coderd/database/dbfake"
21
21
"github.com/coder/coder/coderd/rbac"
22
22
"github.com/coder/coder/coderd/rbac/regosql"
23
23
"github.com/coder/coder/codersdk"
@@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
443
443
444
444
func (a * AuthTester )Test (ctx context.Context ,assertRoute map [string ]RouteCheck ,skipRoutes map [string ]string ) {
445
445
// Always fail auth from this point forward
446
- a .authorizer .AlwaysReturn = rbac .ForbiddenWithInternal (xerrors .New ("fake implementation" ),nil ,nil )
446
+ a .authorizer .Wrapped = & FakeAuthorizer {
447
+ AlwaysReturn :rbac .ForbiddenWithInternal (xerrors .New ("fake implementation" ),nil ,nil ),
448
+ }
447
449
448
450
routeMissing := make (map [string ]bool )
449
451
for k ,v := range assertRoute {
@@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
483
485
return nil
484
486
}
485
487
a .t .Run (name ,func (t * testing.T ) {
486
- a .authorizer .reset ()
488
+ a .authorizer .Reset ()
487
489
routeKey := strings .TrimRight (name ,"/" )
488
490
489
491
routeAssertions ,ok := assertRoute [routeKey ]
@@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
514
516
assert .Equal (t ,http .StatusForbidden ,resp .StatusCode ,"expect unauthorized" )
515
517
}
516
518
}
517
- if a .authorizer .Called != nil {
519
+ if a .authorizer .lastCall ()!= nil {
520
+ last := a .authorizer .lastCall ()
518
521
if routeAssertions .AssertAction != "" {
519
- assert .Equal (t ,routeAssertions .AssertAction ,a . authorizer . Called .Action ,"resource action" )
522
+ assert .Equal (t ,routeAssertions .AssertAction ,last .Action ,"resource action" )
520
523
}
521
524
if routeAssertions .AssertObject .Type != "" {
522
- assert .Equal (t ,routeAssertions .AssertObject .Type ,a . authorizer . Called .Object .Type ,"resource type" )
525
+ assert .Equal (t ,routeAssertions .AssertObject .Type ,last .Object .Type ,"resource type" )
523
526
}
524
527
if routeAssertions .AssertObject .Owner != "" {
525
- assert .Equal (t ,routeAssertions .AssertObject .Owner ,a . authorizer . Called .Object .Owner ,"resource owner" )
528
+ assert .Equal (t ,routeAssertions .AssertObject .Owner ,last .Object .Owner ,"resource owner" )
526
529
}
527
530
if routeAssertions .AssertObject .OrgID != "" {
528
- assert .Equal (t ,routeAssertions .AssertObject .OrgID ,a . authorizer . Called .Object .OrgID ,"resource org" )
531
+ assert .Equal (t ,routeAssertions .AssertObject .OrgID ,last .Object .OrgID ,"resource org" )
529
532
}
530
533
}
531
534
}else {
@@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
539
542
}
540
543
541
544
type authCall struct {
542
- Subject rbac.Subject
543
- Action rbac.Action
544
- Object rbac.Object
545
+ Actor rbac.Subject
546
+ Action rbac.Action
547
+ Object rbac.Object
548
+
549
+ asserted bool
545
550
}
546
551
552
+ var _ rbac.Authorizer = (* RecordingAuthorizer )(nil )
553
+
554
+ // RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize()
555
+ // calls made. This is useful for testing as these calls can later be asserted.
547
556
type RecordingAuthorizer struct {
548
- Called * authCall
549
- AlwaysReturn error
557
+ sync.RWMutex
558
+ Called []authCall
559
+ Wrapped rbac.Authorizer
550
560
}
551
561
552
- var _ rbac.Authorizer = (* RecordingAuthorizer )(nil )
562
+ type ActionObjectPair struct {
563
+ Action rbac.Action
564
+ Object rbac.Object
565
+ }
553
566
554
- // AuthorizeSQL does not record the call. This matches the postgres behavior
555
- // of not calling Authorize()
556
- func (r * RecordingAuthorizer )AuthorizeSQL (_ context.Context ,_ rbac.Subject ,_ rbac.Action ,_ rbac.Object )error {
557
- return r .AlwaysReturn
567
+ // Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
568
+ // interface smaller.
569
+ func (* RecordingAuthorizer )Pair (action rbac.Action ,object rbac.Objecter )ActionObjectPair {
570
+ return ActionObjectPair {
571
+ Action :action ,
572
+ Object :object .RBACObject (),
573
+ }
558
574
}
559
575
560
- func (r * RecordingAuthorizer )Authorize (_ context.Context ,subject rbac.Subject ,action rbac.Action ,object rbac.Object )error {
561
- r .Called = & authCall {
562
- Subject :subject ,
563
- Action :action ,
564
- Object :object ,
576
+ // AllAsserted returns an error if all calls to Authorize() have not been
577
+ // asserted and checked. This is useful for testing to ensure that all
578
+ // Authorize() calls are checked in the unit test.
579
+ func (r * RecordingAuthorizer )AllAsserted ()error {
580
+ r .RLock ()
581
+ defer r .RUnlock ()
582
+ missed := []authCall {}
583
+ for _ ,c := range r .Called {
584
+ if ! c .asserted {
585
+ missed = append (missed ,c )
586
+ }
565
587
}
566
- return r .AlwaysReturn
588
+
589
+ if len (missed )> 0 {
590
+ return xerrors .Errorf ("missed calls: %+v" ,missed )
591
+ }
592
+ return nil
567
593
}
568
594
569
- func (r * RecordingAuthorizer )Prepare (_ context.Context ,subject rbac.Subject ,action rbac.Action ,_ string ) (rbac.PreparedAuthorized ,error ) {
570
- return & fakePreparedAuthorizer {
571
- Original :r ,
572
- Subject :subject ,
573
- Action :action ,
574
- HardCodedSQLString :"true" ,
595
+ // AssertActor asserts in order. If the order of authz calls does not match,
596
+ // this will fail.
597
+ func (r * RecordingAuthorizer )AssertActor (t * testing.T ,actor rbac.Subject ,did ... ActionObjectPair ) {
598
+ r .RLock ()
599
+ defer r .RUnlock ()
600
+ ptr := 0
601
+ for i ,call := range r .Called {
602
+ if ptr == len (did ) {
603
+ // Finished all assertions
604
+ return
605
+ }
606
+ if call .Actor .ID == actor .ID {
607
+ action ,object := did [ptr ].Action ,did [ptr ].Object
608
+ assert .Equalf (t ,action ,call .Action ,"assert action %d" ,ptr )
609
+ assert .Equalf (t ,object ,call .Object ,"assert object %d" ,ptr )
610
+ r .Called [i ].asserted = true
611
+ ptr ++
612
+ }
613
+ }
614
+
615
+ assert .Equalf (t ,len (did ),ptr ,"assert actor: didn't find all actions, %d missing actions" ,len (did )- ptr )
616
+ }
617
+
618
+ // recordAuthorize is the internal method that records the Authorize() call.
619
+ func (r * RecordingAuthorizer )recordAuthorize (subject rbac.Subject ,action rbac.Action ,object rbac.Object ) {
620
+ r .Lock ()
621
+ defer r .Unlock ()
622
+ r .Called = append (r .Called ,authCall {
623
+ Actor :subject ,
624
+ Action :action ,
625
+ Object :object ,
626
+ })
627
+ }
628
+
629
+ func (r * RecordingAuthorizer )Authorize (ctx context.Context ,subject rbac.Subject ,action rbac.Action ,object rbac.Object )error {
630
+ r .recordAuthorize (subject ,action ,object )
631
+ if r .Wrapped == nil {
632
+ panic ("Developer error: RecordingAuthorizer.Wrapped is nil" )
633
+ }
634
+ return r .Wrapped .Authorize (ctx ,subject ,action ,object )
635
+ }
636
+
637
+ func (r * RecordingAuthorizer )Prepare (ctx context.Context ,subject rbac.Subject ,action rbac.Action ,objectType string ) (rbac.PreparedAuthorized ,error ) {
638
+ r .RLock ()
639
+ defer r .RUnlock ()
640
+ if r .Wrapped == nil {
641
+ panic ("Developer error: RecordingAuthorizer.Wrapped is nil" )
642
+ }
643
+
644
+ prep ,err := r .Wrapped .Prepare (ctx ,subject ,action ,objectType )
645
+ if err != nil {
646
+ return nil ,err
647
+ }
648
+ return & PreparedRecorder {
649
+ rec :r ,
650
+ prepped :prep ,
651
+ subject :subject ,
652
+ action :action ,
575
653
},nil
576
654
}
577
655
578
- func (r * RecordingAuthorizer )reset () {
656
+ // Reset clears the recorded Authorize() calls.
657
+ func (r * RecordingAuthorizer )Reset () {
658
+ r .Lock ()
659
+ defer r .Unlock ()
579
660
r .Called = nil
580
661
}
581
662
663
+ // lastCall is implemented to support legacy tests.
664
+ // Deprecated
665
+ func (r * RecordingAuthorizer )lastCall ()* authCall {
666
+ r .RLock ()
667
+ defer r .RUnlock ()
668
+ if len (r .Called )== 0 {
669
+ return nil
670
+ }
671
+ return & r .Called [len (r .Called )- 1 ]
672
+ }
673
+
674
+ // PreparedRecorder is the prepared version of the RecordingAuthorizer.
675
+ // It records the Authorize() calls to the original recorder. If the caller
676
+ // uses CompileToSQL, all recording stops. This is to support parity between
677
+ // memory and SQL backed dbs.
678
+ type PreparedRecorder struct {
679
+ rec * RecordingAuthorizer
680
+ prepped rbac.PreparedAuthorized
681
+ subject rbac.Subject
682
+ action rbac.Action
683
+
684
+ rw sync.Mutex
685
+ usingSQL bool
686
+ }
687
+
688
+ func (s * PreparedRecorder )Authorize (ctx context.Context ,object rbac.Object )error {
689
+ s .rw .Lock ()
690
+ defer s .rw .Unlock ()
691
+
692
+ if ! s .usingSQL {
693
+ s .rec .recordAuthorize (s .subject ,s .action ,object )
694
+ }
695
+ return s .prepped .Authorize (ctx ,object )
696
+ }
697
+ func (s * PreparedRecorder )CompileToSQL (ctx context.Context ,cfg regosql.ConvertConfig ) (string ,error ) {
698
+ s .rw .Lock ()
699
+ defer s .rw .Unlock ()
700
+
701
+ s .usingSQL = true
702
+ return s .prepped .CompileToSQL (ctx ,cfg )
703
+ }
704
+
705
+ // FakeAuthorizer is an Authorizer that always returns the same error.
706
+ type FakeAuthorizer struct {
707
+ // AlwaysReturn is the error that will be returned by Authorize.
708
+ AlwaysReturn error
709
+ }
710
+
711
+ var _ rbac.Authorizer = (* FakeAuthorizer )(nil )
712
+
713
+ func (d * FakeAuthorizer )Authorize (_ context.Context ,_ rbac.Subject ,_ rbac.Action ,_ rbac.Object )error {
714
+ return d .AlwaysReturn
715
+ }
716
+
717
+ func (d * FakeAuthorizer )Prepare (_ context.Context ,subject rbac.Subject ,action rbac.Action ,_ string ) (rbac.PreparedAuthorized ,error ) {
718
+ return & fakePreparedAuthorizer {
719
+ Original :d ,
720
+ Subject :subject ,
721
+ Action :action ,
722
+ },nil
723
+ }
724
+
725
+ var _ rbac.PreparedAuthorized = (* fakePreparedAuthorizer )(nil )
726
+
727
+ // fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will
728
+ // return the same error as the original FakeAuthorizer.
582
729
type fakePreparedAuthorizer struct {
583
- Original * RecordingAuthorizer
584
- Subject rbac.Subject
585
- Action rbac.Action
586
- HardCodedSQLString string
587
- HardCodedRegoString string
730
+ sync.RWMutex
731
+ Original * FakeAuthorizer
732
+ Subject rbac.Subject
733
+ Action rbac.Action
588
734
}
589
735
590
736
func (f * fakePreparedAuthorizer )Authorize (ctx context.Context ,object rbac.Object )error {
@@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
593
739
594
740
// CompileToSQL returns a compiled version of the authorizer that will work for
595
741
// in memory databases. This fake version will not work against a SQL database.
596
- func (fakePreparedAuthorizer )CompileToSQL (_ context.Context ,_ regosql.ConvertConfig ) (string ,error ) {
597
- return "" ,xerrors .New ("not implemented" )
598
- }
599
-
600
- func (f * fakePreparedAuthorizer )Eval (object rbac.Object )bool {
601
- return f .Original .AuthorizeSQL (context .Background (),f .Subject ,f .Action ,object )== nil
602
- }
603
-
604
- func (f fakePreparedAuthorizer )RegoString ()string {
605
- if f .HardCodedRegoString != "" {
606
- return f .HardCodedRegoString
607
- }
608
- panic ("not implemented" )
742
+ func (* fakePreparedAuthorizer )CompileToSQL (_ context.Context ,_ regosql.ConvertConfig ) (string ,error ) {
743
+ return "not a valid sql string" ,nil
609
744
}