1
1
package coderd
2
2
3
3
import (
4
+ "context"
5
+ "fmt"
4
6
"net/http"
5
7
8
+ "github.com/google/uuid"
9
+ "golang.org/x/xerrors"
10
+
6
11
"github.com/coder/coder/v2/coderd/database"
12
+ "github.com/coder/coder/v2/coderd/database/db2sdk"
13
+ "github.com/coder/coder/v2/coderd/database/dbauthz"
7
14
"github.com/coder/coder/v2/coderd/httpapi"
8
- "github.com/coder/coder/v2/coderd/rbac "
9
- "github.com/coder/coder/v2/coderd/rbac/policy "
15
+ "github.com/coder/coder/v2/coderd/httpmw "
16
+ "github.com/coder/coder/v2/coderd/searchquery "
10
17
"github.com/coder/coder/v2/codersdk"
11
18
)
12
19
@@ -15,50 +22,50 @@ const (
15
22
defaultListInterceptionsLimit = 100
16
23
)
17
24
25
+ // aiBridgeListInterceptions returns all AIBridge interceptions a user can read.
26
+ // Optional filters with query params
27
+ //
28
+ // @Summary List AIBridge interceptions
29
+ // @ID list-aibridge-interceptions
30
+ // @Security CoderSessionToken
31
+ // @Produce json
32
+ // @Tags AIBridge
33
+ // @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before."
34
+ // @Param limit query int false "Page limit"
35
+ // @Param offset query int false "Page offset"
36
+ // @Success 200 {object} codersdk.AIBridgeListInterceptionsResponse
37
+ // @Router /api/experimental/aibridge/interceptions [get]
18
38
func (api * API )aiBridgeListInterceptions (rw http.ResponseWriter ,r * http.Request ) {
19
- if ! api .Authorize (r ,policy .ActionRead ,rbac .ResourceAibridgeInterception ) {
20
- httpapi .Forbidden (rw )
21
- return
22
- }
23
-
24
39
ctx := r .Context ()
25
- var req codersdk.AIBridgeListInterceptionsRequest
26
- if ! httpapi .Read (ctx ,rw ,r ,& req ) {
40
+ apiKey := httpmw .APIKey (r )
41
+
42
+ page ,ok := ParsePagination (rw ,r )
43
+ if ! ok {
27
44
return
28
45
}
29
-
30
- if ! req .PeriodStart .IsZero ()&& ! req .PeriodEnd .IsZero ()&& req .PeriodEnd .Before (req .PeriodStart ) {
46
+ if page .Limit == 0 {
47
+ page .Limit = defaultListInterceptionsLimit
48
+ }
49
+ if page .Limit > maxListInterceptionsLimit || page .Limit < 1 {
31
50
httpapi .Write (ctx ,rw ,http .StatusBadRequest , codersdk.Response {
32
- Message :"Invalidtime frame ." ,
33
- Detail :"End of the search period must bebefore start." ,
51
+ Message :"Invalidpagination limit value ." ,
52
+ Detail :fmt . Sprintf ( "Pagination limit must bein range (0, %d]" , maxListInterceptionsLimit ) ,
34
53
})
35
54
return
36
55
}
37
56
38
- if req .Limit == 0 {
39
- req .Limit = defaultListInterceptionsLimit
40
- }
41
-
42
- if req .Limit > maxListInterceptionsLimit || req .Limit < 1 {
57
+ queryStr := r .URL .Query ().Get ("q" )
58
+ filter ,errs := searchquery .AIBridgeInterceptions (ctx ,api .Database ,queryStr ,page ,apiKey .UserID )
59
+ if len (errs )> 0 {
43
60
httpapi .Write (ctx ,rw ,http .StatusBadRequest , codersdk.Response {
44
- Message :"Invalidlimit value ." ,
45
- Detail : "Limit value must be in range <1, 1000>" ,
61
+ Message :"Invalidworkspace search query ." ,
62
+ Validations : errs ,
46
63
})
47
64
return
48
65
}
49
66
50
- // Database returns one row for each tuple (interception, tool, prompt).
51
- // Right now there is a single promp per interception although model allows multiple so this could change in the future.
52
- // There can be multiple tools used in single interception.
53
- // Results are ordered by Interception.StartedAt, Interception.ID, Tool.CreatedAt
54
- rows ,err := api .Database .ListAIBridgeInterceptions (ctx , database.ListAIBridgeInterceptionsParams {
55
- PeriodStart :req .PeriodStart ,
56
- PeriodEnd :req .PeriodEnd ,
57
- CursorTime :req .Cursor .Time ,
58
- CursorID :req .Cursor .ID ,
59
- InitiatorID :req .InitiatorID ,
60
- LimitOpt :req .Limit ,
61
- })
67
+ // This only returns authorized interceptions (when using dbauthz).
68
+ rows ,err := api .Database .ListAIBridgeInterceptions (ctx ,filter )
62
69
if err != nil {
63
70
httpapi .Write (ctx ,rw ,http .StatusInternalServerError , codersdk.Response {
64
71
Message :"Internal error getting AIBridge interceptions." ,
@@ -67,50 +74,61 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques
67
74
return
68
75
}
69
76
70
- resp := prepareResponse (rows )
71
- httpapi .Write (ctx ,rw ,http .StatusOK ,resp )
77
+ // This fetches the other rows associated with the interceptions.
78
+ items ,err := populatedAndConvertAIBridgeInterceptions (ctx ,api .Database ,rows )
79
+ if err != nil {
80
+ httpapi .Write (ctx ,rw ,http .StatusInternalServerError , codersdk.Response {
81
+ Message :"Internal error converting database rows to API response." ,
82
+ Detail :err .Error (),
83
+ })
84
+ return
85
+ }
86
+
87
+ httpapi .Write (ctx ,rw ,http .StatusOK , codersdk.AIBridgeListInterceptionsResponse {
88
+ Results :items ,
89
+ })
72
90
}
73
91
74
- func prepareResponse (rows []database.ListAIBridgeInterceptionsRow ) codersdk.AIBridgeListInterceptionsResponse {
75
- resp := codersdk.AIBridgeListInterceptionsResponse {
76
- Results : []codersdk.AIBridgeListInterceptionsResult {},
92
+ func populatedAndConvertAIBridgeInterceptions (ctx context.Context ,db database.Store ,rows []database.AIBridgeInterception ) ([]codersdk.AIBridgeInterception ,error ) {
93
+ ids := make ([]uuid.UUID ,len (rows ))
94
+ for i ,row := range rows {
95
+ ids [i ]= row .ID
96
+ }
97
+
98
+ //nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
99
+ tokenUsagesRows ,err := db .ListAIBridgeTokenUsagesByInterceptionIDs (dbauthz .AsSystemRestricted (ctx ),ids )
100
+ if err != nil {
101
+ return nil ,xerrors .Errorf ("get linked aibridge token usages from database: %w" ,err )
102
+ }
103
+ tokenUsagesMap := make (map [uuid.UUID ][]database.AIBridgeTokenUsage )
104
+ for _ ,row := range tokenUsagesRows {
105
+ tokenUsagesMap [row .InterceptionID ]= append (tokenUsagesMap [row .InterceptionID ],row )
106
+ }
107
+
108
+ //nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
109
+ userPromptRows ,err := db .ListAIBridgeUserPromptsByInterceptionIDs (dbauthz .AsSystemRestricted (ctx ),ids )
110
+ if err != nil {
111
+ return nil ,xerrors .Errorf ("get linked aibridge user prompts from database: %w" ,err )
112
+ }
113
+ userPromptsMap := make (map [uuid.UUID ][]database.AIBridgeUserPrompt )
114
+ for _ ,row := range userPromptRows {
115
+ userPromptsMap [row .InterceptionID ]= append (userPromptsMap [row .InterceptionID ],row )
77
116
}
78
117
79
- if len (rows )> 0 {
80
- resp .Cursor .ID = rows [len (rows )- 1 ].ID
81
- resp .Cursor .Time = rows [len (rows )- 1 ].StartedAt .UTC ()
118
+ //nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AIBridge interception subresources use the same authorization call as their parent.
119
+ toolUsagesRows ,err := db .ListAIBridgeToolUsagesByInterceptionIDs (dbauthz .AsSystemRestricted (ctx ),ids )
120
+ if err != nil {
121
+ return nil ,xerrors .Errorf ("get linked aibridge tool usages from database: %w" ,err )
122
+ }
123
+ toolUsagesMap := make (map [uuid.UUID ][]database.AIBridgeToolUsage )
124
+ for _ ,row := range toolUsagesRows {
125
+ toolUsagesMap [row .InterceptionID ]= append (toolUsagesMap [row .InterceptionID ],row )
82
126
}
83
127
84
- for i := 0 ;i < len (rows ); {
85
- row := rows [i ]
86
- row .StartedAt = row .StartedAt .UTC ()
87
-
88
- result := codersdk.AIBridgeListInterceptionsResult {
89
- InterceptionID :row .ID ,
90
- UserID :row .InitiatorID ,
91
- Provider :row .Provider ,
92
- Model :row .Model ,
93
- Prompt :row .Prompt .String ,
94
- StartedAt :row .StartedAt ,
95
- Tokens : codersdk.AIBridgeListInterceptionsTokens {
96
- Input :row .InputTokens ,
97
- Output :row .OutputTokens ,
98
- },
99
- Tools : []codersdk.AIBridgeListInterceptionsTool {},
100
- }
101
-
102
- interceptionID := row .ID
103
- for ;i < len (rows )&& interceptionID == rows [i ].ID ;i ++ {
104
- if rows [i ].ServerUrl .Valid || rows [i ].Tool .Valid || rows [i ].Input .Valid {
105
- result .Tools = append (result .Tools , codersdk.AIBridgeListInterceptionsTool {
106
- Server :rows [i ].ServerUrl .String ,
107
- Tool :rows [i ].Tool .String ,
108
- Input :rows [i ].Input .String ,
109
- })
110
- }
111
- }
112
-
113
- resp .Results = append (resp .Results ,result )
128
+ items := make ([]codersdk.AIBridgeInterception ,len (rows ))
129
+ for i ,row := range rows {
130
+ items [i ]= db2sdk .AIBridgeInterception (row ,tokenUsagesMap [row .ID ],userPromptsMap [row .ID ],toolUsagesMap [row .ID ])
114
131
}
115
- return resp
132
+
133
+ return items ,nil
116
134
}