- Notifications
You must be signed in to change notification settings - Fork913
Fix API key refresh for short expiration times#18351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -446,7 +446,7 @@ func (r *RootCmd) configSSH() *serpent.Command { | ||
if !bytes.Equal(configRaw, configModified) { | ||
sshDir := filepath.Dir(sshConfigFile) | ||
if err := os.MkdirAll(sshDir,0o700); err != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Unrelated change | ||
return xerrors.Errorf("failed to create directory %q: %w", sshDir, err) | ||
} | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -207,7 +207,7 @@ func TestConfigSSH_MissingDirectory(t *testing.T) { | ||
// Check that the directory has proper permissions (0700) | ||
sshDirInfo, err := os.Stat(sshDir) | ||
require.NoError(t, err) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. All changes in this file are unrelated | ||
require.Equal(t, os.FileMode(0o700), sshDirInfo.Mode().Perm(), "directory should have 0700 permissions") | ||
} | ||
func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { | ||
@@ -358,7 +358,8 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { | ||
strings.Join([]string{ | ||
headerEnd, | ||
"", | ||
}, "\n"), | ||
}, | ||
}, | ||
args: []string{"--ssh-option", "ForwardAgent=yes"}, | ||
matches: []match{ | ||
@@ -383,7 +384,8 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { | ||
strings.Join([]string{ | ||
headerEnd, | ||
"", | ||
}, "\n"), | ||
}, | ||
}, | ||
args: []string{"--ssh-option", "ForwardAgent=yes"}, | ||
matches: []match{ | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -369,7 +369,13 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon | ||
// We extend the ExpiresAt to reduce re-authentication. | ||
if !cfg.DisableSessionExpiryRefresh { | ||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second | ||
// For short-lived keys (< 1 hour), refresh when half the lifetime remains. | ||
// For longer keys, refresh when 1 hour remains. | ||
refreshThreshold := apiKeyLifetime - time.Hour | ||
if apiKeyLifetime < time.Hour { | ||
refreshThreshold = apiKeyLifetime / 2 | ||
} | ||
if key.ExpiresAt.Sub(now) <= refreshThreshold { | ||
Comment on lines +372 to +378 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. This still seems wrong? For example, if the lifetime of a key is 1 hour 1 second, then the threshold is 1 and it will not refresh until there is 1 second left, rather than when there are 30 minutes left like the comment implies. Also if the lifetime is exactly one hour it will only refresh if the request is made exactly at the time of expiration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. You're absolutely right! The current logic has a bug for keys with lifetime just over 1 hour. Let me fix this. The issue is that for a key with 1 hour 1 second lifetime:
I think a better approach would be:
This would ensure that even very long-lived keys get refreshed at reasonable intervals. I'll update the logic and tests. | ||
key.ExpiresAt = now.Add(apiKeyLifetime) | ||
changed = true | ||
} | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -991,3 +991,237 @@ func TestAPIKey(t *testing.T) { | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
}) | ||
} | ||
func TestAPIKeyExpiryRefresh(t *testing.T) { | ||
t.Parallel() | ||
t.Run("ShortLivedKeyRefresh", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db, _ = dbtestutil.NewDB(t) | ||
user = dbgen.User(t, db, database.User{}) | ||
now = time.Now() | ||
// 2 minute lifetime | ||
lifetime = int64(120) | ||
// Key expires in 30 seconds (less than half the 2-minute lifetime) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. All the other tests use | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().Add(-time.Hour), | ||
ExpiresAt: dbtime.Time(now.Add(30 * time.Second)), | ||
LifetimeSeconds: lifetime, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
// Mock time to be exactly now | ||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ | ||
Message: "It worked!", | ||
}) | ||
}) | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
Comment on lines +1018 to +1031 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Could we just make another API call instead of mocking all this? Fetch | ||
// Verify the key was refreshed | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
// For a 2-minute key, refresh threshold is 1 minute (half the lifetime) | ||
// Since the key expires in 30 seconds and threshold is 1 minute, | ||
// it should be refreshed | ||
require.True(t, gotAPIKey.ExpiresAt.After(sentAPIKey.ExpiresAt), | ||
"API key should have been refreshed for short-lived key") | ||
// Should be extended by the full lifetime (2 minutes) | ||
expectedExpiry := now.Add(time.Duration(lifetime) * time.Second) | ||
require.WithinDuration(t, expectedExpiry, gotAPIKey.ExpiresAt, time.Second, | ||
"API key should be extended by full lifetime") | ||
}) | ||
t.Run("ShortLivedKeyNoRefresh", func(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. I feel like these would all make more sense in a test table, the only difference is the lifetime. Also feel like maybe it can use a range. | ||
t.Parallel() | ||
var ( | ||
db, _ = dbtestutil.NewDB(t) | ||
user = dbgen.User(t, db, database.User{}) | ||
now = time.Now() | ||
// 2 minute lifetime | ||
lifetime = int64(120) | ||
// Key expires in 90 seconds (more than half the lifetime) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().Add(-time.Hour), | ||
ExpiresAt: dbtime.Time(now.Add(90 * time.Second)), | ||
LifetimeSeconds: lifetime, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ | ||
Message: "It worked!", | ||
}) | ||
}) | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
// Verify the key was NOT refreshed | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
// For a 2-minute key with 90 seconds remaining (> 60 second threshold), | ||
// it should NOT be refreshed | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt, | ||
"API key should NOT have been refreshed when above threshold") | ||
}) | ||
t.Run("LongLivedKeyRefresh", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db, _ = dbtestutil.NewDB(t) | ||
user = dbgen.User(t, db, database.User{}) | ||
now = time.Now() | ||
// 2 hour lifetime | ||
lifetime = int64(7200) | ||
// Key expires in 30 minutes (less than 1 hour threshold) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().Add(-time.Hour), | ||
ExpiresAt: dbtime.Time(now.Add(30 * time.Minute)), | ||
LifetimeSeconds: lifetime, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ | ||
Message: "It worked!", | ||
}) | ||
}) | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
// Verify the key was refreshed | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
// For a 2-hour key with 30 minutes remaining (< 1 hour threshold), | ||
// it should be refreshed | ||
require.True(t, gotAPIKey.ExpiresAt.After(sentAPIKey.ExpiresAt), | ||
"API key should have been refreshed for long-lived key") | ||
// Should be extended by the full lifetime (2 hours) | ||
expectedExpiry := now.Add(time.Duration(lifetime) * time.Second) | ||
require.WithinDuration(t, expectedExpiry, gotAPIKey.ExpiresAt, time.Second, | ||
"API key should be extended by full lifetime") | ||
}) | ||
t.Run("LongLivedKeyNoRefresh", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db, _ = dbtestutil.NewDB(t) | ||
user = dbgen.User(t, db, database.User{}) | ||
now = time.Now() | ||
// 2 hour lifetime | ||
lifetime = int64(7200) | ||
// Key expires in 90 minutes (more than 1 hour threshold) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().Add(-time.Hour), | ||
ExpiresAt: dbtime.Time(now.Add(90 * time.Minute)), | ||
LifetimeSeconds: lifetime, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ | ||
Message: "It worked!", | ||
}) | ||
}) | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
RedirectToLogin: false, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
// Verify the key was NOT refreshed | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
// For a 2-hour key with 90 minutes remaining (> 1 hour threshold), | ||
// it should NOT be refreshed | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt, | ||
"API key should NOT have been refreshed when above threshold") | ||
}) | ||
t.Run("RefreshDisabled", func(t *testing.T) { | ||
t.Parallel() | ||
var ( | ||
db, _ = dbtestutil.NewDB(t) | ||
user = dbgen.User(t, db, database.User{}) | ||
now = time.Now() | ||
// 2 minute lifetime | ||
lifetime = int64(120) | ||
// Key expires in 30 seconds (well below threshold) | ||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ | ||
UserID: user.ID, | ||
LastUsed: dbtime.Now().Add(-time.Hour), | ||
ExpiresAt: dbtime.Time(now.Add(30 * time.Second)), | ||
LifetimeSeconds: lifetime, | ||
}) | ||
r = httptest.NewRequest("GET", "/", nil) | ||
rw = httptest.NewRecorder() | ||
) | ||
r.Header.Set(codersdk.SessionTokenHeader, token) | ||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ | ||
Message: "It worked!", | ||
}) | ||
}) | ||
// Disable session expiry refresh | ||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ | ||
DB: db, | ||
RedirectToLogin: false, | ||
DisableSessionExpiryRefresh: true, | ||
})(successHandler).ServeHTTP(rw, r) | ||
res := rw.Result() | ||
defer res.Body.Close() | ||
require.Equal(t, http.StatusOK, res.StatusCode) | ||
// Verify the key was NOT refreshed even though it should have been | ||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) | ||
require.NoError(t, err) | ||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt, | ||
"API key should NOT have been refreshed when refresh is disabled") | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -94,7 +94,6 @@ func (i *dispatchInterceptor) Dispatcher(payload types.MessagePayload, title, bo | ||
} | ||
retryable, err = deliveryFn(ctx, msgID) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Unrelated change | ||
if err != nil { | ||
i.err.Add(1) | ||
i.lastErr.Store(err) | ||
Uh oh!
There was an error while loading.Please reload this page.