@@ -2,6 +2,8 @@ package aibridgedserver
2
2
3
3
import (
4
4
"context"
5
+ "crypto/sha256"
6
+ "crypto/subtle"
5
7
"database/sql"
6
8
"encoding/json"
7
9
"net/url"
@@ -22,12 +24,27 @@ import (
22
24
"github.com/coder/coder/v2/coderd/database/dbtime"
23
25
"github.com/coder/coder/v2/coderd/externalauth"
24
26
"github.com/coder/coder/v2/coderd/httpmw"
27
+ codermcp"github.com/coder/coder/v2/coderd/mcp"
25
28
"github.com/coder/coder/v2/codersdk"
26
29
)
27
30
28
31
var (
29
32
ErrExpiredOrInvalidOAuthToken = xerrors .New ("expired or invalid OAuth2 token" )
30
33
ErrNoMCPConfigFound = xerrors .New ("no MCP config found" )
34
+
35
+ // These errors are returned by IsAuthorized. Since they're just returned as
36
+ // a generic dRPC error, it's difficult to tell them apart without string
37
+ // matching.
38
+ // TODO: return these errors to the client in a more structured/comparable
39
+ // way.
40
+ ErrInvalidKey = xerrors .New ("invalid key" )
41
+ ErrUnknownKey = xerrors .New ("unknown key" )
42
+ ErrExpired = xerrors .New ("expired" )
43
+ ErrUnknownUser = xerrors .New ("unknown user" )
44
+ ErrDeletedUser = xerrors .New ("deleted user" )
45
+ ErrSystemUser = xerrors .New ("system user" )
46
+
47
+ ErrNoExternalAuthLinkFound = xerrors .New ("no external auth link found" )
31
48
)
32
49
33
50
var (
@@ -61,6 +78,8 @@ type Server struct {
61
78
accessURL string
62
79
externalAuthConfigs map [string ]* externalauth.Config
63
80
experiments codersdk.Experiments
81
+
82
+ coderMCPConfig * proto.MCPServerConfig // may be nil if not available
64
83
}
65
84
66
85
func NewServer (lifecycleCtx context.Context ,store store ,logger slog.Logger ,accessURL string ,externalAuthConfigs []* externalauth.Config ,experiments codersdk.Experiments ) (* Server ,error ) {
@@ -74,13 +93,19 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
74
93
eac [cfg .ID ]= cfg
75
94
}
76
95
96
+ coderMCPConfig ,err := getCoderMCPServerConfig (experiments ,accessURL )
97
+ if err != nil {
98
+ logger .Warn (lifecycleCtx ,"failed to retrieve coder MCP server config, Coder MCP will not be available" ,slog .Error (err ))
99
+ }
100
+
77
101
return & Server {
78
102
lifecycleCtx :lifecycleCtx ,
79
103
store :store ,
80
104
logger :logger .Named ("aibridgedserver" ),
81
105
accessURL :accessURL ,
82
106
externalAuthConfigs :eac ,
83
107
experiments :experiments ,
108
+ coderMCPConfig :coderMCPConfig ,
84
109
},nil
85
110
}
86
111
@@ -126,7 +151,7 @@ func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsag
126
151
ProviderResponseID :in .GetMsgId (),
127
152
InputTokens :in .GetInputTokens (),
128
153
OutputTokens :in .GetOutputTokens (),
129
- Metadata :s . marshalMetadata (in .GetMetadata ()),
154
+ Metadata :marshalMetadata (ctx , s . logger , in .GetMetadata ()),
130
155
CreatedAt :in .GetCreatedAt ().AsTime (),
131
156
})
132
157
if err != nil {
@@ -149,7 +174,7 @@ func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUs
149
174
InterceptionID :intcID ,
150
175
ProviderResponseID :in .GetMsgId (),
151
176
Prompt :in .GetPrompt (),
152
- Metadata :s . marshalMetadata (in .GetMetadata ()),
177
+ Metadata :marshalMetadata (ctx , s . logger , in .GetMetadata ()),
153
178
CreatedAt :in .GetCreatedAt ().AsTime (),
154
179
})
155
180
if err != nil {
@@ -175,8 +200,8 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
175
200
Tool :in .GetTool (),
176
201
Input :in .GetInput (),
177
202
Injected :in .GetInjected (),
178
- InvocationError : sql.NullString {String :in .GetInvocationError (),Valid :in .GetInvocationError () != "" },
179
- Metadata :s . marshalMetadata (in .GetMetadata ()),
203
+ InvocationError : sql.NullString {String :in .GetInvocationError (),Valid :in .InvocationError != nil },
204
+ Metadata :marshalMetadata (ctx , s . logger , in .GetMetadata ()),
180
205
CreatedAt :in .GetCreatedAt ().AsTime (),
181
206
})
182
207
if err != nil {
@@ -185,25 +210,7 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
185
210
return & proto.RecordToolUsageResponse {},nil
186
211
}
187
212
188
- func (s * Server )marshalMetadata (in map [string ]* anypb.Any ) []byte {
189
- mdMap := make (map [string ]any ,len (in ))
190
- for k ,v := range in {
191
- if v == nil {
192
- continue
193
- }
194
- var sv structpb.Value
195
- if err := v .UnmarshalTo (& sv );err == nil {
196
- mdMap [k ]= sv .AsInterface ()
197
- }
198
- }
199
- out ,err := json .Marshal (mdMap )
200
- if err != nil {
201
- s .logger .Warn (s .lifecycleCtx ,"failed to marshal metadata" ,slog .Error (err ))
202
- }
203
- return out
204
- }
205
-
206
- func (s * Server )GetMCPServerConfigs (ctx context.Context ,_ * proto.GetMCPServerConfigsRequest ) (* proto.GetMCPServerConfigsResponse ,error ) {
213
+ func (s * Server )GetMCPServerConfigs (_ context.Context ,_ * proto.GetMCPServerConfigsRequest ) (* proto.GetMCPServerConfigsResponse ,error ) {
207
214
cfgs := make ([]* proto.MCPServerConfig ,0 ,len (s .externalAuthConfigs ))
208
215
for _ ,eac := range s .externalAuthConfigs {
209
216
var allowlist ,denylist string
@@ -222,51 +229,25 @@ func (s *Server) GetMCPServerConfigs(ctx context.Context, _ *proto.GetMCPServerC
222
229
})
223
230
}
224
231
225
- coderMCPCfg ,err := s .getCoderMCPServerConfig ()
226
- if err != nil {
227
- s .logger .Warn (ctx ,"failed to retrieve coder MCP server config" ,slog .Error (err ))
228
- }
229
-
230
232
return & proto.GetMCPServerConfigsResponse {
231
- CoderMcpConfig :coderMCPCfg ,
233
+ CoderMcpConfig :s . coderMCPConfig , // it's fine if this is nil
232
234
ExternalAuthMcpConfigs :cfgs ,
233
235
},nil
234
236
}
235
237
236
- func (s * Server )getCoderMCPServerConfig () (* proto.MCPServerConfig ,error ) {
237
- // Both the MCP & OAuth2 experiments are currently required in order to use our
238
- // internal MCP server.
239
- if ! s .experiments .Enabled (codersdk .ExperimentMCPServerHTTP ) {
240
- return nil ,xerrors .Errorf ("%q experiment not enabled" ,codersdk .ExperimentMCPServerHTTP )
241
- }
242
- if ! s .experiments .Enabled (codersdk .ExperimentOAuth2 ) {
243
- return nil ,xerrors .Errorf ("%q experiment not enabled" ,codersdk .ExperimentOAuth2 )
244
- }
245
-
246
- u ,err := url .JoinPath (s .accessURL ,"/api/experimental/mcp/http" )
247
- if err != nil {
248
- return nil ,xerrors .Errorf ("build MCP URL with %q: %w" ,s .accessURL ,err )
249
- }
250
-
251
- return & proto.MCPServerConfig {
252
- Id :"coder" ,
253
- Url :u ,
254
- },nil
255
- }
256
-
257
238
func (s * Server )GetMCPServerAccessTokensBatch (ctx context.Context ,in * proto.GetMCPServerAccessTokensBatchRequest ) (* proto.GetMCPServerAccessTokensBatchResponse ,error ) {
258
239
if len (in .GetMcpServerConfigIds ())== 0 {
259
240
return & proto.GetMCPServerAccessTokensBatchResponse {},nil
260
241
}
261
242
262
- id ,err := uuid .Parse (in .GetUserId ())
243
+ userID ,err := uuid .Parse (in .GetUserId ())
263
244
if err != nil {
264
245
return nil ,xerrors .Errorf ("parse user_id: %w" ,err )
265
246
}
266
247
267
248
//nolint:gocritic // AIBridged has specific authz rules.
268
249
ctx = dbauthz .AsAIBridged (ctx )
269
- links ,err := s .store .GetExternalAuthLinksByUserID (ctx ,id )
250
+ links ,err := s .store .GetExternalAuthLinksByUserID (ctx ,userID )
270
251
if err != nil {
271
252
return nil ,xerrors .Errorf ("fetch external auth links: %w" ,err )
272
253
}
@@ -289,6 +270,7 @@ func (s *Server) GetMCPServerAccessTokensBatch(ctx context.Context, in *proto.Ge
289
270
tokenErrs = make (map [string ]string )
290
271
)
291
272
273
+ externalAuthLoop:
292
274
for _ ,id := range ids {
293
275
eac ,ok := s .externalAuthConfigs [id ]
294
276
if ! ok {
@@ -310,26 +292,32 @@ func (s *Server) GetMCPServerAccessTokensBatch(ctx context.Context, in *proto.Ge
310
292
defer wg .Done ()
311
293
312
294
// TODO: timeout.
313
- valid ,_ ,err := eac .ValidateToken (ctx ,link .OAuthToken ())
295
+ valid ,_ ,validateErr := eac .ValidateToken (ctx ,link .OAuthToken ())
314
296
mu .Lock ()
315
297
defer mu .Unlock ()
316
298
if ! valid {
317
299
// TODO: attempt refresh.
318
- s .logger .Warn (ctx ,"invalid/expired access token, cannot auto-configure MCP" ,slog .F ("provider" ,link .ProviderID ),slog .Error (err ))
300
+ s .logger .Warn (ctx ,"invalid/expired access token, cannot auto-configure MCP" ,slog .F ("provider" ,link .ProviderID ),slog .Error (validateErr ))
319
301
tokenErrs [id ]= ErrExpiredOrInvalidOAuthToken .Error ()
320
302
return
321
303
}
322
304
323
- if err != nil {
324
- errs = multierror .Append (errs ,err )
325
- tokenErrs [id ]= err .Error ()
305
+ if validateErr != nil {
306
+ errs = multierror .Append (errs ,validateErr )
307
+ tokenErrs [id ]= validateErr .Error ()
326
308
}else {
327
309
tokens [id ]= link .OAuthAccessToken
328
310
}
329
311
}()
330
312
331
- break
313
+ continue externalAuthLoop
332
314
}
315
+
316
+ // No link found for this external auth config, so include a generic
317
+ // error.
318
+ mu .Lock ()
319
+ tokenErrs [id ]= ErrNoExternalAuthLinkFound .Error ()
320
+ mu .Unlock ()
333
321
}
334
322
335
323
wg .Wait ()
@@ -357,16 +345,15 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
357
345
ctx = dbauthz .AsAIBridged (ctx )
358
346
359
347
// Key matches expected format.
360
- id , _ ,err := httpmw .SplitAPIToken (in .GetKey ())
348
+ keyID , keySecret ,err := httpmw .SplitAPIToken (in .GetKey ())
361
349
if err != nil {
362
- s .logger .Warn (ctx ,"invalid key provided" ,slog .Error (err ))
363
350
return nil ,ErrInvalidKey
364
351
}
365
352
366
353
// Key exists.
367
- key ,err := s .store .GetAPIKeyByID (ctx ,id )
354
+ key ,err := s .store .GetAPIKeyByID (ctx ,keyID )
368
355
if err != nil {
369
- s .logger .Warn (ctx ,"failed to retrieve API key by id" ,slog .F ("id " ,id ),slog .Error (err ))
356
+ s .logger .Warn (ctx ,"failed to retrieve API key by id" ,slog .F ("key_id " ,keyID ),slog .Error (err ))
370
357
return nil ,ErrUnknownKey
371
358
}
372
359
@@ -376,10 +363,16 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
376
363
return nil ,ErrExpired
377
364
}
378
365
366
+ // Key secret matches.
367
+ hashedSecret := sha256 .Sum256 ([]byte (keySecret ))
368
+ if subtle .ConstantTimeCompare (key .HashedSecret ,hashedSecret [:])!= 1 {
369
+ return nil ,ErrInvalidKey
370
+ }
371
+
379
372
// User exists.
380
373
user ,err := s .store .GetUserByID (ctx ,key .UserID )
381
374
if err != nil {
382
- s .logger .Warn (ctx ,"failed to retrieve API key user" ,slog .F ("user_id" ,key .UserID ),slog .Error (err ))
375
+ s .logger .Warn (ctx ,"failed to retrieve API key user" ,slog .F ("key_id" , keyID ), slog . F ( " user_id" ,key .UserID ),slog .Error (err ))
383
376
return nil ,ErrUnknownUser
384
377
}
385
378
@@ -396,11 +389,45 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
396
389
},nil
397
390
}
398
391
399
- var (
400
- ErrInvalidKey = xerrors .New ("invalid key format" )
401
- ErrUnknownKey = xerrors .New ("unknown key" )
402
- ErrExpired = xerrors .New ("expired" )
403
- ErrUnknownUser = xerrors .New ("unknown user" )
404
- ErrDeletedUser = xerrors .New ("deleted user" )
405
- ErrSystemUser = xerrors .New ("system user" )
406
- )
392
+ func getCoderMCPServerConfig (experiments codersdk.Experiments ,accessURL string ) (* proto.MCPServerConfig ,error ) {
393
+ // Both the MCP & OAuth2 experiments are currently required in order to use our
394
+ // internal MCP server.
395
+ if ! experiments .Enabled (codersdk .ExperimentMCPServerHTTP ) {
396
+ return nil ,xerrors .Errorf ("%q experiment not enabled" ,codersdk .ExperimentMCPServerHTTP )
397
+ }
398
+ if ! experiments .Enabled (codersdk .ExperimentOAuth2 ) {
399
+ return nil ,xerrors .Errorf ("%q experiment not enabled" ,codersdk .ExperimentOAuth2 )
400
+ }
401
+
402
+ u ,err := url .JoinPath (accessURL ,codermcp .MCPEndpoint )
403
+ if err != nil {
404
+ return nil ,xerrors .Errorf ("build MCP URL with %q: %w" ,accessURL ,err )
405
+ }
406
+
407
+ return & proto.MCPServerConfig {
408
+ Id :"coder" ,
409
+ Url :u ,
410
+ },nil
411
+ }
412
+
413
+ // marshalMetadata attempts to marshal the given metadata map into a
414
+ // JSON-encoded byte slice. If the marshaling fails, the function logs a
415
+ // warning and returns nil. The supplied context is only used for logging.
416
+ func marshalMetadata (ctx context.Context ,logger slog.Logger ,in map [string ]* anypb.Any ) []byte {
417
+ mdMap := make (map [string ]any ,len (in ))
418
+ for k ,v := range in {
419
+ if v == nil {
420
+ continue
421
+ }
422
+ var sv structpb.Value
423
+ if err := v .UnmarshalTo (& sv );err == nil {
424
+ mdMap [k ]= sv .AsInterface ()
425
+ }
426
+ }
427
+ out ,err := json .Marshal (mdMap )
428
+ if err != nil {
429
+ logger .Warn (ctx ,"failed to marshal aibridge metadata from proto to JSON" ,slog .F ("metadata" ,in ),slog .Error (err ))
430
+ return nil
431
+ }
432
+ return out
433
+ }