Expand Up @@ -7,17 +7,17 @@ import ( "net/http" "strconv" "strings" "sync" "testing" "time" "github.com/coder/coder/coderd/database/dbfake" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/codersdk" Expand Down Expand Up @@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { // Always fail auth from this point forward a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) a.authorizer.Wrapped = &FakeAuthorizer{ AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil), } routeMissing := make(map[string]bool) for k, v := range assertRoute { Expand Down Expand Up @@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck return nil } a.t.Run(name, func(t *testing.T) { a.authorizer.reset () a.authorizer.Reset () routeKey := strings.TrimRight(name, "/") routeAssertions, ok := assertRoute[routeKey] Expand Down Expand Up @@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized") } } if a.authorizer.Called != nil { if a.authorizer.lastCall() != nil { last := a.authorizer.lastCall() if routeAssertions.AssertAction != "" { assert.Equal(t, routeAssertions.AssertAction,a.authorizer.Called .Action, "resource action") assert.Equal(t, routeAssertions.AssertAction,last .Action, "resource action") } if routeAssertions.AssertObject.Type != "" { assert.Equal(t, routeAssertions.AssertObject.Type,a.authorizer.Called .Object.Type, "resource type") assert.Equal(t, routeAssertions.AssertObject.Type,last .Object.Type, "resource type") } if routeAssertions.AssertObject.Owner != "" { assert.Equal(t, routeAssertions.AssertObject.Owner,a.authorizer.Called .Object.Owner, "resource owner") assert.Equal(t, routeAssertions.AssertObject.Owner,last .Object.Owner, "resource owner") } if routeAssertions.AssertObject.OrgID != "" { assert.Equal(t, routeAssertions.AssertObject.OrgID,a.authorizer.Called .Object.OrgID, "resource org") assert.Equal(t, routeAssertions.AssertObject.OrgID,last .Object.OrgID, "resource org") } } } else { Expand All @@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck } type authCall struct { Subject rbac.Subject Action rbac.Action Object rbac.Object Actor rbac.Subject Action rbac.Action Object rbac.Object asserted bool } var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) // RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize() // calls made. This is useful for testing as these calls can later be asserted. type RecordingAuthorizer struct { Called *authCall AlwaysReturn error sync.RWMutex Called []authCall Wrapped rbac.Authorizer } var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) type ActionObjectPair struct { Action rbac.Action Object rbac.Object } // AuthorizeSQL does not record the call. This matches the postgres behavior // of not calling Authorize() func (r *RecordingAuthorizer) AuthorizeSQL(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error { return r.AlwaysReturn // Pair is on the RecordingAuthorizer to be easy to find and keep the pkg // interface smaller. func (*RecordingAuthorizer) Pair(action rbac.Action, object rbac.Objecter) ActionObjectPair { return ActionObjectPair{ Action: action, Object: object.RBACObject(), } } func (r *RecordingAuthorizer) Authorize(_ context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error { r.Called = &authCall{ Subject: subject, Action: action, Object: object, // AllAsserted returns an error if all calls to Authorize() have not been // asserted and checked. This is useful for testing to ensure that all // Authorize() calls are checked in the unit test. func (r *RecordingAuthorizer) AllAsserted() error { r.RLock() defer r.RUnlock() missed := []authCall{} for _, c := range r.Called { if !c.asserted { missed = append(missed, c) } } return r.AlwaysReturn if len(missed) > 0 { return xerrors.Errorf("missed calls: %+v", missed) } return nil } func (r *RecordingAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: r, Subject: subject, Action: action, HardCodedSQLString: "true", // AssertActor asserts in order. If the order of authz calls does not match, // this will fail. func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) { r.RLock() defer r.RUnlock() ptr := 0 for i, call := range r.Called { if ptr == len(did) { // Finished all assertions return } if call.Actor.ID == actor.ID { action, object := did[ptr].Action, did[ptr].Object assert.Equalf(t, action, call.Action, "assert action %d", ptr) assert.Equalf(t, object, call.Object, "assert object %d", ptr) r.Called[i].asserted = true ptr++ } } assert.Equalf(t, len(did), ptr, "assert actor: didn't find all actions, %d missing actions", len(did)-ptr) } // recordAuthorize is the internal method that records the Authorize() call. func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action rbac.Action, object rbac.Object) { r.Lock() defer r.Unlock() r.Called = append(r.Called, authCall{ Actor: subject, Action: action, Object: object, }) } func (r *RecordingAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error { r.recordAuthorize(subject, action, object) if r.Wrapped == nil { panic("Developer error: RecordingAuthorizer.Wrapped is nil") } return r.Wrapped.Authorize(ctx, subject, action, object) } func (r *RecordingAuthorizer) Prepare(ctx context.Context, subject rbac.Subject, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) { r.RLock() defer r.RUnlock() if r.Wrapped == nil { panic("Developer error: RecordingAuthorizer.Wrapped is nil") } prep, err := r.Wrapped.Prepare(ctx, subject, action, objectType) if err != nil { return nil, err } return &PreparedRecorder{ rec: r, prepped: prep, subject: subject, action: action, }, nil } func (r *RecordingAuthorizer) reset() { // Reset clears the recorded Authorize() calls. func (r *RecordingAuthorizer) Reset() { r.Lock() defer r.Unlock() r.Called = nil } // lastCall is implemented to support legacy tests. // Deprecated func (r *RecordingAuthorizer) lastCall() *authCall { r.RLock() defer r.RUnlock() if len(r.Called) == 0 { return nil } return &r.Called[len(r.Called)-1] } // PreparedRecorder is the prepared version of the RecordingAuthorizer. // It records the Authorize() calls to the original recorder. If the caller // uses CompileToSQL, all recording stops. This is to support parity between // memory and SQL backed dbs. type PreparedRecorder struct { rec *RecordingAuthorizer prepped rbac.PreparedAuthorized subject rbac.Subject action rbac.Action rw sync.Mutex usingSQL bool } func (s *PreparedRecorder) Authorize(ctx context.Context, object rbac.Object) error { s.rw.Lock() defer s.rw.Unlock() if !s.usingSQL { s.rec.recordAuthorize(s.subject, s.action, object) } return s.prepped.Authorize(ctx, object) } func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.ConvertConfig) (string, error) { s.rw.Lock() defer s.rw.Unlock() s.usingSQL = true return s.prepped.CompileToSQL(ctx, cfg) } // FakeAuthorizer is an Authorizer that always returns the same error. type FakeAuthorizer struct { // AlwaysReturn is the error that will be returned by Authorize. AlwaysReturn error } var _ rbac.Authorizer = (*FakeAuthorizer)(nil) func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error { return d.AlwaysReturn } func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: d, Subject: subject, Action: action, }, nil } var _ rbac.PreparedAuthorized = (*fakePreparedAuthorizer)(nil) // fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will // return the same error as the original FakeAuthorizer. type fakePreparedAuthorizer struct { Original *RecordingAuthorizer Subject rbac.Subject Action rbac.Action HardCodedSQLString string HardCodedRegoString string sync.RWMutex Original *FakeAuthorizer Subject rbac.Subject Action rbac.Action } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { Expand All @@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje // CompileToSQL returns a compiled version of the authorizer that will work for // in memory databases. This fake version will not work against a SQL database. func (fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) { return "", xerrors.New("not implemented") } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { return f.Original.AuthorizeSQL(context.Background(), f.Subject, f.Action, object) == nil } func (f fakePreparedAuthorizer) RegoString() string { if f.HardCodedRegoString != "" { return f.HardCodedRegoString } panic("not implemented") func (*fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) { return "not a valid sql string", nil }