Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0971409

Browse files
deansheatherdannykopping
authored andcommitted
PR comments
1 parentdba62c9 commit0971409

File tree

5 files changed

+423
-177
lines changed

5 files changed

+423
-177
lines changed

‎coderd/aibridgedserver/aibridgedserver.go‎

Lines changed: 98 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package aibridgedserver
22

33
import (
44
"context"
5+
"crypto/sha256"
6+
"crypto/subtle"
57
"database/sql"
68
"encoding/json"
79
"net/url"
@@ -22,12 +24,27 @@ import (
2224
"github.com/coder/coder/v2/coderd/database/dbtime"
2325
"github.com/coder/coder/v2/coderd/externalauth"
2426
"github.com/coder/coder/v2/coderd/httpmw"
27+
codermcp"github.com/coder/coder/v2/coderd/mcp"
2528
"github.com/coder/coder/v2/codersdk"
2629
)
2730

2831
var (
2932
ErrExpiredOrInvalidOAuthToken=xerrors.New("expired or invalid OAuth2 token")
3033
ErrNoMCPConfigFound=xerrors.New("no MCP config found")
34+
35+
// These errors are returned by IsAuthorized. Since they're just returned as
36+
// a generic dRPC error, it's difficult to tell them apart without string
37+
// matching.
38+
// TODO: return these errors to the client in a more structured/comparable
39+
// way.
40+
ErrInvalidKey=xerrors.New("invalid key")
41+
ErrUnknownKey=xerrors.New("unknown key")
42+
ErrExpired=xerrors.New("expired")
43+
ErrUnknownUser=xerrors.New("unknown user")
44+
ErrDeletedUser=xerrors.New("deleted user")
45+
ErrSystemUser=xerrors.New("system user")
46+
47+
ErrNoExternalAuthLinkFound=xerrors.New("no external auth link found")
3148
)
3249

3350
var (
@@ -61,6 +78,8 @@ type Server struct {
6178
accessURLstring
6279
externalAuthConfigsmap[string]*externalauth.Config
6380
experiments codersdk.Experiments
81+
82+
coderMCPConfig*proto.MCPServerConfig// may be nil if not available
6483
}
6584

6685
funcNewServer(lifecycleCtx context.Context,storestore,logger slog.Logger,accessURLstring,externalAuthConfigs []*externalauth.Config,experiments codersdk.Experiments) (*Server,error) {
@@ -74,13 +93,19 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
7493
eac[cfg.ID]=cfg
7594
}
7695

96+
coderMCPConfig,err:=getCoderMCPServerConfig(experiments,accessURL)
97+
iferr!=nil {
98+
logger.Warn(lifecycleCtx,"failed to retrieve coder MCP server config, Coder MCP will not be available",slog.Error(err))
99+
}
100+
77101
return&Server{
78102
lifecycleCtx:lifecycleCtx,
79103
store:store,
80104
logger:logger.Named("aibridgedserver"),
81105
accessURL:accessURL,
82106
externalAuthConfigs:eac,
83107
experiments:experiments,
108+
coderMCPConfig:coderMCPConfig,
84109
},nil
85110
}
86111

@@ -126,7 +151,7 @@ func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsag
126151
ProviderResponseID:in.GetMsgId(),
127152
InputTokens:in.GetInputTokens(),
128153
OutputTokens:in.GetOutputTokens(),
129-
Metadata:s.marshalMetadata(in.GetMetadata()),
154+
Metadata:marshalMetadata(ctx,s.logger,in.GetMetadata()),
130155
CreatedAt:in.GetCreatedAt().AsTime(),
131156
})
132157
iferr!=nil {
@@ -149,7 +174,7 @@ func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUs
149174
InterceptionID:intcID,
150175
ProviderResponseID:in.GetMsgId(),
151176
Prompt:in.GetPrompt(),
152-
Metadata:s.marshalMetadata(in.GetMetadata()),
177+
Metadata:marshalMetadata(ctx,s.logger,in.GetMetadata()),
153178
CreatedAt:in.GetCreatedAt().AsTime(),
154179
})
155180
iferr!=nil {
@@ -175,8 +200,8 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
175200
Tool:in.GetTool(),
176201
Input:in.GetInput(),
177202
Injected:in.GetInjected(),
178-
InvocationError: sql.NullString{String:in.GetInvocationError(),Valid:in.GetInvocationError()!=""},
179-
Metadata:s.marshalMetadata(in.GetMetadata()),
203+
InvocationError: sql.NullString{String:in.GetInvocationError(),Valid:in.InvocationError!=nil},
204+
Metadata:marshalMetadata(ctx,s.logger,in.GetMetadata()),
180205
CreatedAt:in.GetCreatedAt().AsTime(),
181206
})
182207
iferr!=nil {
@@ -185,25 +210,7 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
185210
return&proto.RecordToolUsageResponse{},nil
186211
}
187212

