Expand Up @@ -10,11 +10,14 @@ import ( "errors" "fmt" "io" "math/rand" "mime" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "strconv" "strings" "testing" "time" Expand All @@ -34,9 +37,11 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/util/syncmap" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) type token struct { Expand All @@ -45,6 +50,13 @@ type token struct { exp time.Time } type deviceFlow struct { // userInput is the expected input to authenticate the device flow. userInput string exp time.Time granted bool } // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { Expand Down Expand Up @@ -77,6 +89,8 @@ type FakeIDP struct { refreshTokens *syncmap.Map[string, string] stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims] refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims] // Device flow deviceCode *syncmap.Map[string, deviceFlow] // hooks // hookValidRedirectURL can be used to reject a redirect url from the Expand Down Expand Up @@ -226,6 +240,8 @@ const ( authorizePath = "/oauth2/authorize" keysPath = "/oauth2/keys" userInfoPath = "/oauth2/userinfo" deviceAuth = "/login/device/code" deviceVerify = "/login/device" ) func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { Expand All @@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { refreshTokensUsed: syncmap.New[string, bool](), stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), deviceCode: syncmap.New[string, deviceFlow](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, Expand Down Expand Up @@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. f.provider = ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(), Algorithms: []string{ "RS256", }, Expand Down Expand Up @@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f _ = res.Body.Close() } // DeviceLogin does the oauth2 device flow for external auth providers. func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) { // First we need to initiate the device flow. This will have Coder hit the // fake IDP and get a device code. device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID) require.NoError(t, err) // Now the user needs to go to the fake IDP page and click "allow" and enter // the device code input. For our purposes, we just send an http request to // the verification url. No additional user input is needed. ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil) require.NoError(t, err) defer resp.Body.Close() // Now we need to exchange the device code for an access token. We do this // in this method because it is the user that does the polling for the device // auth flow, not the backend. err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{ DeviceCode: device.DeviceCode, }) require.NoError(t, err) } // CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing // unit tests, it's easier to skip this step sometimes. It does make an actual // request to the IDP, so it should be equivalent to doing this "manually" with Expand Down Expand Up @@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map // ProviderJSON is the .well-known/configuration JSON type ProviderJSON struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` JWKSURL string `json:"jwks_uri"` UserInfoURL string `json:"userinfo_endpoint"` Algorithms []string `json:"id_token_signing_alg_values_supported"` Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` JWKSURL string `json:"jwks_uri"` UserInfoURL string `json:"userinfo_endpoint"` DeviceCodeURL string `json:"device_authorization_endpoint"` Algorithms []string `json:"id_token_signing_alg_values_supported"` // This is custom ExternalAuthURL string `json:"external_auth_url"` } Expand Down Expand Up @@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { values, err := f.authenticateOIDCClientRequest(t, r) var values url.Values var err error if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" { values = r.URL.Query() } else { values, err = f.authenticateOIDCClientRequest(t, r) } f.logger.Info(r.Context(), "http idp call token", slog.F("url", r.URL.String()), slog.F("valid", err == nil), slog.F("grant_type", values.Get("grant_type")), slog.F("values", values.Encode()), Expand Down Expand Up @@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { f.refreshTokensUsed.Store(refreshToken, true) // Always invalidate the refresh token after it is used. f.refreshTokens.Delete(refreshToken) case "urn:ietf:params:oauth:grant-type:device_code": // Device flow var resp externalauth.ExchangeDeviceCodeResponse deviceCode := values.Get("device_code") if deviceCode == "" { resp.Error = "invalid_request" resp.ErrorDescription = "missing device_code" httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) return } deviceFlow, ok := f.deviceCode.Load(deviceCode) if !ok { resp.Error = "invalid_request" resp.ErrorDescription = "device_code provided not found" httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) return } if !deviceFlow.granted { // Status code ok with the error as pending. resp.Error = "authorization_pending" resp.ErrorDescription = "" httpapi.Write(r.Context(), rw, http.StatusOK, resp) return } // Would be nice to get an actual email here. claims = jwt.MapClaims{ "email": "unknown-dev-auth", } default: t.Errorf("unexpected grant_type %q", values.Get("grant_type")) http.Error(rw, "invalid grant_type", http.StatusBadRequest) Expand All @@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // Store the claims for the next refresh f.refreshIDTokenClaims.Store(refreshToken, claims) rw.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(rw).Encode(token) mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")) if mediaType == "application/x-www-form-urlencoded" { // This val encode might not work for some data structures. // It's good enough for now... rw.Header().Set("Content-Type", "application/x-www-form-urlencoded") vals := url.Values{} for k, v := range token { vals.Set(k, fmt.Sprintf("%v", v)) } _, _ = rw.Write([]byte(vals.Encode())) return } // Default to json since the oauth2 package doesn't use Accept headers. if mediaType == "application/json" || mediaType == "" { rw.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(rw).Encode(token) return } // If we get something we don't support, throw an error. httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "'Accept' header contains unsupported media type", Detail: fmt.Sprintf("Found %q", mediaType), }) })) validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) { Expand Down Expand Up @@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { _ = json.NewEncoder(rw).Encode(set) })) mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call device verify") inputParam := "user_input" userInput := r.URL.Query().Get(inputParam) if userInput == "" { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid user input", Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam), }) return } deviceCode := r.URL.Query().Get("device_code") if deviceCode == "" { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Hit this url again with ?device_code=<device_code>", }) return } flow, ok := f.deviceCode.Load(deviceCode) if !ok { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Device code not found.", }) return } if time.Now().After(flow.exp) { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Device code expired.", }) return } if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "user code does not match", }) return } f.deviceCode.Store(deviceCode, deviceFlow{ userInput: flow.userInput, exp: flow.exp, granted: true, }) httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ Message: "Device authenticated!", }) })) mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call device auth") p := httpapi.NewQueryParamParser() p.Required("client_id") clientID := p.String(r.URL.Query(), "", "client_id") _ = p.String(r.URL.Query(), "", "scopes") if len(p.Errors) > 0 { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid query params", Validations: p.Errors, }) return } if clientID != f.clientID { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid client id", }) return } deviceCode := uuid.NewString() lifetime := time.Second * 900 flow := deviceFlow{ //nolint:gosec userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8), } f.deviceCode.Store(deviceCode, deviceFlow{ userInput: flow.userInput, exp: time.Now().Add(lifetime), }) verifyURL := f.issuerURL.ResolveReference(&url.URL{ Path: deviceVerify, RawQuery: url.Values{ "device_code": {deviceCode}, "user_input": {flow.userInput}, }.Encode(), }).String() if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" { httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{ "device_code": deviceCode, "user_code": flow.userInput, "verification_uri": verifyURL, "expires_in": int(lifetime.Seconds()), "interval": 3, }) return } // By default, GitHub form encodes these. _, _ = fmt.Fprint(rw, url.Values{ "device_code": {deviceCode}, "user_code": {flow.userInput}, "verification_uri": {verifyURL}, "expires_in": {strconv.Itoa(int(lifetime.Seconds()))}, "interval": {"3"}, }.Encode()) })) mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path)) t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) Expand Down Expand Up @@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct { // completely customize the response. It captures all routes under the /external-auth-validate/* // so the caller can do whatever they want and even add routes. routes map[string]func(email string, rw http.ResponseWriter, r *http.Request) UseDeviceAuth bool } func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions { Expand Down Expand Up @@ -1033,17 +1258,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } } instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil)) cfg := &externalauth.Config{ DisplayName: id, InstrumentedOAuth2Config:instrumentF.New(f.clientID, f.OIDCConfig(t, nil)) , InstrumentedOAuth2Config:oauthCfg , ID: id, // No defaults for these fields by omitting the type Type: "", DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), DeviceAuth: &externalauth.DeviceAuth{ Config: oauthCfg, ClientID: f.clientID, TokenURL: f.provider.TokenURL, Scopes: []string{}, CodeURL: f.provider.DeviceCodeURL, }, } if !custom.UseDeviceAuth { cfg.DeviceAuth = nil } for _, opt := range opts { opt(cfg) } Expand Down