|
| 1 | +package aibridged |
| 2 | + |
| 3 | +import ( |
| 4 | +"context" |
| 5 | +"fmt" |
| 6 | +"regexp" |
| 7 | +"time" |
| 8 | + |
| 9 | +"golang.org/x/xerrors" |
| 10 | + |
| 11 | +"cdr.dev/slog" |
| 12 | +"github.com/coder/aibridge/mcp" |
| 13 | +"github.com/coder/coder/v2/aibridged/proto" |
| 14 | +) |
| 15 | + |
| 16 | +var ( |
| 17 | +ErrEmptyConfig=xerrors.New("empty config given") |
| 18 | +ErrCompileRegex=xerrors.New("compile tool regex") |
| 19 | +) |
| 20 | + |
| 21 | +typeMCPProxyBuilderinterface { |
| 22 | +// Build creates a [mcp.ServerProxier] for the given request initiator. |
| 23 | +// At minimum, the Coder MCP server will be proxied. |
| 24 | +// The SessionKey from [Request] is used to authenticate against the Coder MCP server. |
| 25 | +// |
| 26 | +// NOTE: the [mcp.ServerProxier] instance may be proxying one or more MCP servers. |
| 27 | +Build(ctx context.Context,reqRequest) (mcp.ServerProxier,error) |
| 28 | +} |
| 29 | + |
| 30 | +var_MCPProxyBuilder=&MCPProxyFactory{} |
| 31 | + |
| 32 | +typeMCPProxyFactorystruct { |
| 33 | +logger slog.Logger |
| 34 | +clientFnClientFunc |
| 35 | +} |
| 36 | + |
| 37 | +funcNewMCPProxyFactory(logger slog.Logger,clientFnClientFunc)*MCPProxyFactory { |
| 38 | +return&MCPProxyFactory{ |
| 39 | +logger:logger, |
| 40 | +clientFn:clientFn, |
| 41 | +} |
| 42 | +} |
| 43 | + |
| 44 | +func (m*MCPProxyFactory)Build(ctx context.Context,reqRequest) (mcp.ServerProxier,error) { |
| 45 | +proxiers,err:=m.retrieveMCPServerConfigs(ctx,req) |
| 46 | +iferr!=nil { |
| 47 | +returnnil,xerrors.Errorf("resolve configs: %w",err) |
| 48 | +} |
| 49 | + |
| 50 | +returnmcp.NewServerProxyManager(proxiers),nil |
| 51 | +} |
| 52 | + |
| 53 | +func (m*MCPProxyFactory)retrieveMCPServerConfigs(ctx context.Context,reqRequest) (map[string]mcp.ServerProxier,error) { |
| 54 | +client,err:=m.clientFn() |
| 55 | +iferr!=nil { |
| 56 | +returnnil,xerrors.Errorf("acquire client: %w",err) |
| 57 | +} |
| 58 | + |
| 59 | +srvCfgCtx,srvCfgCancel:=context.WithTimeout(ctx,time.Second*10) |
| 60 | +defersrvCfgCancel() |
| 61 | + |
| 62 | +// Fetch MCP server configs. |
| 63 | +mcpSrvCfgs,err:=client.GetMCPServerConfigs(srvCfgCtx,&proto.GetMCPServerConfigsRequest{ |
| 64 | +UserId:req.InitiatorID.String(), |
| 65 | +}) |
| 66 | +iferr!=nil { |
| 67 | +returnnil,xerrors.Errorf("get MCP server configs: %w",err) |
| 68 | +} |
| 69 | + |
| 70 | +proxiers:=make(map[string]mcp.ServerProxier,len(mcpSrvCfgs.GetExternalAuthMcpConfigs())+1)// Extra one for Coder MCP server. |
| 71 | + |
| 72 | +ifmcpSrvCfgs.GetCoderMcpConfig()!=nil { |
| 73 | +// Setup the Coder MCP server proxy. |
| 74 | +coderMCPProxy,err:=m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(),req.SessionKey)// The session key is used to auth against our internal MCP server. |
| 75 | +iferr!=nil { |
| 76 | +m.logger.Warn(ctx,"failed to create MCP server proxy",slog.F("mcp_server_id",mcpSrvCfgs.GetCoderMcpConfig().GetId()),slog.Error(err)) |
| 77 | +}else { |
| 78 | +proxiers["coder"]=coderMCPProxy |
| 79 | +} |
| 80 | +} |
| 81 | + |
| 82 | +iflen(mcpSrvCfgs.GetExternalAuthMcpConfigs())==0 { |
| 83 | +returnproxiers,nil |
| 84 | +} |
| 85 | + |
| 86 | +serverIDs:=make([]string,0,len(mcpSrvCfgs.GetExternalAuthMcpConfigs())) |
| 87 | +for_,cfg:=rangemcpSrvCfgs.GetExternalAuthMcpConfigs() { |
| 88 | +serverIDs=append(serverIDs,cfg.GetId()) |
| 89 | +} |
| 90 | + |
| 91 | +accTokCtx,accTokCancel:=context.WithTimeout(ctx,time.Second*10) |
| 92 | +deferaccTokCancel() |
| 93 | + |
| 94 | +// Request a batch of access tokens, one per given server ID. |
| 95 | +resp,err:=client.GetMCPServerAccessTokensBatch(accTokCtx,&proto.GetMCPServerAccessTokensBatchRequest{ |
| 96 | +UserId:req.InitiatorID.String(), |
| 97 | +McpServerConfigIds:serverIDs, |
| 98 | +}) |
| 99 | +iferr!=nil { |
| 100 | +m.logger.Warn(ctx,"failed to retrieve access token(s)",slog.F("server_ids",serverIDs),slog.Error(err)) |
| 101 | +} |
| 102 | + |
| 103 | +ifresp==nil { |
| 104 | +returnproxiers,nil |
| 105 | +} |
| 106 | +tokens:=resp.GetAccessTokens() |
| 107 | +iflen(tokens)==0 { |
| 108 | +returnproxiers,nil |
| 109 | +} |
| 110 | + |
| 111 | +forid,tokErr:=rangeresp.GetErrors() { |
| 112 | +m.logger.Warn(ctx,"failed to retrieve access token",slog.F("server_id",id),slog.F("error",tokErr)) |
| 113 | +} |
| 114 | + |
| 115 | +// Iterate over all External Auth configurations which are configured for MCP and attempt to setup |
| 116 | +// a [mcp.ServerProxier] for it using the access token retrieved above. |
| 117 | +for_,cfg:=rangemcpSrvCfgs.GetExternalAuthMcpConfigs() { |
| 118 | +iferr,ok:=resp.GetErrors()[cfg.GetId()];ok { |
| 119 | +m.logger.Warn(ctx,"failed to get access token",slog.F("mcp_server_id",cfg.GetId()),slog.F("error",err)) |
| 120 | +continue |
| 121 | +} |
| 122 | + |
| 123 | +token,ok:=tokens[cfg.GetId()] |
| 124 | +if!ok { |
| 125 | +m.logger.Warn(ctx,"no access token found",slog.F("mcp_server_id",cfg.GetId())) |
| 126 | +continue |
| 127 | +} |
| 128 | + |
| 129 | +proxy,err:=m.newStreamableHTTPServerProxy(cfg,token) |
| 130 | +iferr!=nil { |
| 131 | +m.logger.Warn(ctx,"failed to create MCP server proxy",slog.F("mcp_server_id",cfg.GetId()),slog.Error(err)) |
| 132 | +continue |
| 133 | +} |
| 134 | + |
| 135 | +proxiers[cfg.Id]=proxy |
| 136 | +} |
| 137 | +returnproxiers,nil |
| 138 | +} |
| 139 | + |
| 140 | +// newStreamableHTTPServerProxy creates an MCP server capable of proxying requests using the Streamable HTTP transport. |
| 141 | +// |
| 142 | +// TODO: support SSE transport. |
| 143 | +func (m*MCPProxyFactory)newStreamableHTTPServerProxy(cfg*proto.MCPServerConfig,accessTokenstring) (mcp.ServerProxier,error) { |
| 144 | +ifcfg==nil { |
| 145 | +returnnil,ErrEmptyConfig |
| 146 | +} |
| 147 | + |
| 148 | +var ( |
| 149 | +allowlist,denylist*regexp.Regexp |
| 150 | +errerror |
| 151 | +) |
| 152 | +ifcfg.GetToolAllowRegex()!="" { |
| 153 | +allowlist,err=regexp.Compile(cfg.GetToolAllowRegex()) |
| 154 | +iferr!=nil { |
| 155 | +returnnil,ErrCompileRegex |
| 156 | +} |
| 157 | +} |
| 158 | +ifcfg.GetToolDenyRegex()!="" { |
| 159 | +denylist,err=regexp.Compile(cfg.GetToolDenyRegex()) |
| 160 | +iferr!=nil { |
| 161 | +returnnil,ErrCompileRegex |
| 162 | +} |
| 163 | +} |
| 164 | + |
| 165 | +// TODO: future improvement: |
| 166 | +// |
| 167 | +// The access token provided here may expire at any time, or the connection to the MCP server could be severed. |
| 168 | +// Instead of passing through an access token directly, rather provide an interface through which to retrieve |
| 169 | +// an access token imperatively. In the event of a tool call failing, we could Ping() the MCP server to establish |
| 170 | +// whether the connection is still active. If not, this indicates that the access token is probably expired/revoked. |
| 171 | +// (It could also mean the server has a problem, which we should account for.) |
| 172 | +// The proxy could then use its interface to retrieve a new access token and re-establish a connection. |
| 173 | +// For now though, the short TTL of this cache should mostly mask this problem. |
| 174 | +srv,err:=mcp.NewStreamableHTTPServerProxy( |
| 175 | +m.logger.Named(fmt.Sprintf("mcp-server-proxy-%s",cfg.GetId())), |
| 176 | +cfg.GetId(), |
| 177 | +cfg.GetUrl(), |
| 178 | +// See https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#token-requirements. |
| 179 | +map[string]string{ |
| 180 | +"Authorization":fmt.Sprintf("Bearer %s",accessToken), |
| 181 | +}, |
| 182 | +allowlist, |
| 183 | +denylist, |
| 184 | +) |
| 185 | +iferr!=nil { |
| 186 | +returnnil,xerrors.Errorf("create streamable HTTP MCP server proxy: %w",err) |
| 187 | +} |
| 188 | + |
| 189 | +returnsrv,nil |
| 190 | +} |