188-
func (s*Server)marshalMetadata(inmap[string]*anypb.Any) []byte {
189-
mdMap:=make(map[string]any,len(in))
190-
fork,v:=rangein {
191-
ifv==nil {
192-
continue
193-
}
194-
varsv structpb.Value
195-
iferr:=v.UnmarshalTo(&sv);err==nil {
196-
mdMap[k]=sv.AsInterface()
197-
}
198-
}
199-
out,err:=json.Marshal(mdMap)
200-
iferr!=nil {
201-
s.logger.Warn(s.lifecycleCtx,"failed to marshal metadata",slog.Error(err))
202-
}
203-
returnout
204-
}
205-
206-
func (s*Server)GetMCPServerConfigs(ctx context.Context,_*proto.GetMCPServerConfigsRequest) (*proto.GetMCPServerConfigsResponse,error) {
213+
func (s*Server)GetMCPServerConfigs(_ context.Context,_*proto.GetMCPServerConfigsRequest) (*proto.GetMCPServerConfigsResponse,error) {
207214
cfgs:=make([]*proto.MCPServerConfig,0,len(s.externalAuthConfigs))
208215
for_,eac:=ranges.externalAuthConfigs {
209216
varallowlist,denyliststring
@@ -222,51 +229,25 @@ func (s *Server) GetMCPServerConfigs(ctx context.Context, _ *proto.GetMCPServerC
222229
})
223230
}
224231

225-
coderMCPCfg,err:=s.getCoderMCPServerConfig()
226-
iferr!=nil {
227-
s.logger.Warn(ctx,"failed to retrieve coder MCP server config",slog.Error(err))
228-
}
229-
230232
return&proto.GetMCPServerConfigsResponse{
231-
CoderMcpConfig:coderMCPCfg,
233+
CoderMcpConfig:s.coderMCPConfig,// it's fine if this is nil
232234
ExternalAuthMcpConfigs:cfgs,
233235
},nil
234236
}
235237

236-
func (s*Server)getCoderMCPServerConfig() (*proto.MCPServerConfig,error) {
237-
// Both the MCP & OAuth2 experiments are currently required in order to use our
238-
// internal MCP server.
239-
if!s.experiments.Enabled(codersdk.ExperimentMCPServerHTTP) {
240-
returnnil,xerrors.Errorf("%q experiment not enabled",codersdk.ExperimentMCPServerHTTP)
241-
}
242-
if!s.experiments.Enabled(codersdk.ExperimentOAuth2) {
243-
returnnil,xerrors.Errorf("%q experiment not enabled",codersdk.ExperimentOAuth2)
244-
}
245-
246-
u,err:=url.JoinPath(s.accessURL,"/api/experimental/mcp/http")
247-
iferr!=nil {
248-
returnnil,xerrors.Errorf("build MCP URL with %q: %w",s.accessURL,err)
249-
}
250-
251-
return&proto.MCPServerConfig{
252-
Id:"coder",
253-
Url:u,
254-
},nil
255-
}
256-
257238
func (s*Server)GetMCPServerAccessTokensBatch(ctx context.Context,in*proto.GetMCPServerAccessTokensBatchRequest) (*proto.GetMCPServerAccessTokensBatchResponse,error) {
258239
iflen(in.GetMcpServerConfigIds())==0 {
259240
return&proto.GetMCPServerAccessTokensBatchResponse{},nil
260241
}
261242

262-
id,err:=uuid.Parse(in.GetUserId())
243+
userID,err:=uuid.Parse(in.GetUserId())
263244
iferr!=nil {
264245
returnnil,xerrors.Errorf("parse user_id: %w",err)
265246
}
266247

267248
//nolint:gocritic // AIBridged has specific authz rules.
268249
ctx=dbauthz.AsAIBridged(ctx)
269-
links,err:=s.store.GetExternalAuthLinksByUserID(ctx,id)
250+
links,err:=s.store.GetExternalAuthLinksByUserID(ctx,userID)
270251
iferr!=nil {
271252
returnnil,xerrors.Errorf("fetch external auth links: %w",err)
272253
}
@@ -289,6 +270,7 @@ func (s *Server) GetMCPServerAccessTokensBatch(ctx context.Context, in *proto.Ge
289270
tokenErrs=make(map[string]string)
290271
)
291272

273+
externalAuthLoop:
292274
for_,id:=rangeids {
293275
eac,ok:=s.externalAuthConfigs[id]
294276
if!ok {
@@ -310,26 +292,32 @@ func (s *Server) GetMCPServerAccessTokensBatch(ctx context.Context, in *proto.Ge
310292
deferwg.Done()
311293

312294
// TODO: timeout.
313-
valid,_,err:=eac.ValidateToken(ctx,link.OAuthToken())
295+
valid,_,validateErr:=eac.ValidateToken(ctx,link.OAuthToken())
314296
mu.Lock()
315297
defermu.Unlock()
316298
if!valid {
317299
// TODO: attempt refresh.
318-
s.logger.Warn(ctx,"invalid/expired access token, cannot auto-configure MCP",slog.F("provider",link.ProviderID),slog.Error(err))
300+
s.logger.Warn(ctx,"invalid/expired access token, cannot auto-configure MCP",slog.F("provider",link.ProviderID),slog.Error(validateErr))
319301
tokenErrs[id]=ErrExpiredOrInvalidOAuthToken.Error()
320302
return
321303
}
322304

323-
iferr!=nil {
324-
errs=multierror.Append(errs,err)
325-
tokenErrs[id]=err.Error()
305+
ifvalidateErr!=nil {
306+
errs=multierror.Append(errs,validateErr)
307+
tokenErrs[id]=validateErr.Error()
326308
}else {
327309
tokens[id]=link.OAuthAccessToken
328310
}
329311
}()
330312

331-
break
313+
continue externalAuthLoop
332314
}
315+
316+
// No link found for this external auth config, so include a generic
317+
// error.
318+
mu.Lock()
319+
tokenErrs[id]=ErrNoExternalAuthLinkFound.Error()
320+
mu.Unlock()
333321
}
334322

335323
wg.Wait()
@@ -357,16 +345,15 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
357345
ctx=dbauthz.AsAIBridged(ctx)
358346

359347
// Key matches expected format.
360-
id,_,err:=httpmw.SplitAPIToken(in.GetKey())
348+
keyID,keySecret,err:=httpmw.SplitAPIToken(in.GetKey())
361349
iferr!=nil {
362-
s.logger.Warn(ctx,"invalid key provided",slog.Error(err))
363350
returnnil,ErrInvalidKey
364351
}
365352

366353
// Key exists.
367-
key,err:=s.store.GetAPIKeyByID(ctx,id)
354+
key,err:=s.store.GetAPIKeyByID(ctx,keyID)
368355
iferr!=nil {
369-
s.logger.Warn(ctx,"failed to retrieve API key by id",slog.F("id",id),slog.Error(err))
356+
s.logger.Warn(ctx,"failed to retrieve API key by id",slog.F("key_id",keyID),slog.Error(err))
370357
returnnil,ErrUnknownKey
371358
}
372359

@@ -376,10 +363,16 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
376363
returnnil,ErrExpired
377364
}
378365

366+
// Key secret matches.
367+
hashedSecret:=sha256.Sum256([]byte(keySecret))
368+
ifsubtle.ConstantTimeCompare(key.HashedSecret,hashedSecret[:])!=1 {
369+
returnnil,ErrInvalidKey
370+
}
371+
379372
// User exists.
380373
user,err:=s.store.GetUserByID(ctx,key.UserID)
381374
iferr!=nil {
382-
s.logger.Warn(ctx,"failed to retrieve API key user",slog.F("user_id",key.UserID),slog.Error(err))
375+
s.logger.Warn(ctx,"failed to retrieve API key user",slog.F("key_id",keyID),slog.F("user_id",key.UserID),slog.Error(err))
383376
returnnil,ErrUnknownUser
384377
}
385378

@@ -396,11 +389,45 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest
396389
},nil
397390
}
398391

