1
1
package promoauth_test
2
2
3
3
import (
4
+ "context"
5
+ "fmt"
6
+ "io"
4
7
"net/http"
5
- "net/url"
8
+ "net/http/httptest"
9
+ "strings"
6
10
"testing"
7
11
"time"
8
12
9
13
"github.com/prometheus/client_golang/prometheus"
14
+ "github.com/prometheus/client_golang/prometheus/promhttp"
10
15
ptestutil"github.com/prometheus/client_golang/prometheus/testutil"
16
+ io_prometheus_client"github.com/prometheus/client_model/go"
17
+ "github.com/stretchr/testify/assert"
11
18
"github.com/stretchr/testify/require"
19
+ "golang.org/x/exp/maps"
20
+ "golang.org/x/oauth2"
12
21
13
22
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
14
23
"github.com/coder/coder/v2/coderd/externalauth"
@@ -22,53 +31,197 @@ func TestInstrument(t *testing.T) {
22
31
ctx := testutil .Context (t ,testutil .WaitShort )
23
32
idp := oidctest .NewFakeIDP (t ,oidctest .WithServing ())
24
33
reg := prometheus .NewRegistry ()
25
- count := func ()int {
26
- return ptestutil .CollectAndCount (reg ,"coderd_oauth2_external_requests_total" )
34
+ t .Cleanup (func () {
35
+ if t .Failed () {
36
+ t .Log (registryDump (reg ))
37
+ }
38
+ })
39
+
40
+ const id = "test"
41
+ labels := prometheus.Labels {
42
+ "name" :id ,
43
+ "status_code" :"200" ,
44
+ }
45
+ const metricname = "coderd_oauth2_external_requests_total"
46
+ count := func (source string )int {
47
+ labels ["source" ]= source
48
+ return counterValue (t ,reg ,"coderd_oauth2_external_requests_total" ,labels )
27
49
}
28
50
29
51
factory := promoauth .NewFactory (reg )
30
- const id = "test"
52
+
31
53
cfg := externalauth.Config {
32
54
InstrumentedOAuth2Config :factory .New (id ,idp .OIDCConfig (t , []string {})),
33
55
ID :"test" ,
34
56
ValidateURL :must [* url.URL ](t )(idp .IssuerURL ().Parse ("/oauth2/userinfo" )).String (),
35
57
}
36
58
37
59
// 0 Requests before we start
38
- require .Equal (t ,count ( ),0 )
60
+ require .Nil (t ,metricValue ( t , reg , metricname , labels ),"no metrics at start" )
39
61
40
62
// Exchange should trigger a request
41
63
code := idp .CreateAuthCode (t ,"foo" )
42
64
token ,err := cfg .Exchange (ctx ,code )
43
65
require .NoError (t ,err )
44
- require .Equal (t ,count (),1 )
66
+ require .Equal (t ,count ("Exchange" ),1 )
45
67
46
68
// Force a refresh
47
69
token .Expiry = time .Now ().Add (time .Hour * - 1 )
48
70
src := cfg .TokenSource (ctx ,token )
49
71
refreshed ,err := src .Token ()
50
72
require .NoError (t ,err )
51
73
require .NotEqual (t ,token .AccessToken ,refreshed .AccessToken ,"token refreshed" )
52
- require .Equal (t ,count (),2 )
74
+ require .Equal (t ,count ("TokenSource" ),1 )
53
75
54
76
// Try a validate
55
77
valid ,_ ,err := cfg .ValidateToken (ctx ,refreshed .AccessToken )
56
78
require .NoError (t ,err )
57
79
require .True (t ,valid )
58
- require .Equal (t ,count (),3 )
80
+ require .Equal (t ,count ("ValidateToken" ),1 )
59
81
60
82
// Verify the default client was not broken. This check is added because we
61
83
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
62
84
// mis-used. It is cheap to run this quick check.
85
+ snapshot := registryDump (reg )
63
86
req ,err := http .NewRequestWithContext (ctx ,http .MethodGet ,
64
- must [ * url. URL ]( t ) (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (),nil )
87
+ must (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (),nil )
65
88
require .NoError (t ,err )
66
89
67
90
resp ,err := http .DefaultClient .Do (req )
68
91
require .NoError (t ,err )
69
92
_ = resp .Body .Close ()
70
93
71
- require .Equal (t ,count (),3 )
94
+ require .NoError (t ,compare (reg ,snapshot ),"no metric changes" )
95
+ }
96
+
97
+ func TestGithubRateLimits (t * testing.T ) {
98
+ t .Parallel ()
99
+
100
+ now := time .Now ()
101
+ cases := []struct {
102
+ Name string
103
+ NoHeaders bool
104
+ Omit []string
105
+ ExpectNoMetrics bool
106
+ Limit int
107
+ Remaining int
108
+ Used int
109
+ Reset time.Time
110
+
111
+ at time.Time
112
+ }{
113
+ {
114
+ Name :"NoHeaders" ,
115
+ NoHeaders :true ,
116
+ ExpectNoMetrics :true ,
117
+ },
118
+ {
119
+ Name :"ZeroHeaders" ,
120
+ ExpectNoMetrics :true ,
121
+ },
122
+ {
123
+ Name :"OverLimit" ,
124
+ Limit :100 ,
125
+ Remaining :0 ,
126
+ Used :500 ,
127
+ Reset :now .Add (time .Hour ),
128
+ at :now ,
129
+ },
130
+ {
131
+ Name :"UnderLimit" ,
132
+ Limit :100 ,
133
+ Remaining :0 ,
134
+ Used :500 ,
135
+ Reset :now .Add (time .Hour ),
136
+ at :now ,
137
+ },
138
+ {
139
+ Name :"Partial" ,
140
+ Omit : []string {"x-ratelimit-remaining" },
141
+ ExpectNoMetrics :true ,
142
+ Limit :100 ,
143
+ Remaining :0 ,
144
+ Used :500 ,
145
+ Reset :now .Add (time .Hour ),
146
+ at :now ,
147
+ },
148
+ }
149
+
150
+ for _ ,c := range cases {
151
+ c := c
152
+ t .Run (c .Name ,func (t * testing.T ) {
153
+ t .Parallel ()
154
+
155
+ reg := prometheus .NewRegistry ()
156
+ idp := oidctest .NewFakeIDP (t ,oidctest .WithMiddlewares (
157
+ func (next http.Handler ) http.Handler {
158
+ return http .HandlerFunc (func (rw http.ResponseWriter ,r * http.Request ) {
159
+ if ! c .NoHeaders {
160
+ rw .Header ().Set ("x-ratelimit-limit" ,fmt .Sprintf ("%d" ,c .Limit ))
161
+ rw .Header ().Set ("x-ratelimit-remaining" ,fmt .Sprintf ("%d" ,c .Remaining ))
162
+ rw .Header ().Set ("x-ratelimit-used" ,fmt .Sprintf ("%d" ,c .Used ))
163
+ rw .Header ().Set ("x-ratelimit-resource" ,"core" )
164
+ rw .Header ().Set ("x-ratelimit-reset" ,fmt .Sprintf ("%d" ,c .Reset .Unix ()))
165
+ for _ ,omit := range c .Omit {
166
+ rw .Header ().Del (omit )
167
+ }
168
+ }
169
+
170
+ next .ServeHTTP (rw ,r )
171
+ })
172
+ }))
173
+
174
+ factory := promoauth .NewFactory (reg )
175
+ if ! c .at .IsZero () {
176
+ factory .Now = func () time.Time {
177
+ return c .at
178
+ }
179
+ }
180
+
181
+ cfg := factory .NewGithub ("test" ,idp .OIDCConfig (t , []string {}))
182
+
183
+ // Do a single oauth2 call
184
+ ctx := testutil .Context (t ,testutil .WaitShort )
185
+ ctx = context .WithValue (ctx ,oauth2 .HTTPClient ,idp .HTTPClient (nil ))
186
+ _ ,err := cfg .Exchange (ctx ,idp .CreateAuthCode (t ,"foo" ))
187
+ require .NoError (t ,err )
188
+
189
+ // Verify
190
+ labels := prometheus.Labels {
191
+ "name" :"test" ,
192
+ "resource" :"core" ,
193
+ }
194
+ pass := true
195
+ if ! c .ExpectNoMetrics {
196
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_total" ,labels ),c .Limit ,"limit" )
197
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_remaining" ,labels ),c .Remaining ,"remaining" )
198
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_used" ,labels ),c .Used ,"used" )
199
+ if ! c .at .IsZero () {
200
+ until := c .Reset .Sub (c .at )
201
+ // Float accuracy is not great, so we allow a delta of 2
202
+ pass = pass && assert .InDelta (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_reset_in_seconds" ,labels ),int (until .Seconds ()),2 ,"reset in" )
203
+ }
204
+ }else {
205
+ pass = pass && assert .Nil (t ,metricValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_total" ,labels ),"not exists" )
206
+ }
207
+
208
+ // Helpful debugging
209
+ if ! pass {
210
+ t .Log (registryDump (reg ))
211
+ }
212
+ })
213
+ }
214
+ }
215
+
216
+ func registryDump (reg * prometheus.Registry )string {
217
+ h := promhttp .HandlerFor (reg , promhttp.HandlerOpts {})
218
+ rec := httptest .NewRecorder ()
219
+ req ,_ := http .NewRequest (http .MethodGet ,"/" ,nil )
220
+ h .ServeHTTP (rec ,req )
221
+ resp := rec .Result ()
222
+ data ,_ := io .ReadAll (resp .Body )
223
+ _ = resp .Body .Close ()
224
+ return string (data )
72
225
}
73
226
74
227
func must [V any ](t * testing.T )func (v V ,err error )V {
@@ -78,3 +231,39 @@ func must[V any](t *testing.T) func(v V, err error) V {
78
231
return v
79
232
}
80
233
}
234
+
235
+ func gaugeValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )int {
236
+ labeled := metricValue (t ,reg ,metricName ,labels )
237
+ require .NotNilf (t ,labeled ,"metric %q with labels %v not found" ,metricName ,labels )
238
+ return int (labeled .GetGauge ().GetValue ())
239
+ }
240
+
241
+ func counterValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )int {
242
+ labeled := metricValue (t ,reg ,metricName ,labels )
243
+ require .NotNilf (t ,labeled ,"metric %q with labels %v not found" ,metricName ,labels )
244
+ return int (labeled .GetCounter ().GetValue ())
245
+ }
246
+
247
+ func compare (reg prometheus.Gatherer ,compare string )error {
248
+ return ptestutil .GatherAndCompare (reg ,strings .NewReader (compare ))
249
+ }
250
+
251
+ func metricValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )* io_prometheus_client.Metric {
252
+ metrics ,err := reg .Gather ()
253
+ require .NoError (t ,err )
254
+
255
+ for _ ,m := range metrics {
256
+ if m .GetName ()== metricName {
257
+ for _ ,labeled := range m .GetMetric () {
258
+ mLables := make (prometheus.Labels )
259
+ for _ ,v := range labeled .GetLabel () {
260
+ mLables [v .GetName ()]= v .GetValue ()
261
+ }
262
+ if maps .Equal (mLables ,labels ) {
263
+ return labeled
264
+ }
265
+ }
266
+ }
267
+ }
268
+ return nil
269
+ }