@@ -32,6 +32,10 @@ func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock
32
32
client := mock .NewMockDRPCClient (ctrl )
33
33
pool := mock .NewMockPooler (ctrl )
34
34
35
+ conn := & mockDRPCConn {}
36
+ client .EXPECT ().DRPCConn ().AnyTimes ().Return (conn )
37
+ pool .EXPECT ().Shutdown (gomock .Any ()).MinTimes (1 ).Return (nil )
38
+
35
39
srv ,err := aibridged .New (
36
40
t .Context (),
37
41
pool ,
@@ -40,6 +44,9 @@ func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock
40
44
},
41
45
logger )
42
46
require .NoError (t ,err ,"create new aibridged" )
47
+ t .Cleanup (func () {
48
+ srv .Shutdown (context .Background ())
49
+ })
43
50
44
51
return srv ,client ,pool
45
52
}
@@ -53,6 +60,7 @@ func (*mockDRPCConn) Transport() drpc.Transport { return nil }
53
60
func (* mockDRPCConn )Invoke (ctx context.Context ,rpc string ,enc drpc.Encoding ,in ,out drpc.Message )error {
54
61
return nil
55
62
}
63
+
56
64
func (* mockDRPCConn )NewStream (ctx context.Context ,rpc string ,enc drpc.Encoding ) (drpc.Stream ,error ) {
57
65
// nolint:nilnil // Chillchill.
58
66
return nil ,nil
@@ -61,9 +69,7 @@ func (*mockDRPCConn) NewStream(ctx context.Context, rpc string, enc drpc.Encodin
61
69
func TestServeHTTP_FailureModes (t * testing.T ) {
62
70
t .Parallel ()
63
71
64
- var (
65
- defaultHeaders = map [string ]string {"Authorization" :"Bearer key" }
66
- )
72
+ defaultHeaders := map [string ]string {"Authorization" :"Bearer key" }
67
73
68
74
cases := []struct {
69
75
name string
@@ -116,7 +122,7 @@ func TestServeHTTP_FailureModes(t *testing.T) {
116
122
// Should pass authorization.
117
123
client .EXPECT ().IsAuthorized (gomock .Any (),gomock .Any ()).AnyTimes ().Return (& proto.IsAuthorizedResponse {OwnerId :uuid .NewString ()},nil )
118
124
// But fail when acquiring a pool instance.
119
- pool .EXPECT ().Acquire (gomock .Any (),gomock .Any (),gomock .Any ()).AnyTimes ().Return (nil ,xerrors .New ("oops" ))
125
+ pool .EXPECT ().Acquire (gomock .Any (),gomock .Any (),gomock .Any (), gomock . Any () ).AnyTimes ().Return (nil ,xerrors .New ("oops" ))
120
126
},
121
127
expectedErr :aibridged .ErrAcquireRequestHandler ,
122
128
expectedStatus :http .StatusInternalServerError ,
@@ -229,38 +235,6 @@ func (*mockHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
229
235
_ ,_ = rw .Write ([]byte (r .URL .Path ))
230
236
}
231
237
232
- // TestPoolHandler validates that an http.Handler can be acquired from a given [aibridged.Pooler]
233
- // and have its HTTP handler invoked.
234
- //
235
- // We're not actually testing the routing, since that is being tested by [aibridge.RequestBridge].
236
- //
237
- // We're validating that a request can be successfully processed by aibridged
238
- // (i.e. authn/z, acquire pool instance) and what happens thereafter is a black box to aibridged.
239
- func TestPoolHandler (t * testing.T ) {
240
- t .Parallel ()
241
-
242
- srv ,client ,pool := newTestServer (t )
243
-
244
- conn := & mockDRPCConn {}
245
- client .EXPECT ().DRPCConn ().AnyTimes ().Return (conn )
246
- // Authorize all requests.
247
- client .EXPECT ().IsAuthorized (gomock .Any (),gomock .Any ()).AnyTimes ().Return (& proto.IsAuthorizedResponse {OwnerId :uuid .NewString ()},nil )
248
- pool .EXPECT ().Acquire (gomock .Any (),gomock .Any (),gomock .Any ()).AnyTimes ().Return (& mockHandler {},nil )
249
-
250
- ctx := testutil .Context (t ,testutil .WaitShort )
251
- path := "/irrelevant"
252
- req ,err := http .NewRequestWithContext (ctx ,http .MethodPost ,path ,nil )
253
- require .NoError (t ,err ,"make request to test server" )
254
- req .Header .Add ("Authorization" ,"Bearer key" )
255
-
256
- rec := httptest .NewRecorder ()
257
- srv .ServeHTTP (rec ,req )
258
-
259
- require .Equal (t ,http .StatusOK ,rec .Code )
260
- require .NotNil (t ,rec .Body )
261
- require .Equal (t ,path ,rec .Body .String ())
262
- }
263
-
264
238
// TestRouting validates that a request which originates with aibridged will be handled
265
239
// by coder/aibridge's handling logic in a provider-specific manner.
266
240
// We must validate that logic that pertains to coder/coder is exercised.
@@ -285,13 +259,13 @@ func TestRouting(t *testing.T) {
285
259
{
286
260
name :"openai chat completions" ,
287
261
path :"/openai/v1/chat/completions" ,
288
- expectedStatus :http .StatusOK ,
262
+ expectedStatus :http .StatusTeapot , // Nonsense status to indicate server was hit.
289
263
expectedHits :1 ,
290
264
},
291
265
{
292
266
name :"anthropic messages" ,
293
267
path :"/anthropic/v1/messages" ,
294
- expectedStatus :http .StatusOK ,
268
+ expectedStatus :http .StatusTeapot , // Nonsense status to indicate server was hit.
295
269
expectedHits :1 ,
296
270
},
297
271
}
@@ -311,14 +285,12 @@ func TestRouting(t *testing.T) {
311
285
logger := slogtest .Make (t ,& slogtest.Options {IgnoreErrors :true })
312
286
ctrl := gomock .NewController (t )
313
287
client := mock .NewMockDRPCClient (ctrl )
314
- pool ,err := aibridged .NewCachedBridgePool (10 , aibridge.Config {
315
- OpenAI : aibridge.ProviderConfig {
316
- BaseURL :openaiSrv .URL ,
317
- },
318
- Anthropic : aibridge.ProviderConfig {
319
- BaseURL :antSrv .URL ,
320
- },
321
- },logger )
288
+
289
+ providers := []aibridge.Provider {
290
+ aibridge .NewOpenAIProvider (aibridge.ProviderConfig {BaseURL :openaiSrv .URL }),
291
+ aibridge .NewAnthropicProvider (aibridge.ProviderConfig {BaseURL :antSrv .URL }),
292
+ }
293
+ pool ,err := aibridged .NewCachedBridgePool (aibridged .DefaultPoolOptions ,providers ,logger )
322
294
require .NoError (t ,err )
323
295
conn := & mockDRPCConn {}
324
296
client .EXPECT ().DRPCConn ().AnyTimes ().Return (conn )
@@ -332,9 +304,6 @@ func TestRouting(t *testing.T) {
332
304
interceptionID = in .GetId ()
333
305
return & proto.RecordInterceptionResponse {},nil
334
306
})
335
- client .EXPECT ().RecordPromptUsage (gomock .Any (),gomock .Any ()).AnyTimes ().Return (& proto.RecordPromptUsageResponse {},nil )
336
- client .EXPECT ().RecordTokenUsage (gomock .Any (),gomock .Any ()).AnyTimes ().Return (& proto.RecordTokenUsageResponse {},nil )
337
- client .EXPECT ().RecordToolUsage (gomock .Any (),gomock .Any ()).AnyTimes ().Return (& proto.RecordToolUsageResponse {},nil )
338
307
339
308
// Given: aibridged is started.
340
309
srv ,err := aibridged .New (t .Context (),pool ,func (ctx context.Context ) (aibridged.DRPCClient ,error ) {
@@ -357,6 +326,8 @@ func TestRouting(t *testing.T) {
357
326
srv .ServeHTTP (rec ,req )
358
327
359
328
// Then: the upstream server will have received a number of hits.
329
+ // NOTE: we *expect* the interceptions to fail because [mockAIUpstreamServer] returns a nonsense status code.
330
+ // We only need to test that the request was routed, NOT processed.
360
331
require .Equal (t ,tc .expectedStatus ,rec .Code )
361
332
assert .EqualValues (t ,tc .expectedHits ,upstreamSrv .Hits ())
362
333
if tc .expectedHits > 0 {