399-
var (
400-
ErrInvalidKey=xerrors.New("invalid key format")
401-
ErrUnknownKey=xerrors.New("unknown key")
402-
ErrExpired=xerrors.New("expired")
403-
ErrUnknownUser=xerrors.New("unknown user")
404-
ErrDeletedUser=xerrors.New("deleted user")
405-
ErrSystemUser=xerrors.New("system user")
406-
)
392+
funcgetCoderMCPServerConfig(experiments codersdk.Experiments,accessURLstring) (*proto.MCPServerConfig,error) {
393+
// Both the MCP & OAuth2 experiments are currently required in order to use our
394+
// internal MCP server.
395+
if!experiments.Enabled(codersdk.ExperimentMCPServerHTTP) {
396+
returnnil,xerrors.Errorf("%q experiment not enabled",codersdk.ExperimentMCPServerHTTP)
397+
}
398+
if!experiments.Enabled(codersdk.ExperimentOAuth2) {
399+
returnnil,xerrors.Errorf("%q experiment not enabled",codersdk.ExperimentOAuth2)
400+
}
401+
402+
u,err:=url.JoinPath(accessURL,codermcp.MCPEndpoint)
403+
iferr!=nil {
404+
returnnil,xerrors.Errorf("build MCP URL with %q: %w",accessURL,err)
405+
}
406+
407+
return&proto.MCPServerConfig{
408+
Id:"coder",
409+
Url:u,
410+
},nil
411+
}
412+
413+
// marshalMetadata attempts to marshal the given metadata map into a
414+
// JSON-encoded byte slice. If the marshaling fails, the function logs a
415+
// warning and returns nil. The supplied context is only used for logging.
416+
funcmarshalMetadata(ctx context.Context,logger slog.Logger,inmap[string]*anypb.Any) []byte {
417+
mdMap:=make(map[string]any,len(in))
418+
fork,v:=rangein {
419+
ifv==nil {
420+
continue
421+
}
422+
varsv structpb.Value
423+
iferr:=v.UnmarshalTo(&sv);err==nil {
424+
mdMap[k]=sv.AsInterface()
425+
}
426+
}
427+
out,err:=json.Marshal(mdMap)
428+
iferr!=nil {
429+
logger.Warn(ctx,"failed to marshal aibridge metadata from proto to JSON",slog.F("metadata",in),slog.Error(err))
430+
returnnil
431+
}
432+
returnout
433+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp