|
1 | 1 | package oidctest
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | +"context" |
4 | 5 | "database/sql"
|
5 | 6 | "encoding/json"
|
6 | 7 | "net/http"
|
| 8 | +"net/url" |
7 | 9 | "testing"
|
8 | 10 | "time"
|
9 | 11 |
|
10 | 12 | "github.com/golang-jwt/jwt/v4"
|
11 | 13 | "github.com/stretchr/testify/require"
|
| 14 | +"golang.org/x/xerrors" |
12 | 15 |
|
13 | 16 | "github.com/coder/coder/v2/coderd/database"
|
14 | 17 | "github.com/coder/coder/v2/coderd/database/dbauthz"
|
@@ -114,3 +117,51 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders
|
114 | 117 | _,err:=user.User(testutil.Context(t,testutil.WaitShort),"me")
|
115 | 118 | require.NoError(t,err,"user must be able to be fetched")
|
116 | 119 | }
|
| 120 | + |
| 121 | +// OAuth2GetCode emulates a user clicking "allow" on the IDP page. When doing |
| 122 | +// unit tests, it's easier to skip this step sometimes. It does make an actual |
| 123 | +// request to the IDP, so it should be equivalent to doing this "manually" with |
| 124 | +// actual requests. |
| 125 | +// |
| 126 | +// TODO: Is state param optional? Can we grab it from the authURL? |
| 127 | +funcOAuth2GetCode(authURLstring,statestring,doRequestfunc(req*http.Request) (*http.Response,error)) (string,error) { |
| 128 | +// We need to store some claims, because this is also an OIDC provider, and |
| 129 | +// it expects some claims to be present. |
| 130 | +// TODO: POST or GET method? |
| 131 | +r,err:=http.NewRequestWithContext(context.Background(),http.MethodGet,authURL,nil) |
| 132 | +iferr!=nil { |
| 133 | +return"",xerrors.Errorf("failed to create auth request: %w",err) |
| 134 | +} |
| 135 | + |
| 136 | +expCode:=http.StatusTemporaryRedirect |
| 137 | +resp,err:=doRequest(r) |
| 138 | +iferr!=nil { |
| 139 | +return"",xerrors.Errorf("request: %w",err) |
| 140 | +} |
| 141 | +deferresp.Body.Close() |
| 142 | + |
| 143 | +ifresp.StatusCode!=expCode { |
| 144 | +return"",codersdk.ReadBodyAsError(resp) |
| 145 | +} |
| 146 | + |
| 147 | +to:=resp.Header.Get("Location") |
| 148 | +ifto=="" { |
| 149 | +return"",xerrors.Errorf("expected redirect location") |
| 150 | +} |
| 151 | + |
| 152 | +toURL,err:=url.Parse(to) |
| 153 | +iferr!=nil { |
| 154 | +return"",xerrors.Errorf("failed to parse redirect location: %w",err) |
| 155 | +} |
| 156 | + |
| 157 | +code:=toURL.Query().Get("code") |
| 158 | +ifcode=="" { |
| 159 | +return"",xerrors.Errorf("expected code in redirect location") |
| 160 | +} |
| 161 | + |
| 162 | +newState:=toURL.Query().Get("state") |
| 163 | +ifnewState!=state { |
| 164 | +return"",xerrors.Errorf("expected state %q, got %q",state,newState) |
| 165 | +} |
| 166 | +returncode,nil |
| 167 | +} |