@@ -15,7 +15,6 @@ import (
15
15
"golang.org/x/xerrors"
16
16
17
17
"github.com/coder/coder/v2/aibridged/proto"
18
- "github.com/coder/coder/v2/coderd/database"
19
18
"github.com/coder/coder/v2/codersdk"
20
19
)
21
20
@@ -49,16 +48,17 @@ type Server struct {
49
48
// shuttingDownCh will receive when we start graceful shutdown
50
49
shuttingDownCh chan struct {}
51
50
52
- bridge * Bridge
51
+ bridge * Bridge
53
52
}
54
53
55
- func New (store database.Store ,rpcDialer Dialer ,httpAddr string ,logger slog.Logger ) (* Server ,error ) {
54
+ var _ proto.DRPCAIBridgeDaemonServer = & Server {}
55
+
56
+ func New (rpcDialer Dialer ,httpAddr string ,logger slog.Logger ) (* Server ,error ) {
56
57
if rpcDialer == nil {
57
58
return nil ,xerrors .Errorf ("nil rpcDialer given" )
58
59
}
59
60
60
61
ctx ,cancel := context .WithCancel (context .Background ())
61
- bridge := NewBridge (httpAddr ,store )
62
62
daemon := & Server {
63
63
logger :logger ,
64
64
clientDialer :rpcDialer ,
@@ -68,9 +68,11 @@ func New(store database.Store, rpcDialer Dialer, httpAddr string, logger slog.Lo
68
68
closedCh :make (chan struct {}),
69
69
shuttingDownCh :make (chan struct {}),
70
70
initConnectionCh :make (chan struct {}),
71
-
72
- bridge :bridge ,
73
71
}
72
+
73
+ bridge := NewBridge (httpAddr ,daemon .client )
74
+ daemon .bridge = bridge
75
+
74
76
go daemon .connect ()
75
77
go func () {
76
78
err := bridge .Serve ()
@@ -164,6 +166,26 @@ func (s *Server) AuditPrompt(ctx context.Context, in *proto.AuditPromptRequest)
164
166
return out ,nil
165
167
}
166
168
169
+ func (s * Server )TrackTokenUsage (ctx context.Context ,in * proto.TrackTokenUsageRequest ) (* proto.TrackTokenUsageResponse ,error ) {
170
+ out ,err := clientDoWithRetries (ctx ,s .client ,func (ctx context.Context ,client proto.DRPCAIBridgeDaemonClient ) (* proto.TrackTokenUsageResponse ,error ) {
171
+ return client .TrackTokenUsage (ctx ,in )
172
+ })
173
+ if err != nil {
174
+ return nil ,err
175
+ }
176
+ return out ,nil
177
+ }
178
+
179
+ func (s * Server )TrackUserPrompts (ctx context.Context ,in * proto.TrackUserPromptsRequest ) (* proto.TrackUserPromptsResponse ,error ) {
180
+ out ,err := clientDoWithRetries (ctx ,s .client ,func (ctx context.Context ,client proto.DRPCAIBridgeDaemonClient ) (* proto.TrackUserPromptsResponse ,error ) {
181
+ return client .TrackUserPrompts (ctx ,in )
182
+ })
183
+ if err != nil {
184
+ return nil ,err
185
+ }
186
+ return out ,nil
187
+ }
188
+
167
189
//func (s *Server) ChatCompletions(payload *proto.JSONPayload, stream proto.DRPCOpenAIService_ChatCompletionsStream) error {
168
190
//// TODO: call OpenAI API.
169
191
//