1
1
package promoauth_test
2
2
3
3
import (
4
+ "context"
5
+ "fmt"
6
+ "io"
4
7
"net/http"
8
+ "net/http/httptest"
9
+ "strings"
5
10
"testing"
6
11
"time"
7
12
8
13
"github.com/prometheus/client_golang/prometheus"
14
+ "github.com/prometheus/client_golang/prometheus/promhttp"
9
15
ptestutil"github.com/prometheus/client_golang/prometheus/testutil"
16
+ io_prometheus_client"github.com/prometheus/client_model/go"
17
+ "github.com/stretchr/testify/assert"
10
18
"github.com/stretchr/testify/require"
19
+ "golang.org/x/exp/maps"
20
+ "golang.org/x/oauth2"
11
21
12
22
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
13
23
"github.com/coder/coder/v2/coderd/externalauth"
@@ -21,44 +31,58 @@ func TestInstrument(t *testing.T) {
21
31
ctx := testutil .Context (t ,testutil .WaitShort )
22
32
idp := oidctest .NewFakeIDP (t ,oidctest .WithServing ())
23
33
reg := prometheus .NewRegistry ()
24
- count := func ()int {
25
- return ptestutil .CollectAndCount (reg ,"coderd_oauth2_external_requests_total" )
34
+ t .Cleanup (func () {
35
+ if t .Failed () {
36
+ t .Log (registryDump (reg ))
37
+ }
38
+ })
39
+
40
+ const id = "test"
41
+ labels := prometheus.Labels {
42
+ "name" :id ,
43
+ "status_code" :"200" ,
44
+ }
45
+ const metricname = "coderd_oauth2_external_requests_total"
46
+ count := func (source string )int {
47
+ labels ["source" ]= source
48
+ return counterValue (t ,reg ,"coderd_oauth2_external_requests_total" ,labels )
26
49
}
27
50
28
51
factory := promoauth .NewFactory (reg )
29
- const id = "test"
52
+
30
53
cfg := externalauth.Config {
31
54
InstrumentedOAuth2Config :factory .New (id ,idp .OIDCConfig (t , []string {})),
32
55
ID :"test" ,
33
56
ValidateURL :must (idp .IssuerURL ().Parse ("/oauth2/userinfo" )).String (),
34
57
}
35
58
36
59
// 0 Requests before we start
37
- require .Equal (t ,count ( ),0 )
60
+ require .Nil (t ,metricValue ( t , reg , metricname , labels ),"no metrics at start" )
38
61
39
62
// Exchange should trigger a request
40
63
code := idp .CreateAuthCode (t ,"foo" )
41
64
token ,err := cfg .Exchange (ctx ,code )
42
65
require .NoError (t ,err )
43
- require .Equal (t ,count (),1 )
66
+ require .Equal (t ,count ("Exchange" ),1 )
44
67
45
68
// Force a refresh
46
69
token .Expiry = time .Now ().Add (time .Hour * - 1 )
47
70
src := cfg .TokenSource (ctx ,token )
48
71
refreshed ,err := src .Token ()
49
72
require .NoError (t ,err )
50
73
require .NotEqual (t ,token .AccessToken ,refreshed .AccessToken ,"token refreshed" )
51
- require .Equal (t ,count (),2 )
74
+ require .Equal (t ,count ("TokenSource" ),1 )
52
75
53
76
// Try a validate
54
77
valid ,_ ,err := cfg .ValidateToken (ctx ,refreshed .AccessToken )
55
78
require .NoError (t ,err )
56
79
require .True (t ,valid )
57
- require .Equal (t ,count (),3 )
80
+ require .Equal (t ,count ("ValidateToken" ),1 )
58
81
59
82
// Verify the default client was not broken. This check is added because we
60
83
// extend the http.DefaultTransport. If a `.Clone()` is not done, this can be
61
84
// mis-used. It is cheap to run this quick check.
85
+ snapshot := registryDump (reg )
62
86
req ,err := http .NewRequest (http .MethodGet ,
63
87
must (idp .IssuerURL ().Parse ("/.well-known/openid-configuration" )).String (),nil )
64
88
require .NoError (t ,err )
@@ -68,7 +92,137 @@ func TestInstrument(t *testing.T) {
68
92
require .NoError (t ,err )
69
93
_ = resp .Body .Close ()
70
94
71
- require .Equal (t ,count (),3 )
95
+ require .NoError (t ,compare (reg ,snapshot ),"no metric changes" )
96
+ }
97
+
98
+ func TestGithubRateLimits (t * testing.T ) {
99
+ t .Parallel ()
100
+
101
+ now := time .Now ()
102
+ cases := []struct {
103
+ Name string
104
+ NoHeaders bool
105
+ Omit []string
106
+ ExpectNoMetrics bool
107
+ Limit int
108
+ Remaining int
109
+ Used int
110
+ Reset time.Time
111
+
112
+ at time.Time
113
+ }{
114
+ {
115
+ Name :"NoHeaders" ,
116
+ NoHeaders :true ,
117
+ ExpectNoMetrics :true ,
118
+ },
119
+ {
120
+ Name :"ZeroHeaders" ,
121
+ ExpectNoMetrics :true ,
122
+ },
123
+ {
124
+ Name :"OverLimit" ,
125
+ Limit :100 ,
126
+ Remaining :0 ,
127
+ Used :500 ,
128
+ Reset :now .Add (time .Hour ),
129
+ at :now ,
130
+ },
131
+ {
132
+ Name :"UnderLimit" ,
133
+ Limit :100 ,
134
+ Remaining :0 ,
135
+ Used :500 ,
136
+ Reset :now .Add (time .Hour ),
137
+ at :now ,
138
+ },
139
+ {
140
+ Name :"Partial" ,
141
+ Omit : []string {"x-ratelimit-remaining" },
142
+ ExpectNoMetrics :true ,
143
+ Limit :100 ,
144
+ Remaining :0 ,
145
+ Used :500 ,
146
+ Reset :now .Add (time .Hour ),
147
+ at :now ,
148
+ },
149
+ }
150
+
151
+ for _ ,c := range cases {
152
+ c := c
153
+ t .Run (c .Name ,func (t * testing.T ) {
154
+ t .Parallel ()
155
+
156
+ reg := prometheus .NewRegistry ()
157
+ idp := oidctest .NewFakeIDP (t ,oidctest .WithMiddlewares (
158
+ func (next http.Handler ) http.Handler {
159
+ return http .HandlerFunc (func (rw http.ResponseWriter ,r * http.Request ) {
160
+ if ! c .NoHeaders {
161
+ rw .Header ().Set ("x-ratelimit-limit" ,fmt .Sprintf ("%d" ,c .Limit ))
162
+ rw .Header ().Set ("x-ratelimit-remaining" ,fmt .Sprintf ("%d" ,c .Remaining ))
163
+ rw .Header ().Set ("x-ratelimit-used" ,fmt .Sprintf ("%d" ,c .Used ))
164
+ rw .Header ().Set ("x-ratelimit-resource" ,"core" )
165
+ rw .Header ().Set ("x-ratelimit-reset" ,fmt .Sprintf ("%d" ,c .Reset .Unix ()))
166
+ for _ ,omit := range c .Omit {
167
+ rw .Header ().Del (omit )
168
+ }
169
+ }
170
+
171
+ next .ServeHTTP (rw ,r )
172
+ })
173
+ }))
174
+
175
+ factory := promoauth .NewFactory (reg )
176
+ if ! c .at .IsZero () {
177
+ factory .Now = func () time.Time {
178
+ return c .at
179
+ }
180
+ }
181
+
182
+ cfg := factory .NewGithub ("test" ,idp .OIDCConfig (t , []string {}))
183
+
184
+ // Do a single oauth2 call
185
+ ctx := testutil .Context (t ,testutil .WaitShort )
186
+ ctx = context .WithValue (ctx ,oauth2 .HTTPClient ,idp .HTTPClient (nil ))
187
+ _ ,err := cfg .Exchange (ctx ,idp .CreateAuthCode (t ,"foo" ))
188
+ require .NoError (t ,err )
189
+
190
+ // Verify
191
+ labels := prometheus.Labels {
192
+ "name" :"test" ,
193
+ "resource" :"core" ,
194
+ }
195
+ pass := true
196
+ if ! c .ExpectNoMetrics {
197
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_total" ,labels ),c .Limit ,"limit" )
198
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_remaining" ,labels ),c .Remaining ,"remaining" )
199
+ pass = pass && assert .Equal (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_used" ,labels ),c .Used ,"used" )
200
+ if ! c .at .IsZero () {
201
+ until := c .Reset .Sub (c .at )
202
+ // Float accuracy is not great, so we allow a delta of 2
203
+ pass = pass && assert .InDelta (t ,gaugeValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_reset_in_seconds" ,labels ),int (until .Seconds ()),2 ,"reset in" )
204
+ }
205
+ }else {
206
+ pass = pass && assert .Nil (t ,metricValue (t ,reg ,"coderd_oauth2_external_requests_rate_limit_total" ,labels ),"not exists" )
207
+ }
208
+
209
+ // Helpful debugging
210
+ if ! pass {
211
+ t .Log (registryDump (reg ))
212
+ }
213
+ })
214
+ }
215
+ }
216
+
217
+ func registryDump (reg * prometheus.Registry )string {
218
+ h := promhttp .HandlerFor (reg , promhttp.HandlerOpts {})
219
+ rec := httptest .NewRecorder ()
220
+ req ,_ := http .NewRequest (http .MethodGet ,"/" ,nil )
221
+ h .ServeHTTP (rec ,req )
222
+ resp := rec .Result ()
223
+ data ,_ := io .ReadAll (resp .Body )
224
+ _ = resp .Body .Close ()
225
+ return string (data )
72
226
}
73
227
74
228
func must [V any ](v V ,err error )V {
@@ -77,3 +231,39 @@ func must[V any](v V, err error) V {
77
231
}
78
232
return v
79
233
}
234
+
235
+ func gaugeValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )int {
236
+ labeled := metricValue (t ,reg ,metricName ,labels )
237
+ require .NotNilf (t ,labeled ,"metric %q with labels %v not found" ,metricName ,labels )
238
+ return int (labeled .GetGauge ().GetValue ())
239
+ }
240
+
241
+ func counterValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )int {
242
+ labeled := metricValue (t ,reg ,metricName ,labels )
243
+ require .NotNilf (t ,labeled ,"metric %q with labels %v not found" ,metricName ,labels )
244
+ return int (labeled .GetCounter ().GetValue ())
245
+ }
246
+
247
+ func compare (reg prometheus.Gatherer ,compare string )error {
248
+ return ptestutil .GatherAndCompare (reg ,strings .NewReader (compare ))
249
+ }
250
+
251
+ func metricValue (t testing.TB ,reg prometheus.Gatherer ,metricName string ,labels prometheus.Labels )* io_prometheus_client.Metric {
252
+ metrics ,err := reg .Gather ()
253
+ require .NoError (t ,err )
254
+
255
+ for _ ,m := range metrics {
256
+ if m .GetName ()== metricName {
257
+ for _ ,labeled := range m .GetMetric () {
258
+ mLables := make (prometheus.Labels )
259
+ for _ ,v := range labeled .GetLabel () {
260
+ mLables [v .GetName ()]= v .GetValue ()
261
+ }
262
+ if maps .Equal (mLables ,labels ) {
263
+ return labeled
264
+ }
265
+ }
266
+ }
267
+ }
268
+ return nil
269
+ }