@@ -2,21 +2,21 @@ package coderd
2
2
3
3
import (
4
4
"encoding/json"
5
+ "io"
5
6
"net/http"
6
7
"time"
7
8
9
+ "github.com/google/uuid"
10
+ "github.com/kylecarbs/aisdk-go"
11
+
8
12
"github.com/coder/coder/v2/coderd/ai"
9
13
"github.com/coder/coder/v2/coderd/database"
10
14
"github.com/coder/coder/v2/coderd/database/db2sdk"
11
15
"github.com/coder/coder/v2/coderd/database/dbtime"
12
16
"github.com/coder/coder/v2/coderd/httpapi"
13
17
"github.com/coder/coder/v2/coderd/httpmw"
14
18
"github.com/coder/coder/v2/codersdk"
15
- codermcp"github.com/coder/coder/v2/mcp"
16
- "github.com/google/uuid"
17
- "github.com/kylecarbs/aisdk-go"
18
- "github.com/mark3labs/mcp-go/mcp"
19
- "github.com/mark3labs/mcp-go/server"
19
+ "github.com/coder/coder/v2/codersdk/toolsdk"
20
20
)
21
21
22
22
// postChats creates a new chat.
@@ -142,9 +142,10 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
142
142
Message :"Failed to get chat messages" ,
143
143
Detail :err .Error (),
144
144
})
145
+ return
145
146
}
146
147
147
- messages := make ([]aisdk.Message ,len (dbMessages ))
148
+ messages := make ([]aisdk.Message ,0 , len (dbMessages ))
148
149
for i ,message := range dbMessages {
149
150
err = json .Unmarshal (message .Content ,& messages [i ])
150
151
if err != nil {
@@ -157,31 +158,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
157
158
}
158
159
messages = append (messages ,req .Message )
159
160
160
- toolMap := codermcp .AllTools ()
161
- toolsByName := make (map [string ]server.ToolHandlerFunc )
162
161
client := codersdk .New (api .AccessURL )
163
162
client .SetSessionToken (httpmw .APITokenFromRequest (r ))
164
- toolDeps := codermcp.ToolDeps {
165
- Client :client ,
166
- Logger :& api .Logger ,
167
- }
168
- for _ ,tool := range toolMap {
169
- toolsByName [tool .Tool .Name ]= tool .MakeHandler (toolDeps )
170
- }
171
- convertedTools := make ([]aisdk.Tool ,len (toolMap ))
172
- for i ,tool := range toolMap {
173
- schema := aisdk.Schema {
174
- Required :tool .Tool .InputSchema .Required ,
175
- Properties :tool .Tool .InputSchema .Properties ,
176
- }
177
- if tool .Tool .InputSchema .Required == nil {
178
- schema .Required = []string {}
179
- }
180
- convertedTools [i ]= aisdk.Tool {
181
- Name :tool .Tool .Name ,
182
- Description :tool .Tool .Description ,
183
- Schema :schema ,
163
+
164
+ tools := make ([]aisdk.Tool ,len (toolsdk .All ))
165
+ handlers := map [string ]toolsdk.GenericHandlerFunc {}
166
+ for i ,tool := range toolsdk .All {
167
+ if tool .Tool .Schema .Required == nil {
168
+ tool .Tool .Schema .Required = []string {}
184
169
}
170
+ tools [i ]= tool .Tool
171
+ handlers [tool .Tool .Name ]= tool .Handler
185
172
}
186
173
187
174
provider ,ok := api .LanguageModels [req .Model ]
@@ -192,6 +179,44 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
192
179
return
193
180
}
194
181
182
+ // If it's the user's first message, generate a title for the chat.
183
+ if len (messages )== 1 {
184
+ var acc aisdk.DataStreamAccumulator
185
+ stream ,err := provider .StreamFunc (ctx , ai.StreamOptions {
186
+ Model :req .Model ,
187
+ SystemPrompt :`- You will generate a short title based on the user's message.
188
+ - It should be maximum of 40 characters.
189
+ - Do not use quotes, colons, special characters, or emojis.` ,
190
+ Messages :messages ,
191
+ Tools :tools ,
192
+ })
193
+ if err != nil {
194
+ httpapi .Write (ctx ,w ,http .StatusInternalServerError , codersdk.Response {
195
+ Message :"Failed to create stream" ,
196
+ Detail :err .Error (),
197
+ })
198
+ }
199
+ stream = stream .WithAccumulator (& acc )
200
+ err = stream .Pipe (io .Discard )
201
+ if err != nil {
202
+ httpapi .Write (ctx ,w ,http .StatusInternalServerError , codersdk.Response {
203
+ Message :"Failed to pipe stream" ,
204
+ Detail :err .Error (),
205
+ })
206
+ }
207
+ err = api .Database .UpdateChatByID (ctx , database.UpdateChatByIDParams {
208
+ ID :chat .ID ,
209
+ Title :acc .Messages ()[0 ].Content ,
210
+ })
211
+ if err != nil {
212
+ httpapi .Write (ctx ,w ,http .StatusInternalServerError , codersdk.Response {
213
+ Message :"Failed to update chat title" ,
214
+ Detail :err .Error (),
215
+ })
216
+ return
217
+ }
218
+ }
219
+
195
220
// Write headers for the data stream!
196
221
aisdk .WriteDataStreamHeaders (w )
197
222
@@ -219,12 +244,20 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
219
244
return
220
245
}
221
246
247
+ deps := toolsdk.Deps {
248
+ CoderClient :client ,
249
+ }
250
+
222
251
for {
223
252
var acc aisdk.DataStreamAccumulator
224
253
stream ,err := provider .StreamFunc (ctx , ai.StreamOptions {
225
254
Model :req .Model ,
226
255
Messages :messages ,
227
- Tools :convertedTools ,
256
+ Tools :tools ,
257
+ SystemPrompt :`You are a chat assistant for Coder. You will attempt to resolve the user's
258
+ request to the maximum utilization of your tools.
259
+
260
+ Try your best to not ask the user for help - solve the task with your tools!` ,
228
261
})
229
262
if err != nil {
230
263
httpapi .Write (ctx ,w ,http .StatusInternalServerError , codersdk.Response {
@@ -234,28 +267,21 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
234
267
return
235
268
}
236
269
stream = stream .WithToolCalling (func (toolCall aisdk.ToolCall )any {
237
- tool ,ok := toolsByName [toolCall .Name ]
270
+ tool ,ok := handlers [toolCall .Name ]
238
271
if ! ok {
239
272
return nil
240
273
}
241
- result ,err := tool (ctx , mcp.CallToolRequest {
242
- Params :struct {
243
- Name string "json:\" name\" "
244
- Arguments map [string ]interface {}"json:\" arguments,omitempty\" "
245
- Meta * struct {
246
- ProgressToken mcp.ProgressToken "json:\" progressToken,omitempty\" "
247
- }"json:\" _meta,omitempty\" "
248
- }{
249
- Name :toolCall .Name ,
250
- Arguments :toolCall .Args ,
251
- },
252
- })
274
+ toolArgs ,err := json .Marshal (toolCall .Args )
275
+ if err != nil {
276
+ return nil
277
+ }
278
+ result ,err := tool (ctx ,deps ,toolArgs )
253
279
if err != nil {
254
280
return map [string ]any {
255
281
"error" :err .Error (),
256
282
}
257
283
}
258
- return result . Content
284
+ return result
259
285
}).WithAccumulator (& acc )
260
286
261
287
err = stream .Pipe (w )