Expand Up @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "io" "net" Expand Down Expand Up @@ -41,7 +42,7 @@ import ( type FakeIDP struct { issuer string key *rsa.PrivateKey providerproviderJSON providerProviderJSON handler http.Handler cfg *oauth2.Config Expand All @@ -66,7 +67,7 @@ type FakeIDP struct { // IDP -> Application. Almost all IDPs have the concept of // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error hookUserInfo func(email string) jwt.MapClaims hookUserInfo func(email string)( jwt.MapClaims, error) fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want Expand All @@ -75,6 +76,26 @@ type FakeIDP struct { serve bool } func StatusError(code int, err error) error { return statusHookError{ Err: err, HTTPStatusCode: code, } } // statusHookError allows a hook to change the returned http status code. type statusHookError struct { Err error HTTPStatusCode int } func (s statusHookError) Error() string { if s.Err == nil { return "" } return s.Err.Error() } type FakeIDPOpt func(idp *FakeIDP) func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { Expand All @@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID } } //WithRefreshHook is called when a refresh token is used. The email is //WithRefresh is called when a refresh token is used. The email is // the email of the user that is being refreshed assuming the claims are correct. funcWithRefreshHook (hook func(email string) error) func(*FakeIDP) { funcWithRefresh (hook func(email string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookOnRefresh = hook } Expand All @@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = func(_ string) jwt.MapClaims { return info f.hookUserInfo = func(_ string)( jwt.MapClaims, error) { return info, nil } } } func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) { func WithDynamicUserInfo(userInfoFunc func(email string)( jwt.MapClaims, error) ) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = userInfoFunc } Expand Down Expand Up @@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, hookUserInfo: func(email string)( jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, } Expand All @@ -181,16 +202,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { return idp } func (f *FakeIDP) WellknownConfig() ProviderJSON { return f.provider } func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { t.Helper() u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") f.issuer = issuer //providerJSON is the JSON representation of the OpenID Connect provider //ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. f.provider =providerJSON { f.provider =ProviderJSON { Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), Expand Down Expand Up @@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { return srv } // GenerateAuthenticatedToken skips all oauth2 flows, and just generates a // valid token for some given claims. func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { state := uuid.NewString() f.stateToIDTokenClaims.Store(state, claims) code := f.newCode(state) return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } // Login does the full OIDC flow starting at the "LoginButton". // The client argument is just to get the URL of the Coder instance. // Expand Down Expand Up @@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map return resp, nil } type providerJSON struct { // ProviderJSON is the .well-known/configuration JSON type ProviderJSON struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` Expand Down Expand Up @@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()),httpErrorCode( http.StatusBadRequest, err) ) return } Expand All @@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { slog.F("values", values.Encode()), ) if err != nil { http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()),httpErrorCode( http.StatusBadRequest, err) ) return } getEmail := func(claims jwt.MapClaims) string { Expand Down Expand Up @@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest) http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()),httpErrorCode( http.StatusBadRequest, err) ) return } Expand Down Expand Up @@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) return } _ = json.NewEncoder(rw).Encode(f.hookUserInfo(email)) claims, err := f.hookUserInfo(email) if err != nil { http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } _ = json.NewEncoder(rw).Encode(claims) })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { Expand Down Expand Up @@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co return cfg } func httpErrorCode(defaultCode int, err error) int { var stautsErr statusHookError status := defaultCode if errors.As(err, &stautsErr) { status = stautsErr.HTTPStatusCode } return status } type fakeRoundTripper struct { roundTrip func(req *http.Request) (*http.Response, error) } Expand Down