|
| 1 | +package aibridged_test |
| 2 | + |
| 3 | +import ( |
| 4 | +"bytes" |
| 5 | +"context" |
| 6 | +"fmt" |
| 7 | +"net/http" |
| 8 | +"net/http/httptest" |
| 9 | +"testing" |
| 10 | +"time" |
| 11 | + |
| 12 | +"github.com/stretchr/testify/require" |
| 13 | + |
| 14 | +"github.com/coder/aibridge" |
| 15 | +"github.com/coder/coder/v2/aibridged" |
| 16 | +"github.com/coder/coder/v2/coderd/coderdtest" |
| 17 | +"github.com/coder/coder/v2/coderd/database" |
| 18 | +"github.com/coder/coder/v2/coderd/database/dbauthz" |
| 19 | +"github.com/coder/coder/v2/coderd/database/dbtestutil" |
| 20 | +"github.com/coder/coder/v2/coderd/database/dbtime" |
| 21 | +"github.com/coder/coder/v2/coderd/externalauth" |
| 22 | +"github.com/coder/coder/v2/codersdk" |
| 23 | +"github.com/coder/coder/v2/testutil" |
| 24 | +) |
| 25 | + |
| 26 | +// TestIntegration is not an exhaustive test against the upstream AI providers' SDKs (see coder/aibridge for those). |
| 27 | +// This test validates that: |
| 28 | +// - intercepted requests can be authenticated/authorized |
| 29 | +// - requests can be routed to an appropriate handler |
| 30 | +// - responses can be returned as expected |
| 31 | +// - interceptions are logged, as well as their related prompt, token, and tool calls |
| 32 | +// - MCP server configurations are returned as expected |
| 33 | +funcTestIntegration(t*testing.T) { |
| 34 | +t.Parallel() |
| 35 | + |
| 36 | +ctx:=testutil.Context(t,testutil.WaitLong) |
| 37 | + |
| 38 | +// Create mock MCP server. |
| 39 | +varmcpTokenReceivedstring |
| 40 | +mockMCPServer:=httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter,r*http.Request) { |
| 41 | +t.Logf("Mock MCP server received request: %s %s",r.Method,r.URL.Path) |
| 42 | + |
| 43 | +ifr.Method==http.MethodPost&&r.URL.Path=="/" { |
| 44 | +// Mark that init was called. |
| 45 | +mcpTokenReceived=r.Header.Get("Authorization") |
| 46 | +t.Log("MCP init request received") |
| 47 | + |
| 48 | +// Return a basic MCP init response. |
| 49 | +w.Header().Set("Content-Type","application/json") |
| 50 | +w.Header().Set("Mcp-Session-Id","test-session-123") |
| 51 | +w.WriteHeader(http.StatusOK) |
| 52 | +_,_=w.Write([]byte(`{ |
| 53 | +"jsonrpc": "2.0", |
| 54 | +"id": 1, |
| 55 | +"result": { |
| 56 | +"protocolVersion": "2024-11-05", |
| 57 | +"capabilities": {}, |
| 58 | +"serverInfo": { |
| 59 | +"name": "test-mcp-server", |
| 60 | +"version": "1.0.0" |
| 61 | +} |
| 62 | +} |
| 63 | +}`)) |
| 64 | +} |
| 65 | +})) |
| 66 | +t.Cleanup(mockMCPServer.Close) |
| 67 | +t.Logf("Mock MCP server running at: %s",mockMCPServer.URL) |
| 68 | + |
| 69 | +// Set up mock OpenAI server that returns a tool call response. |
| 70 | +mockOpenAI:=httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter,r*http.Request) { |
| 71 | +w.Header().Set("Content-Type","application/json") |
| 72 | +w.WriteHeader(http.StatusOK) |
| 73 | +_,_=w.Write([]byte(`{ |
| 74 | + "id": "chatcmpl-BwkyFElDIr1egmFyfQ9z4vPBto7m2", |
| 75 | + "object": "chat.completion", |
| 76 | + "created": 1753343279, |
| 77 | + "model": "gpt-4.1-2025-04-14", |
| 78 | + "choices": [ |
| 79 | + { |
| 80 | + "index": 0, |
| 81 | + "message": { |
| 82 | + "role": "assistant", |
| 83 | + "content": null, |
| 84 | + "tool_calls": [ |
| 85 | + { |
| 86 | + "id": "call_KjzAbhiZC6nk81tQzL7pwlpc", |
| 87 | + "type": "function", |
| 88 | + "function": { |
| 89 | + "name": "read_file", |
| 90 | + "arguments": "{\"path\":\"README.md\"}" |
| 91 | + } |
| 92 | + } |
| 93 | + ], |
| 94 | + "refusal": null, |
| 95 | + "annotations": [] |
| 96 | + }, |
| 97 | + "logprobs": null, |
| 98 | + "finish_reason": "tool_calls" |
| 99 | + } |
| 100 | + ], |
| 101 | + "usage": { |
| 102 | + "prompt_tokens": 60, |
| 103 | + "completion_tokens": 15, |
| 104 | + "total_tokens": 75, |
| 105 | + "prompt_tokens_details": { |
| 106 | + "cached_tokens": 0, |
| 107 | + "audio_tokens": 0 |
| 108 | + }, |
| 109 | + "completion_tokens_details": { |
| 110 | + "reasoning_tokens": 0, |
| 111 | + "audio_tokens": 0, |
| 112 | + "accepted_prediction_tokens": 0, |
| 113 | + "rejected_prediction_tokens": 0 |
| 114 | + } |
| 115 | + }, |
| 116 | + "service_tier": "default", |
| 117 | + "system_fingerprint": "fp_b3f1157249" |
| 118 | +}`)) |
| 119 | +})) |
| 120 | +t.Cleanup(mockOpenAI.Close) |
| 121 | + |
| 122 | +db,ps:=dbtestutil.NewDB(t) |
| 123 | +client,_,api:=coderdtest.NewWithAPI(t,&coderdtest.Options{ |
| 124 | +Database:db, |
| 125 | +Pubsub:ps, |
| 126 | +ExternalAuthConfigs: []*externalauth.Config{ |
| 127 | +{ |
| 128 | +InstrumentedOAuth2Config:&testutil.OAuth2Config{}, |
| 129 | +ID:"mock", |
| 130 | +Type:"mock", |
| 131 | +DisplayName:"Mock", |
| 132 | +MCPURL:mockMCPServer.URL, |
| 133 | +}, |
| 134 | +}, |
| 135 | +}) |
| 136 | + |
| 137 | +firstUser:=coderdtest.CreateFirstUser(t,client) |
| 138 | +userClient,user:=coderdtest.CreateAnotherUser(t,client,firstUser.OrganizationID) |
| 139 | + |
| 140 | +// Create an API token for the user. |
| 141 | +apiKey,err:=userClient.CreateToken(ctx,"me", codersdk.CreateTokenRequest{ |
| 142 | +TokenName:fmt.Sprintf("test-key-%d",time.Now().UnixNano()), |
| 143 | +Lifetime:time.Hour, |
| 144 | +Scope:codersdk.APIKeyScopeAll, |
| 145 | +}) |
| 146 | +require.NoError(t,err) |
| 147 | + |
| 148 | +// Create external auth link for the user. |
| 149 | +authLink,err:=db.InsertExternalAuthLink(dbauthz.AsSystemRestricted(ctx), database.InsertExternalAuthLinkParams{ |
| 150 | +ProviderID:"mock", |
| 151 | +UserID:user.ID, |
| 152 | +CreatedAt:dbtime.Now(), |
| 153 | +UpdatedAt:dbtime.Now(), |
| 154 | +OAuthAccessToken:"test-mock-token", |
| 155 | +OAuthRefreshToken:"test-refresh-token", |
| 156 | +OAuthExpiry:dbtime.Now().Add(time.Hour), |
| 157 | +}) |
| 158 | +require.NoError(t,err) |
| 159 | + |
| 160 | +// Create aibridge server & client. |
| 161 | +aiBridgeClient,err:=api.CreateInMemoryAIBridgeServer(ctx) |
| 162 | +require.NoError(t,err) |
| 163 | + |
| 164 | +logger:=testutil.Logger(t) |
| 165 | +providers:= []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.ProviderConfig{BaseURL:mockOpenAI.URL})} |
| 166 | +pool,err:=aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions,providers,logger) |
| 167 | +require.NoError(t,err) |
| 168 | + |
| 169 | +// Given: aibridged is started. |
| 170 | +srv,err:=aibridged.New(t.Context(),pool,func(ctx context.Context) (aibridged.DRPCClient,error) { |
| 171 | +returnaiBridgeClient,nil |
| 172 | +},logger) |
| 173 | +require.NoError(t,err,"create new aibridged") |
| 174 | +t.Cleanup(func() { |
| 175 | +_=srv.Shutdown(ctx) |
| 176 | +}) |
| 177 | + |
| 178 | +// When: a request is made to aibridged. |
| 179 | +req,err:=http.NewRequestWithContext(ctx,http.MethodPost,"/openai/v1/chat/completions",bytes.NewBufferString(`{ |
| 180 | + "messages": [ |
| 181 | + { |
| 182 | + "role": "user", |
| 183 | + "content": "how large is the README.md file in my current path" |
| 184 | + } |
| 185 | + ], |
| 186 | + "model": "gpt-4.1", |
| 187 | + "tools": [ |
| 188 | + { |
| 189 | + "type": "function", |
| 190 | + "function": { |
| 191 | + "name": "read_file", |
| 192 | + "description": "Read the contents of a file at the given path.", |
| 193 | + "parameters": { |
| 194 | + "properties": { |
| 195 | + "path": { |
| 196 | + "type": "string" |
| 197 | + } |
| 198 | + }, |
| 199 | + "required": [ |
| 200 | + "path" |
| 201 | + ], |
| 202 | + "type": "object" |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | + ] |
| 207 | +}`)) |
| 208 | +require.NoError(t,err,"make request to test server") |
| 209 | +req.Header.Add("Authorization","Bearer "+apiKey.Key) |
| 210 | +req.Header.Add("Accept","application/json") |
| 211 | + |
| 212 | +// When: aibridged handles the request. |
| 213 | +rec:=httptest.NewRecorder() |
| 214 | +srv.ServeHTTP(rec,req) |
| 215 | +require.Equal(t,http.StatusOK,rec.Code) |
| 216 | + |
| 217 | +// Then: the interception & related records are stored. |
| 218 | +interceptions,err:=db.GetAIBridgeInterceptions(ctx) |
| 219 | +require.NoError(t,err) |
| 220 | +require.Len(t,interceptions,1) |
| 221 | + |
| 222 | +prompts,err:=db.GetAIBridgeUserPromptsByInterceptionID(ctx,interceptions[0].ID) |
| 223 | +require.NoError(t,err) |
| 224 | +require.Len(t,prompts,1) |
| 225 | +require.Equal(t,prompts[0].Prompt,"how large is the README.md file in my current path") |
| 226 | + |
| 227 | +tokens,err:=db.GetAIBridgeTokenUsagesByInterceptionID(ctx,interceptions[0].ID) |
| 228 | +require.NoError(t,err) |
| 229 | +require.Len(t,tokens,1) |
| 230 | +require.EqualValues(t,tokens[0].InputTokens,60) |
| 231 | +require.EqualValues(t,tokens[0].OutputTokens,15) |
| 232 | + |
| 233 | +tools,err:=db.GetAIBridgeToolUsagesByInterceptionID(ctx,interceptions[0].ID) |
| 234 | +require.NoError(t,err) |
| 235 | +require.Len(t,tools,1) |
| 236 | +require.False(t,tools[0].Injected) |
| 237 | + |
| 238 | +// Then: the MCP server was initialized. |
| 239 | +require.Contains(t,mcpTokenReceived,authLink.OAuthAccessToken,"mock MCP server not requested") |
| 240 | +} |