@@ -2,8 +2,140 @@ package ai
2
2
3
3
import (
4
4
"context"
5
+ "fmt"
5
6
7
+ "github.com/anthropics/anthropic-sdk-go"
8
+ anthropicoption"github.com/anthropics/anthropic-sdk-go/option"
9
+ "github.com/coder/coder/v2/codersdk"
6
10
"github.com/kylecarbs/aisdk-go"
11
+ "github.com/openai/openai-go"
12
+ openaioption"github.com/openai/openai-go/option"
13
+ "google.golang.org/genai"
7
14
)
8
15
9
- type Provider func (ctx context.Context ,messages []aisdk.Message ) (aisdk.DataStream ,error )
16
+ type LanguageModel struct {
17
+ codersdk.LanguageModel
18
+ StreamFunc StreamFunc
19
+ }
20
+
21
+ type StreamOptions struct {
22
+ Model string
23
+ Messages []aisdk.Message
24
+ Thinking bool
25
+ Tools []aisdk.Tool
26
+ }
27
+
28
+ type StreamFunc func (ctx context.Context ,options StreamOptions ) (aisdk.DataStream ,error )
29
+
30
+ // LanguageModels is a map of language model ID to language model.
31
+ type LanguageModels map [string ]LanguageModel
32
+
33
+ func ModelsFromConfig (ctx context.Context ,configs []codersdk.AIProviderConfig ) (LanguageModels ,error ) {
34
+ models := make (LanguageModels )
35
+
36
+ for _ ,config := range configs {
37
+ var streamFunc StreamFunc
38
+
39
+ switch config .Type {
40
+ case "openai" :
41
+ client := openai .NewClient (openaioption .WithAPIKey (config .APIKey ))
42
+ streamFunc = func (ctx context.Context ,options StreamOptions ) (aisdk.DataStream ,error ) {
43
+ openaiMessages ,err := aisdk .MessagesToOpenAI (options .Messages )
44
+ if err != nil {
45
+ return nil ,err
46
+ }
47
+ tools := aisdk .ToolsToOpenAI (options .Tools )
48
+ return aisdk .OpenAIToDataStream (client .Chat .Completions .NewStreaming (ctx , openai.ChatCompletionNewParams {
49
+ Messages :openaiMessages ,
50
+ Model :options .Model ,
51
+ Tools :tools ,
52
+ MaxTokens :openai .Int (8192 ),
53
+ })),nil
54
+ }
55
+ if config .Models == nil {
56
+ models ,err := client .Models .List (ctx )
57
+ if err != nil {
58
+ return nil ,err
59
+ }
60
+ config .Models = make ([]string ,len (models .Data ))
61
+ for i ,model := range models .Data {
62
+ config .Models [i ]= model .ID
63
+ }
64
+ }
65
+ break
66
+ case "anthropic" :
67
+ client := anthropic .NewClient (anthropicoption .WithAPIKey (config .APIKey ))
68
+ streamFunc = func (ctx context.Context ,options StreamOptions ) (aisdk.DataStream ,error ) {
69
+ anthropicMessages ,systemMessage ,err := aisdk .MessagesToAnthropic (options .Messages )
70
+ if err != nil {
71
+ return nil ,err
72
+ }
73
+ return aisdk .AnthropicToDataStream (client .Messages .NewStreaming (ctx , anthropic.MessageNewParams {
74
+ Messages :anthropicMessages ,
75
+ Model :options .Model ,
76
+ System :systemMessage ,
77
+ Tools :aisdk .ToolsToAnthropic (options .Tools ),
78
+ MaxTokens :8192 ,
79
+ })),nil
80
+ }
81
+ if config .Models == nil {
82
+ models ,err := client .Models .List (ctx , anthropic.ModelListParams {})
83
+ if err != nil {
84
+ return nil ,err
85
+ }
86
+ config .Models = make ([]string ,len (models .Data ))
87
+ for i ,model := range models .Data {
88
+ config .Models [i ]= model .ID
89
+ }
90
+ }
91
+ break
92
+ case "google" :
93
+ client ,err := genai .NewClient (ctx ,& genai.ClientConfig {
94
+ APIKey :config .APIKey ,
95
+ Backend :genai .BackendGeminiAPI ,
96
+ })
97
+ if err != nil {
98
+ return nil ,err
99
+ }
100
+ streamFunc = func (ctx context.Context ,options StreamOptions ) (aisdk.DataStream ,error ) {
101
+ googleMessages ,err := aisdk .MessagesToGoogle (options .Messages )
102
+ if err != nil {
103
+ return nil ,err
104
+ }
105
+ tools ,err := aisdk .ToolsToGoogle (options .Tools )
106
+ if err != nil {
107
+ return nil ,err
108
+ }
109
+ return aisdk .GoogleToDataStream (client .Models .GenerateContentStream (ctx ,options .Model ,googleMessages ,& genai.GenerateContentConfig {
110
+ Tools :tools ,
111
+ })),nil
112
+ }
113
+ if config .Models == nil {
114
+ models ,err := client .Models .List (ctx ,& genai.ListModelsConfig {})
115
+ if err != nil {
116
+ return nil ,err
117
+ }
118
+ config .Models = make ([]string ,len (models .Items ))
119
+ for i ,model := range models .Items {
120
+ config .Models [i ]= model .Name
121
+ }
122
+ }
123
+ break
124
+ default :
125
+ return nil ,fmt .Errorf ("unsupported model type: %s" ,config .Type )
126
+ }
127
+
128
+ for _ ,model := range config .Models {
129
+ models [model ]= LanguageModel {
130
+ LanguageModel : codersdk.LanguageModel {
131
+ ID :model ,
132
+ DisplayName :model ,
133
+ Provider :config .Type ,
134
+ },
135
+ StreamFunc :streamFunc ,
136
+ }
137
+ }
138
+ }
139
+
140
+ return models ,nil
141
+ }