Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

feat: aibridged mcp handling#19911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
dannykopping merged 1 commit intomainfromdk/aibridged-mcp
Sep 25, 2025
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletionenterprise/x/aibridged/aibridged.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -139,7 +139,7 @@ func (s *Server) GetRequestHandler(ctx context.Context, req Request) (http.Handl
returnnil,xerrors.New("nil requestBridgePool")
}

reqBridge,err:=s.requestBridgePool.Acquire(ctx,req,s.Client)
reqBridge,err:=s.requestBridgePool.Acquire(ctx,req,s.Client,NewMCPProxyFactory(s.logger,s.Client))
iferr!=nil {
returnnil,xerrors.Errorf("acquire request bridge: %w",err)
}
Expand Down
2 changes: 1 addition & 1 deletionenterprise/x/aibridged/aibridged_test.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -122,7 +122,7 @@ func TestServeHTTP_FailureModes(t *testing.T) {
// Should pass authorization.
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
// But fail when acquiring a pool instance.
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("oops"))
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("oops"))
},
expectedErr: aibridged.ErrAcquireRequestHandler,
expectedStatus: http.StatusInternalServerError,
Expand Down
8 changes: 4 additions & 4 deletionsenterprise/x/aibridged/aibridgedmock/poolmock.go
View file
Open in desktop

Some generated files are not rendered by default. Learn more abouthow customized files appear on GitHub.

191 changes: 191 additions & 0 deletionsenterprise/x/aibridged/mcp.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
package aibridged

import (
"context"
"fmt"
"regexp"
"time"

"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/aibridge/mcp"
"github.com/coder/coder/v2/enterprise/x/aibridged/proto"
)

var (
ErrEmptyConfig = xerrors.New("empty config given")
ErrCompileRegex = xerrors.New("compile tool regex")
)

const (
InternalMCPServerID = "coder"
)

type MCPProxyBuilder interface {
// Build creates a [mcp.ServerProxier] for the given request initiator.
// At minimum, the Coder MCP server will be proxied.
// The SessionKey from [Request] is used to authenticate against the Coder MCP server.
//
// NOTE: the [mcp.ServerProxier] instance may be proxying one or more MCP servers.
Build(ctx context.Context, req Request) (mcp.ServerProxier, error)
}

var _ MCPProxyBuilder = &MCPProxyFactory{}

type MCPProxyFactory struct {
logger slog.Logger
clientFn ClientFunc
}

func NewMCPProxyFactory(logger slog.Logger, clientFn ClientFunc) *MCPProxyFactory {
return &MCPProxyFactory{
logger: logger,
clientFn: clientFn,
}
}

func (m *MCPProxyFactory) Build(ctx context.Context, req Request) (mcp.ServerProxier, error) {
proxiers, err := m.retrieveMCPServerConfigs(ctx, req)
if err != nil {
return nil, xerrors.Errorf("resolve configs: %w", err)
}

return mcp.NewServerProxyManager(proxiers), nil
}

func (m *MCPProxyFactory) retrieveMCPServerConfigs(ctx context.Context, req Request) (map[string]mcp.ServerProxier, error) {
client, err := m.clientFn()
if err != nil {
return nil, xerrors.Errorf("acquire client: %w", err)
}

srvCfgCtx, srvCfgCancel := context.WithTimeout(ctx, time.Second*10)
defer srvCfgCancel()

// Fetch MCP server configs.
mcpSrvCfgs, err := client.GetMCPServerConfigs(srvCfgCtx, &proto.GetMCPServerConfigsRequest{
UserId: req.InitiatorID.String(),
})
if err != nil {
return nil, xerrors.Errorf("get MCP server configs: %w", err)
}

proxiers := make(map[string]mcp.ServerProxier, len(mcpSrvCfgs.GetExternalAuthMcpConfigs())+1) // Extra one for Coder MCP server.

if mcpSrvCfgs.GetCoderMcpConfig() != nil {
// Setup the Coder MCP server proxy.
coderMCPProxy, err := m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(), req.SessionKey) // The session key is used to auth against our internal MCP server.
if err != nil {
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()), slog.Error(err))
} else {
proxiers[InternalMCPServerID] = coderMCPProxy
}
}

if len(mcpSrvCfgs.GetExternalAuthMcpConfigs()) == 0 {
return proxiers, nil
}

serverIDs := make([]string, 0, len(mcpSrvCfgs.GetExternalAuthMcpConfigs()))
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() {
serverIDs = append(serverIDs, cfg.GetId())
}

accTokCtx, accTokCancel := context.WithTimeout(ctx, time.Second*10)
defer accTokCancel()

// Request a batch of access tokens, one per given server ID.
resp, err := client.GetMCPServerAccessTokensBatch(accTokCtx, &proto.GetMCPServerAccessTokensBatchRequest{
UserId: req.InitiatorID.String(),
McpServerConfigIds: serverIDs,
})
if err != nil {
m.logger.Warn(ctx, "failed to retrieve access token(s)", slog.F("server_ids", serverIDs), slog.Error(err))
}

if resp == nil {
m.logger.Warn(ctx, "nil response given to mcp access tokens call")
return proxiers, nil
}
tokens := resp.GetAccessTokens()
if len(tokens) == 0 {
return proxiers, nil
}

