Expand Up @@ -6,10 +6,8 @@ import ( "encoding/json" "fmt" "io" "net" "net/http" "net/http/httptest" "slices" "strings" "sync/atomic" "testing" Expand All @@ -18,12 +16,13 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" "golang.org/x/oauth2" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmem " "github.com/coder/coder/v2/coderd/database/dbtestutil " "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" Expand Down Expand Up @@ -83,9 +82,9 @@ func TestAPIKey(t *testing.T) { t.Run("NoCookie", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, Expand All @@ -99,9 +98,9 @@ func TestAPIKey(t *testing.T) { t.Run("NoCookieRedirects", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, Expand All @@ -118,9 +117,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidFormat", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "test-wow-hello") Expand All @@ -136,9 +135,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidIDLength", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "test-wow") Expand All @@ -154,9 +153,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecretLength", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "testtestid-wow") Expand All @@ -172,7 +171,7 @@ func TestAPIKey(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) id, secret = randomAPIKeyParts() r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() Expand All @@ -191,10 +190,10 @@ func TestAPIKey(t *testing.T) { t.Run("UserLinkNotFound", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() user = dbgen.User(t, db, database.User{ db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() user = dbgen.User(t, db, database.User{ LoginType: database.LoginTypeGithub, }) // Intentionally not inserting any user link Expand All @@ -219,10 +218,10 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() user = dbgen.User(t, db, database.User{}) db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() user = dbgen.User(t, db, database.User{}) // Use a different secret so they don't match! hashed = sha256.Sum256([]byte("differentsecret")) Expand All @@ -244,7 +243,7 @@ func TestAPIKey(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -273,7 +272,7 @@ func TestAPIKey(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -309,7 +308,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidWithScope", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -347,7 +346,7 @@ func TestAPIKey(t *testing.T) { t.Run("QueryParameter", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -381,7 +380,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateLastUsed", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -412,7 +411,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateExpiry", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -443,7 +442,7 @@ func TestAPIKey(t *testing.T) { t.Run("NoRefresh", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -475,7 +474,7 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -511,7 +510,7 @@ func TestAPIKey(t *testing.T) { t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -561,7 +560,7 @@ func TestAPIKey(t *testing.T) { t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -607,7 +606,7 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand All @@ -630,7 +629,7 @@ func TestAPIKey(t *testing.T) { oauthToken := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", Expiry:dbtime.Now ().AddDate(0, 0, 1), Expiry:dbtestutil.NowInDefaultTimezone ().AddDate(0, 0, 1), } httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, Expand Down Expand Up @@ -665,7 +664,7 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( ctx = testutil.Context(t, testutil.WaitShort) db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -715,7 +714,7 @@ func TestAPIKey(t *testing.T) { t.Run("RemoteIPUpdates", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand All @@ -740,15 +739,15 @@ func TestAPIKey(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t,net.ParseIP( "1.1.1.1") , gotAPIKey.IPAddress.IPNet.IP) require.Equal(t, "1.1.1.1", gotAPIKey.IPAddress.IPNet.IP.String() ) }) t.Run("RedirectToLogin", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ Expand All @@ -767,9 +766,9 @@ func TestAPIKey(t *testing.T) { t.Run("Optional", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() db, _ = dbtestutil.NewDB(t ) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() count int64 handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { Expand Down Expand Up @@ -798,7 +797,7 @@ func TestAPIKey(t *testing.T) { t.Run("Tokens", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -831,7 +830,7 @@ func TestAPIKey(t *testing.T) { t.Run("MissingConfig", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, Expand Down Expand Up @@ -866,7 +865,7 @@ func TestAPIKey(t *testing.T) { t.Run("CustomRoles", func(t *testing.T) { t.Parallel() var ( db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) org = dbgen.Organization(t, db, database.Organization{}) customRole = dbgen.CustomRole(t, db, database.CustomRole{ Name: "custom-role", Expand Down Expand Up @@ -933,7 +932,7 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( roleNotExistsName = "role-not-exists" db = dbmem.New( ) db, _ = dbtestutil.NewDB(t ) org = dbgen.Organization(t, db, database.Organization{}) user = dbgen.User(t, db, database.User{ RBACRoles: []string{ Expand Down