- Notifications
You must be signed in to change notification settings - Fork1k
feat: aibridged mcp handling#19911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more abouthow customized files appear on GitHub.
Uh oh!
There was an error while loading.Please reload this page.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
package aibridged | ||
import ( | ||
"context" | ||
"fmt" | ||
"regexp" | ||
"time" | ||
"golang.org/x/xerrors" | ||
"cdr.dev/slog" | ||
"github.com/coder/aibridge/mcp" | ||
"github.com/coder/coder/v2/enterprise/x/aibridged/proto" | ||
) | ||
var ( | ||
ErrEmptyConfig = xerrors.New("empty config given") | ||
ErrCompileRegex = xerrors.New("compile tool regex") | ||
) | ||
const ( | ||
InternalMCPServerID = "coder" | ||
) | ||
type MCPProxyBuilder interface { | ||
// Build creates a [mcp.ServerProxier] for the given request initiator. | ||
// At minimum, the Coder MCP server will be proxied. | ||
// The SessionKey from [Request] is used to authenticate against the Coder MCP server. | ||
// | ||
// NOTE: the [mcp.ServerProxier] instance may be proxying one or more MCP servers. | ||
Build(ctx context.Context, req Request) (mcp.ServerProxier, error) | ||
} | ||
var _ MCPProxyBuilder = &MCPProxyFactory{} | ||
type MCPProxyFactory struct { | ||
logger slog.Logger | ||
clientFn ClientFunc | ||
} | ||
func NewMCPProxyFactory(logger slog.Logger, clientFn ClientFunc) *MCPProxyFactory { | ||
return &MCPProxyFactory{ | ||
logger: logger, | ||
clientFn: clientFn, | ||
} | ||
} | ||
func (m *MCPProxyFactory) Build(ctx context.Context, req Request) (mcp.ServerProxier, error) { | ||
proxiers, err := m.retrieveMCPServerConfigs(ctx, req) | ||
if err != nil { | ||
return nil, xerrors.Errorf("resolve configs: %w", err) | ||
} | ||
return mcp.NewServerProxyManager(proxiers), nil | ||
} | ||
func (m *MCPProxyFactory) retrieveMCPServerConfigs(ctx context.Context, req Request) (map[string]mcp.ServerProxier, error) { | ||
client, err := m.clientFn() | ||
if err != nil { | ||
return nil, xerrors.Errorf("acquire client: %w", err) | ||
} | ||
srvCfgCtx, srvCfgCancel := context.WithTimeout(ctx, time.Second*10) | ||
defer srvCfgCancel() | ||
// Fetch MCP server configs. | ||
mcpSrvCfgs, err := client.GetMCPServerConfigs(srvCfgCtx, &proto.GetMCPServerConfigsRequest{ | ||
UserId: req.InitiatorID.String(), | ||
}) | ||
if err != nil { | ||
return nil, xerrors.Errorf("get MCP server configs: %w", err) | ||
} | ||
proxiers := make(map[string]mcp.ServerProxier, len(mcpSrvCfgs.GetExternalAuthMcpConfigs())+1) // Extra one for Coder MCP server. | ||
if mcpSrvCfgs.GetCoderMcpConfig() != nil { | ||
// Setup the Coder MCP server proxy. | ||
coderMCPProxy, err := m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(), req.SessionKey) // The session key is used to auth against our internal MCP server. | ||
if err != nil { | ||
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()), slog.Error(err)) | ||
} else { | ||
proxiers[InternalMCPServerID] = coderMCPProxy | ||
} | ||
} | ||
if len(mcpSrvCfgs.GetExternalAuthMcpConfigs()) == 0 { | ||
return proxiers, nil | ||
} | ||
serverIDs := make([]string, 0, len(mcpSrvCfgs.GetExternalAuthMcpConfigs())) | ||
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() { | ||
serverIDs = append(serverIDs, cfg.GetId()) | ||
} | ||
accTokCtx, accTokCancel := context.WithTimeout(ctx, time.Second*10) | ||
defer accTokCancel() | ||
// Request a batch of access tokens, one per given server ID. | ||
resp, err := client.GetMCPServerAccessTokensBatch(accTokCtx, &proto.GetMCPServerAccessTokensBatchRequest{ | ||
UserId: req.InitiatorID.String(), | ||
McpServerConfigIds: serverIDs, | ||
}) | ||
if err != nil { | ||
m.logger.Warn(ctx, "failed to retrieve access token(s)", slog.F("server_ids", serverIDs), slog.Error(err)) | ||
} | ||
if resp == nil { | ||
m.logger.Warn(ctx, "nil response given to mcp access tokens call") | ||
return proxiers, nil | ||
} | ||
tokens := resp.GetAccessTokens() | ||
if len(tokens) == 0 { | ||
return proxiers, nil | ||
} | ||
// Iterate over all External Auth configurations which are configured for MCP and attempt to setup | ||
// a [mcp.ServerProxier] for it using the access token retrieved above. | ||
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() { | ||
if err, ok := resp.GetErrors()[cfg.GetId()]; ok { | ||
m.logger.Debug(ctx, "failed to get access token", slog.F("mcp_server_id", cfg.GetId()), slog.F("error", err)) | ||
continue | ||
} | ||
token, ok := tokens[cfg.GetId()] | ||
if !ok { | ||
m.logger.Warn(ctx, "no access token found", slog.F("mcp_server_id", cfg.GetId())) | ||
continue | ||
} | ||
proxy, err := m.newStreamableHTTPServerProxy(cfg, token) | ||
if err != nil { | ||
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", cfg.GetId()), slog.Error(err)) | ||
continue | ||
} | ||
proxiers[cfg.Id] = proxy | ||
} | ||
return proxiers, nil | ||
} | ||
// newStreamableHTTPServerProxy creates an MCP server capable of proxying requests using the Streamable HTTP transport. | ||
// | ||
// TODO: support SSE transport. | ||
func (m *MCPProxyFactory) newStreamableHTTPServerProxy(cfg *proto.MCPServerConfig, accessToken string) (mcp.ServerProxier, error) { | ||
if cfg == nil { | ||
return nil, ErrEmptyConfig | ||
} | ||
var ( | ||
allowlist, denylist *regexp.Regexp | ||
err error | ||
) | ||
if cfg.GetToolAllowRegex() != "" { | ||
allowlist, err = regexp.Compile(cfg.GetToolAllowRegex()) | ||
if err != nil { | ||
return nil, ErrCompileRegex | ||
} | ||
} | ||
if cfg.GetToolDenyRegex() != "" { | ||
denylist, err = regexp.Compile(cfg.GetToolDenyRegex()) | ||
if err != nil { | ||
return nil, ErrCompileRegex | ||
} | ||
} | ||
// TODO: future improvement: | ||
// | ||
// The access token provided here may expire at any time, or the connection to the MCP server could be severed. | ||
// Instead of passing through an access token directly, rather provide an interface through which to retrieve | ||
// an access token imperatively. In the event of a tool call failing, we could Ping() the MCP server to establish | ||
// whether the connection is still active. If not, this indicates that the access token is probably expired/revoked. | ||
// (It could also mean the server has a problem, which we should account for.) | ||
// The proxy could then use its interface to retrieve a new access token and re-establish a connection. | ||
// For now though, the short TTL of this cache should mostly mask this problem. | ||
srv, err := mcp.NewStreamableHTTPServerProxy( | ||
m.logger.Named(fmt.Sprintf("mcp-server-proxy-%s", cfg.GetId())), | ||
cfg.GetId(), | ||
cfg.GetUrl(), | ||
// See https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#token-requirements. | ||
map[string]string{ | ||
"Authorization": fmt.Sprintf("Bearer %s", accessToken), | ||
}, | ||
allowlist, | ||
denylist, | ||
) | ||
if err != nil { | ||
return nil, xerrors.Errorf("create streamable HTTP MCP server proxy: %w", err) | ||
} | ||
return srv, nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package aibridged | ||
import ( | ||
"testing" | ||
"github.com/stretchr/testify/require" | ||
"github.com/coder/coder/v2/enterprise/x/aibridged/proto" | ||
"github.com/coder/coder/v2/testutil" | ||
) | ||
func TestMCPRegex(t *testing.T) { | ||
t.Parallel() | ||
cases := []struct { | ||
name string | ||
allowRegex, denyRegex string | ||
expectedErr error | ||
}{ | ||
{ | ||
name: "invalid allow regex", | ||
allowRegex: `\`, | ||
expectedErr: ErrCompileRegex, | ||
}, | ||
{ | ||
name: "invalid deny regex", | ||
denyRegex: `+`, | ||
expectedErr: ErrCompileRegex, | ||
}, | ||
{ | ||
name: "valid empty", | ||
}, | ||
{ | ||
name: "valid", | ||
allowRegex: "(allowed|allowed2)", | ||
denyRegex: ".*disallowed.*", | ||
}, | ||
} | ||
for _, tc := range cases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
logger := testutil.Logger(t) | ||
f := NewMCPProxyFactory(logger, nil) | ||
_, err := f.newStreamableHTTPServerProxy(&proto.MCPServerConfig{ | ||
Id: "mock", | ||
Url: "mock/mcp", | ||
ToolAllowRegex: tc.allowRegex, | ||
ToolDenyRegex: tc.denyRegex, | ||
}, "") | ||
if tc.expectedErr == nil { | ||
require.NoError(t, err) | ||
} else { | ||
require.ErrorIs(t, err, tc.expectedErr) | ||
} | ||
}) | ||
} | ||
} |
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.