// Iterate over all External Auth configurations which are configured for MCP and attempt to setup
// a [mcp.ServerProxier] for it using the access token retrieved above.
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() {
if err, ok := resp.GetErrors()[cfg.GetId()]; ok {
m.logger.Debug(ctx, "failed to get access token", slog.F("mcp_server_id", cfg.GetId()), slog.F("error", err))
continue
}

token, ok := tokens[cfg.GetId()]
if !ok {
m.logger.Warn(ctx, "no access token found", slog.F("mcp_server_id", cfg.GetId()))
continue
}

proxy, err := m.newStreamableHTTPServerProxy(cfg, token)
if err != nil {
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", cfg.GetId()), slog.Error(err))
continue
}

proxiers[cfg.Id] = proxy
}
return proxiers, nil
}

// newStreamableHTTPServerProxy creates an MCP server capable of proxying requests using the Streamable HTTP transport.
//
// TODO: support SSE transport.
func (m *MCPProxyFactory) newStreamableHTTPServerProxy(cfg *proto.MCPServerConfig, accessToken string) (mcp.ServerProxier, error) {
if cfg == nil {
return nil, ErrEmptyConfig
}

var (
allowlist, denylist *regexp.Regexp
err error
)
if cfg.GetToolAllowRegex() != "" {
allowlist, err = regexp.Compile(cfg.GetToolAllowRegex())
if err != nil {
return nil, ErrCompileRegex
}
}
if cfg.GetToolDenyRegex() != "" {
denylist, err = regexp.Compile(cfg.GetToolDenyRegex())
if err != nil {
return nil, ErrCompileRegex
}
}

// TODO: future improvement:
//
// The access token provided here may expire at any time, or the connection to the MCP server could be severed.
// Instead of passing through an access token directly, rather provide an interface through which to retrieve
// an access token imperatively. In the event of a tool call failing, we could Ping() the MCP server to establish
// whether the connection is still active. If not, this indicates that the access token is probably expired/revoked.
// (It could also mean the server has a problem, which we should account for.)
// The proxy could then use its interface to retrieve a new access token and re-establish a connection.
// For now though, the short TTL of this cache should mostly mask this problem.
srv, err := mcp.NewStreamableHTTPServerProxy(
m.logger.Named(fmt.Sprintf("mcp-server-proxy-%s", cfg.GetId())),
cfg.GetId(),
cfg.GetUrl(),
// See https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#token-requirements.
map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", accessToken),
},
allowlist,
denylist,
)
if err != nil {
return nil, xerrors.Errorf("create streamable HTTP MCP server proxy: %w", err)
}

return srv, nil
}
61 changes: 61 additions & 0 deletionsenterprise/x/aibridged/mcp_internal_test.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
package aibridged

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/enterprise/x/aibridged/proto"
"github.com/coder/coder/v2/testutil"
)

func TestMCPRegex(t *testing.T) {
t.Parallel()

cases := []struct {
name string
allowRegex, denyRegex string
expectedErr error
}{
{
name: "invalid allow regex",
allowRegex: `\`,
expectedErr: ErrCompileRegex,
},
{
name: "invalid deny regex",
denyRegex: `+`,
expectedErr: ErrCompileRegex,
},
{
name: "valid empty",
},
{
name: "valid",
allowRegex: "(allowed|allowed2)",
denyRegex: ".*disallowed.*",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := testutil.Logger(t)
f := NewMCPProxyFactory(logger, nil)

_, err := f.newStreamableHTTPServerProxy(&proto.MCPServerConfig{
Id: "mock",
Url: "mock/mcp",
ToolAllowRegex: tc.allowRegex,
ToolDenyRegex: tc.denyRegex,
}, "")

if tc.expectedErr == nil {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, tc.expectedErr)
}
})
}
}
25 changes: 22 additions & 3 deletionsenterprise/x/aibridged/pool.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,6 +14,7 @@ import (
"cdr.dev/slog"

"github.com/coder/aibridge"
"github.com/coder/aibridge/mcp"
)

const (
Expand All@@ -23,7 +24,7 @@ const (
// Pooler describes a pool of [*aibridge.RequestBridge] instances from which instances can be retrieved.
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc) (http.Handler, error)
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
Shutdown(ctx context.Context) error
}

Expand DownExpand Up@@ -102,7 +103,7 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
//
// Each returned [*aibridge.RequestBridge] is safe for concurrent use.
// Each [*aibridge.RequestBridge] is stateful because it has MCP clients which maintain sessions to the configured MCP server.
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc) (http.Handler, error) {
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpProxyFactory MCPProxyBuilder) (http.Handler, error) {
if err := ctx.Err(); err != nil {
return nil, xerrors.Errorf("acquire: %w", err)
}
Expand DownExpand Up@@ -141,7 +142,25 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, p.logger, recorder, nil)
var (
mcpServers mcp.ServerProxier
err error
)

mcpServers, err = mcpProxyFactory.Build(ctx, req)
if err != nil {
p.logger.Warn(ctx, "failed to create MCP server proxiers", slog.Error(err))
// Don't fail here; MCP server injection can gracefully degrade.
}

if mcpServers != nil {
// This will block while connections are established with upstream MCP server(s), and tools are listed.
if err := mcpServers.Init(ctx); err != nil {
p.logger.Warn(ctx, "failed to initialize MCP server proxier(s)", slog.Error(err))
}
}

bridge, err := aibridge.NewRequestBridge(ctx, p.providers, p.logger, recorder, mcpServers)
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}
Expand Down
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp