- Notifications
You must be signed in to change notification settings - Fork907
fix: stop extending API key access if OIDC refresh is available#17878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Changes fromall commits
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values | ||
// WithLogging is optional, but will log some HTTP calls made to the IDP. | ||
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { | ||
return func(f *FakeIDP) { | ||
f.logger = slogtest.Make(t, options).Named("fakeidp") | ||
} | ||
} | ||
@@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string | ||
func (f *FakeIDP) newRefreshTokens(email string) string { | ||
refreshToken := uuid.NewString() | ||
f.refreshTokens.Store(refreshToken, email) | ||
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken)) | ||
Emyrk marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
return refreshToken | ||
} | ||
@@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { | ||
return | ||
} | ||
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken)) | ||
_, ok := f.refreshTokens.Load(refreshToken) | ||
if !assert.True(t, ok, "invalid refresh_token") { | ||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest) | ||
@@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { | ||
f.refreshTokensUsed.Store(refreshToken, true) | ||
// Always invalidate the refresh token after it is used. | ||
f.refreshTokens.Delete(refreshToken) | ||
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken)) | ||
case "urn:ietf:params:oauth:grant-type:device_code": | ||
// Device flow | ||
var resp externalauth.ExchangeDeviceCodeResponse | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
return optionalWrite(http.StatusUnauthorized, resp) | ||
} | ||
now := dbtime.Now() | ||
if key.ExpiresAt.Before(now) { | ||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{ | ||
Message: SignedOutErrorMessage, | ||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()), | ||
}) | ||
} | ||
spikecurtis marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor | ||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly | ||
// refreshing the OIDC token. | ||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { | ||
var err error | ||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid. | ||
link, err:= cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ | ||
UserID: key.UserID, | ||
LoginType: key.LoginType, | ||
}) | ||
@@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
}) | ||
} | ||
// Check if the OAuth token is expired | ||
if!link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) { | ||
spikecurtis marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
if cfg.OAuth2Configs.IsZero() { | ||
return write(http.StatusInternalServerError, codersdk.Response{ | ||
Message: internalErrorMessage, | ||
@@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
}) | ||
} | ||
var friendlyName string | ||
var oauthConfig promoauth.OAuth2Config | ||
switch key.LoginType { | ||
case database.LoginTypeGithub: | ||
oauthConfig = cfg.OAuth2Configs.Github | ||
friendlyName = "GitHub" | ||
case database.LoginTypeOIDC: | ||
oauthConfig = cfg.OAuth2Configs.OIDC | ||
friendlyName = "OpenID Connect" | ||
default: | ||
return write(http.StatusInternalServerError, codersdk.Response{ | ||
Message: internalErrorMessage, | ||
@@ -292,36 +300,53 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
}) | ||
} | ||
if link.OAuthRefreshToken == "" { | ||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{ | ||
Message: SignedOutErrorMessage, | ||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()), | ||
}) | ||
} | ||
// We have a refresh token, so let's try it | ||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ | ||
AccessToken: link.OAuthAccessToken, | ||
RefreshToken: link.OAuthRefreshToken, | ||
Expiry: link.OAuthExpiry, | ||
}).Token() | ||
if err != nil { | ||
return write(http.StatusUnauthorized, codersdk.Response{ | ||
Message: fmt.Sprintf( | ||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.", | ||
friendlyName), | ||
Detail: err.Error(), | ||
}) | ||
} | ||
link.OAuthAccessToken = token.AccessToken | ||
link.OAuthRefreshToken = token.RefreshToken | ||
link.OAuthExpiry = token.Expiry | ||
//nolint:gocritic // system needs to update user link | ||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ | ||
UserID: link.UserID, | ||
LoginType: link.LoginType, | ||
OAuthAccessToken: link.OAuthAccessToken, | ||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required | ||
OAuthRefreshToken: link.OAuthRefreshToken, | ||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required | ||
OAuthExpiry: link.OAuthExpiry, | ||
// Refresh should keep the same debug context because we use | ||
// the original claims for the group/role sync. | ||
Claims: link.Claims, | ||
}) | ||
if err != nil { | ||
return write(http.StatusInternalServerError, codersdk.Response{ | ||
Message: internalErrorMessage, | ||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()), | ||
}) | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Note that | ||
// Tracks if the API key has properties updated | ||
changed := false | ||
// Only update LastUsed once an hour to prevent database spam. | ||
if now.Sub(key.LastUsed) > time.Hour { | ||
@@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()), | ||
}) | ||
} | ||
// We only want to update this occasionally to reduce DB write | ||
// load. We update alongside the UserLink and APIKey since it's | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) { | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) | ||
}) | ||
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db = dbmem.New() | ||
user = dbgen.User(t, db, database.User{}) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().AddDate(0, 0, -1), | ||
ExpiresAt: dbtime.Now().AddDate(0, 0, -1), | ||
LoginType: database.LoginTypeOIDC, | ||
}) | ||
_ = dbgen.UserLink(t, db, database.UserLink{ | ||
UserID: user.ID, | ||
LoginType: database.LoginTypeOIDC, | ||
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1), | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
// Include a valid oauth token for refreshing. If this token is invalid, | ||
// it is difficult to tell an auth failure from an expired api key, or | ||
// an expired oauth key. | ||
oauthToken := &oauth2.Token{ | ||
AccessToken: "wow", | ||
RefreshToken: "moo", | ||
Expiry: dbtime.Now().AddDate(0, 0, 1), | ||
} | ||
spikecurtis marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
OAuth2Configs: &httpmw.OAuth2Configs{ | ||
OIDC: &testutil.OAuth2Config{ | ||
Token: oauthToken, | ||
}, | ||
}, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusUnauthorized, res.StatusCode) | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) | ||
}) | ||
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db = dbmem.New() | ||
user = dbgen.User(t, db, database.User{}) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().AddDate(0, 0, -1), | ||
ExpiresAt: dbtime.Now().AddDate(0, 0, -1), | ||
LoginType: database.LoginTypeOIDC, | ||
}) | ||
_ = dbgen.UserLink(t, db, database.UserLink{ | ||
UserID: user.ID, | ||
LoginType: database.LoginTypeOIDC, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
oauthToken := &oauth2.Token{ | ||
AccessToken: "wow", | ||
RefreshToken: "moo", | ||
Expiry: dbtime.Now().AddDate(0, 0, 1), | ||
} | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
OAuth2Configs: &httpmw.OAuth2Configs{ | ||
OIDC: &testutil.OAuth2Config{ | ||
Token: oauthToken, | ||
}, | ||
}, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusUnauthorized, res.StatusCode) | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) | ||
}) | ||
t.Run("OAuthRefresh", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
@@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) { | ||
require.NoError(t, err) | ||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) | ||
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the | ||
// APIKey | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) | ||
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ | ||
UserID: user.ID, | ||
LoginType: database.LoginTypeGithub, | ||
}) | ||
require.NoError(t, err) | ||
require.Equal(t, gotLink.OAuthRefreshToken, "moo") | ||
}) | ||
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
ctx = testutil.Context(t, testutil.WaitShort) | ||
db = dbmem.New() | ||
user = dbgen.User(t, db, database.User{}) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now(), | ||
ExpiresAt: dbtime.Now().AddDate(0, 0, 1), | ||
LoginType: database.LoginTypeGithub, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ | ||
UserID: user.ID, | ||
LoginType: database.LoginTypeGithub, | ||
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1), | ||
OAuthAccessToken: "letmein", | ||
}) | ||
require.NoError(t, err) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
oauthToken := &oauth2.Token{ | ||
AccessToken: "wow", | ||
RefreshToken: "moo", | ||
Expiry: dbtime.Now().AddDate(0, 0, 1), | ||
} | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
OAuth2Configs: &httpmw.OAuth2Configs{ | ||
Github: &testutil.OAuth2Config{ | ||
Token: oauthToken, | ||
}, | ||
}, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusUnauthorized, res.StatusCode) | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) | ||
}) | ||
t.Run("RemoteIPUpdates", func(t *testing.T) { | ||
Uh oh!
There was an error while loading.Please reload this page.