@@ -16,9 +16,10 @@ import (
16
16
"time"
17
17
18
18
"github.com/anthropics/anthropic-sdk-go"
19
- "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
19
+ ant_ssestream "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
20
20
"github.com/charmbracelet/log"
21
21
"github.com/openai/openai-go"
22
+ openai_ssestream"github.com/openai/openai-go/packages/ssestream"
22
23
"golang.org/x/xerrors"
23
24
24
25
"github.com/coder/coder/v2/aibridged/proto"
@@ -61,24 +62,13 @@ func NewBridge(addr string, clientFn func() (proto.DRPCAIBridgeDaemonClient, boo
61
62
}
62
63
63
64
func (b * Bridge )proxyOpenAIRequest (w http.ResponseWriter ,r * http.Request ) {
64
- body ,err := io .ReadAll (r .Body )
65
- if err != nil {
66
- // TODO: error handling.
67
- panic (err )
68
- return
69
- }
70
- r .Body .Close ()
71
-
72
- var msg openai.ChatCompletionNewParams
73
- err = json .Unmarshal (body ,& msg )
74
- if err != nil {
75
- // TODO: error handling.
76
- panic (err )
65
+ coderdClient ,ok := b .clientFn ()
66
+ if ! ok {
67
+ // TODO: log issue.
68
+ http .Error (w ,"could not acquire coderd client" ,http .StatusInternalServerError )
77
69
return
78
70
}
79
71
80
- fmt .Println (msg )
81
-
82
72
target ,err := url .Parse ("https://api.openai.com" )
83
73
if err != nil {
84
74
http .Error (w ,"failed to parse OpenAI URL" ,http .StatusInternalServerError )
@@ -103,8 +93,99 @@ func (b *Bridge) proxyOpenAIRequest(w http.ResponseWriter, r *http.Request) {
103
93
req .URL .Scheme = target .Scheme
104
94
req .URL .Host = target .Host
105
95
96
+ body ,err := io .ReadAll (req .Body )
97
+ if err != nil {
98
+ http .Error (w ,"could not ready request body" ,http .StatusBadRequest )
99
+ return
100
+ }
101
+ _ = req .Body .Close ()
102
+
103
+ var msg openai.ChatCompletionNewParams
104
+ err = json .NewDecoder (bytes .NewReader (body )).Decode (& msg )
105
+ if err != nil {
106
+ http .Error (w ,"could not unmarshal request body" ,http .StatusBadRequest )
107
+ return
108
+ }
109
+
110
+ // TODO: robustness
111
+ if len (msg .Messages )> 0 {
112
+ latest := msg .Messages [len (msg .Messages )- 1 ]
113
+ if latest .OfUser != nil {
114
+ if latest .OfUser .Content .OfString .String ()!= "" {
115
+ _ ,_ = coderdClient .TrackUserPrompts (r .Context (),& proto.TrackUserPromptsRequest {
116
+ Prompt :strings .TrimSpace (latest .OfUser .Content .OfString .String ()),
117
+ })
118
+ }else {
119
+ fmt .Println ()
120
+ }
121
+ }
122
+ }
123
+
124
+ req .Body = io .NopCloser (bytes .NewReader (body ))
125
+
106
126
fmt .Printf ("Proxying %s request to: %s\n " ,req .Method ,req .URL .String ())
107
127
}
128
+ proxy .ModifyResponse = func (response * http.Response )error {
129
+ body ,err := io .ReadAll (response .Body )
130
+ if err != nil {
131
+ return xerrors .Errorf ("read response body: %w" ,err )
132
+ }
133
+ if err = response .Body .Close ();err != nil {
134
+ return xerrors .Errorf ("close body: %w" ,err )
135
+ }
136
+
137
+ if ! strings .Contains (response .Header .Get ("Content-Type" ),"text/event-stream" ) {
138
+ var msg openai.ChatCompletion
139
+
140
+ // TODO: check content-encoding to handle others.
141
+ gr ,err := gzip .NewReader (bytes .NewReader (body ))
142
+ if err != nil {
143
+ return xerrors .Errorf ("parse gzip-encoded body: %w" ,err )
144
+ }
145
+
146
+ err = json .NewDecoder (gr ).Decode (& msg )
147
+ if err != nil {
148
+ return xerrors .Errorf ("parse non-streaming body: %w" ,err )
149
+ }
150
+
151
+ _ ,_ = coderdClient .TrackTokenUsage (r .Context (),& proto.TrackTokenUsageRequest {
152
+ MsgId :msg .ID ,
153
+ InputTokens :msg .Usage .PromptTokens ,
154
+ OutputTokens :msg .Usage .CompletionTokens ,
155
+ })
156
+
157
+ response .Body = io .NopCloser (bytes .NewReader (body ))
158
+ return nil
159
+ }
160
+
161
+ response .Body = io .NopCloser (bytes .NewReader (body ))
162
+ stream := openai_ssestream .NewStream [openai.ChatCompletionChunk ](openai_ssestream .NewDecoder (response ),nil )
163
+
164
+ var (
165
+ inputToks ,outputToks int64
166
+ )
167
+
168
+ var msg openai.ChatCompletionAccumulator
169
+ for stream .Next () {
170
+ chunk := stream .Current ()
171
+ msg .AddChunk (chunk )
172
+
173
+ if msg .Usage .PromptTokens + msg .Usage .CompletionTokens > 0 {
174
+ inputToks = msg .Usage .PromptTokens
175
+ outputToks = msg .Usage .CompletionTokens
176
+ }
177
+ }
178
+
179
+ _ ,_ = coderdClient .TrackTokenUsage (r .Context (),& proto.TrackTokenUsageRequest {
180
+ MsgId :msg .ID ,
181
+ InputTokens :inputToks ,
182
+ OutputTokens :outputToks ,
183
+ })
184
+
185
+ response .Body = io .NopCloser (bytes .NewReader (body ))
186
+
187
+ return nil
188
+ }
108
189
proxy .ServeHTTP (w ,r )
109
190
}
110
191
@@ -207,7 +288,7 @@ func (b *Bridge) proxyAnthropicRequest(w http.ResponseWriter, r *http.Request) {
207
288
}
208
289
209
290
response .Body = io .NopCloser (bytes .NewReader (body ))
210
- stream := ssestream .NewStream [anthropic.MessageStreamEventUnion ](ssestream .NewDecoder (response ),nil )
291
+ stream := ant_ssestream .NewStream [anthropic.MessageStreamEventUnion ](ant_ssestream .NewDecoder (response ),nil )
211
292
212
293
var (
213
294
inputToks ,outputToks int64