Expand Up @@ -39,6 +39,12 @@ import ( "github.com/coder/coder/v2/codersdk" ) type token struct { issued time.Time email string exp time.Time } // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { Expand All @@ -65,7 +71,7 @@ type FakeIDP struct { // That is the various access tokens, refresh tokens, states, etc. codeToStateMap *syncmap.Map[string, string] // Token -> Email accessTokens *syncmap.Map[string,string ] accessTokens *syncmap.Map[string,token ] // Refresh Token -> Email refreshTokensUsed *syncmap.Map[string, bool] refreshTokens *syncmap.Map[string, string] Expand All @@ -89,7 +95,8 @@ type FakeIDP struct { hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) serve bool // optional middlewares middlewares chi.Middlewares middlewares chi.Middlewares defaultExpire time.Duration } func StatusError(code int, err error) error { Expand Down Expand Up @@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) { } } func WithDefaultExpire(d time.Duration) func(*FakeIDP) { return func(f *FakeIDP) { f.defaultExpire = d } } func WithStaticCredentials(id, secret string) func(*FakeIDP) { return func(f *FakeIDP) { if id != "" { f.clientID = id } if secret != "" { f.clientSecret = secret } } } // WithExtra returns extra fields that be accessed on the returned Oauth Token. // These extra fields can override the default fields (id_token, access_token, etc). func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) { Expand All @@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { } } func WithLogger(logger slog.Logger) func(*FakeIDP) { return func(f *FakeIDP) { f.logger = logger } } // WithStaticUserInfo is optional, but will return the same user info for // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { Expand Down Expand Up @@ -211,14 +241,15 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { clientSecret: uuid.NewString(), logger: slog.Make(), codeToStateMap: syncmap.New[string, string](), accessTokens: syncmap.New[string,string ](), accessTokens: syncmap.New[string,token ](), refreshTokens: syncmap.New[string, string](), refreshTokensUsed: syncmap.New[string, bool](), stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, defaultExpire: time.Minute * 5, } for _, opt := range opts { Expand Down Expand Up @@ -265,15 +296,31 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { Algorithms: []string{ "RS256", }, ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), } } // realServer turns the FakeIDP into a real http server. func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() srvURL := "localhost:0" issURL, err := url.Parse(f.issuer) if err == nil { if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { srvURL = issURL.Host } } l, err := net.Listen("tcp", srvURL) require.NoError(t, err, "failed to create listener") ctx, cancel := context.WithCancel(context.Background()) srv := httptest.NewUnstartedServer(f.handler) srv := &httptest.Server{ Listener: l, Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5}, } srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } Expand Down Expand Up @@ -495,6 +542,8 @@ type ProviderJSON struct { JWKSURL string `json:"jwks_uri"` UserInfoURL string `json:"userinfo_endpoint"` Algorithms []string `json:"id_token_signing_alg_values_supported"` // This is custom ExternalAuthURL string `json:"external_auth_url"` } // newCode enforces the code exchanged is actually a valid code Expand All @@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string { // newToken enforces the access token exchanged is actually a valid access token // created by the IDP. func (f *FakeIDP) newToken(email string) string { func (f *FakeIDP) newToken(email string, expires time.Time ) string { accessToken := uuid.NewString() f.accessTokens.Store(accessToken, email) f.accessTokens.Store(accessToken, token{ issued: time.Now(), email: email, exp: expires, }) return accessToken } Expand All @@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request auth := req.Header.Get("Authorization") token := strings.TrimPrefix(auth, "Bearer ") _ , ok := f.accessTokens.Load(token)authToken , ok := f.accessTokens.Load(token)if !ok { return "", xerrors.New("invalid access token") } if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { return "", xerrors.New("access token expired") } return token, nil } Expand Down Expand Up @@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { values, err := f.authenticateOIDCClientRequest(t, r) f.logger.Info(r.Context(), "http idp call token", slog.Error(err), slog.F("valid", err == nil), slog.F("grant_type", values.Get("grant_type")), slog.F("values", values.Encode()), ) if err != nil { Expand Down Expand Up @@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return } exp := time.Now().Add(time.Minute * 5 ) exp := time.Now().Add(f.defaultExpire ) claims["exp"] = exp.UnixMilli() email := getEmail(claims) refreshToken := f.newRefreshTokens(email) token := map[string]interface{}{ "access_token": f.newToken(email), "access_token": f.newToken(email, exp ), "refresh_token": refreshToken, "token_type": "Bearer", "expires_in": int64((time.Minute * 5 ).Seconds()), "expires_in": int64((f.defaultExpire ).Seconds()), "id_token": f.encodeClaims(t, claims), } if f.hookMutateToken != nil { Expand All @@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) { token, err := f.authenticateBearerTokenRequest(t, r) f.logger.Info(r.Context(), "http call idp user info", slog.Error(err), slog.F("url", r.URL.String()), ) if err != nil { http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest ) http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized ) return "", false } email , ok = f.accessTokens.Load(token)authToken , ok: = f.accessTokens.Load(token)if !ok { t.Errorf("access token user for user_info has no email to indicate which user") http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized) return "", false } if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { http.Error(rw, "auth token expired", http.StatusUnauthorized) return "", false } return email, true return authToken.email, true } mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) f.logger.Info(r.Context(), "http userinfo endpoint", slog.F("valid", ok), slog.F("email", email), ) if !ok { return } Expand All @@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // should be strict, and this one needs to handle sub routes. mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) f.logger.Info(r.Context(), "http external auth validate", slog.F("valid", ok), slog.F("email", email), ) if !ok { return } Expand Down Expand Up @@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } f.externalProviderID = id f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) { newPath := strings.TrimPrefix(r.URL.Path,fmt.Sprintf( "/external-auth-validate/%s", id) ) newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate" ) switch newPath { // /user is ALWAYS supported under the `/` path too. case "/user", "/", "": Expand All @@ -965,18 +1034,20 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) cfg := &externalauth.Config{ DisplayName: id, InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)), ID: id, // No defaults for these fields by omitting the type Type: "", DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path:fmt.Sprintf( "/external-auth-validate/%s", id) }).String(), ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/" }).String(), } for _, opt := range opts { opt(cfg) } f.updateIssuerURL(t, f.issuer) return cfg } Expand Down