Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitdcd7d87

Browse files
committed
feat: add MoonshotAI provider with Kimi-K2 model support
- Add MoonshotAIProvider with OpenAI-compatible API- Implements OpenAI-style interface with custom base URL- Supports tool definitions but not strict tool validation- Add moonshotai:kimi-k2-0711-preview as known model- Configure to use OpenAIModel for compatibility- Add comprehensive tests for provider functionality- Update CLI and model name tests
1 parent883e1ea commitdcd7d87

File tree

10 files changed

+204
-2
lines changed

10 files changed

+204
-2
lines changed

‎pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@
286286
'openai:o3-mini-2025-01-31',
287287
'openai:o4-mini',
288288
'openai:o4-mini-2025-04-16',
289+
'openai:computer-use-preview-2025-03-11',
290+
'moonshotai:kimi-k2-0711-preview',
289291
'test',
290292
],
291293
)
@@ -588,6 +590,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
588590
'azure',
589591
'openrouter',
590592
'grok',
593+
'moonshotai',
591594
'fireworks',
592595
'together',
593596
'heroku',

‎pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,16 @@ def __init__(
190190
model_name:OpenAIModelName,
191191
*,
192192
provider:Literal[
193-
'openai','deepseek','azure','openrouter','grok','fireworks','together','heroku','github'
193+
'openai',
194+
'deepseek',
195+
'azure',
196+
'openrouter',
197+
'grok',
198+
'moonshotai',
199+
'fireworks',
200+
'together',
201+
'heroku',
202+
'github',
194203
]
195204
|Provider[AsyncOpenAI]='openai',
196205
profile:ModelProfileSpec|None=None,
@@ -598,7 +607,18 @@ def __init__(
598607
self,
599608
model_name:OpenAIModelName,
600609
*,
601-
provider:Literal['openai','deepseek','azure','openrouter','grok','fireworks','together']
610+
provider:Literal[
611+
'openai',
612+
'deepseek',
613+
'azure',
614+
'openrouter',
615+
'grok',
616+
'moonshotai',
617+
'fireworks',
618+
'together',
619+
'heroku',
620+
'github',
621+
]
602622
|Provider[AsyncOpenAI]='openai',
603623
profile:ModelProfileSpec|None=None,
604624
settings:ModelSettings|None=None,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__importannotationsas_annotations
2+
3+
from .importModelProfile
4+
5+
6+
defmoonshotai_model_profile(model_name:str)->ModelProfile|None:
7+
"""Get the model profile for a MoonshotAI model."""
8+
returnNone

‎pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
9999
from .grokimportGrokProvider
100100

101101
returnGrokProvider
102+
elifprovider=='moonshotai':
103+
from .moonshotaiimportMoonshotAIProvider
104+
105+
returnMoonshotAIProvider
102106
elifprovider=='fireworks':
103107
from .fireworksimportFireworksProvider
104108

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__importannotationsas_annotations
2+
3+
importos
4+
fromtypingimportoverload
5+
6+
fromhttpximportAsyncClientasAsyncHTTPClient
7+
fromopenaiimportAsyncOpenAI
8+
9+
frompydantic_ai.exceptionsimportUserError
10+
frompydantic_ai.modelsimportcached_async_http_client
11+
frompydantic_ai.profilesimportModelProfile
12+
frompydantic_ai.profiles.moonshotaiimportmoonshotai_model_profile
13+
frompydantic_ai.profiles.openaiimport (
14+
OpenAIJsonSchemaTransformer,
15+
OpenAIModelProfile,
16+
)
17+
frompydantic_ai.providersimportProvider
18+
19+
20+
classMoonshotAIProvider(Provider[AsyncOpenAI]):
21+
"""Provider for MoonshotAI platform (Kimi models)."""
22+
23+
@property
24+
defname(self)->str:
25+
return'moonshotai'
26+
27+
@property
28+
defbase_url(self)->str:
29+
# OpenAI-compatible endpoint, see MoonshotAI docs
30+
return'https://api.moonshot.ai/v1'
31+
32+
@property
33+
defclient(self)->AsyncOpenAI:
34+
returnself._client
35+
36+
defmodel_profile(self,model_name:str)->ModelProfile|None:
37+
profile=moonshotai_model_profile(model_name)
38+
39+
# As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
40+
# unless json_schema_transformer is set explicitly.
41+
# Also, MoonshotAI does not support strict tool definitions
42+
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice
43+
# "Please note that the current version of Kimi API does not support the tool_choice=required parameter."
44+
returnOpenAIModelProfile(
45+
json_schema_transformer=OpenAIJsonSchemaTransformer,
46+
openai_supports_strict_tool_definition=False,
47+
).update(profile)
48+
49+
# ---------------------------------------------------------------------
50+
# Construction helpers
51+
# ---------------------------------------------------------------------
52+
@overload
53+
def__init__(self)->None: ...
54+
55+
@overload
56+
def__init__(self,*,api_key:str)->None: ...
57+
58+
@overload
59+
def__init__(self,*,api_key:str,http_client:AsyncHTTPClient)->None: ...
60+
61+
@overload
62+
def__init__(self,*,openai_client:AsyncOpenAI|None=None)->None: ...
63+
64+
def__init__(
65+
self,
66+
*,
67+
api_key:str|None=None,
68+
openai_client:AsyncOpenAI|None=None,
69+
http_client:AsyncHTTPClient|None=None,
70+
)->None:
71+
api_key=api_keyoros.getenv('MOONSHOT_API_KEY')
72+
ifnotapi_keyandopenai_clientisNone:
73+
raiseUserError(
74+
'Set the `MOONSHOT_API_KEY` environment variable or pass it via '
75+
'`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.'
76+
)
77+
78+
ifopenai_clientisnotNone:
79+
self._client=openai_client
80+
elifhttp_clientisnotNone:
81+
self._client=AsyncOpenAI(base_url=self.base_url,api_key=api_key,http_client=http_client)
82+
else:
83+
http_client=cached_async_http_client(provider='moonshotai')
84+
self._client=AsyncOpenAI(base_url=self.base_url,api_key=api_key,http_client=http_client)

‎tests/models/test_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@
7272
'github',
7373
'OpenAIModel',
7474
),
75+
(
76+
'MOONSHOT_API_KEY',
77+
'moonshotai:kimi-k2-0711-preview',
78+
'kimi-k2-0711-preview',
79+
'moonshotai',
80+
'openai',
81+
'OpenAIModel',
82+
),
7583
]
7684

7785

‎tests/models/test_model_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
4949
f'google-vertex:{n}'forninget_model_names(GeminiModelName)
5050
]
5151
groq_names= [f'groq:{n}'forninget_model_names(GroqModelName)]
52+
moonshotai_names= ['moonshotai:kimi-k2-0711-preview']
5253
mistral_names= [f'mistral:{n}'forninget_model_names(MistralModelName)]
5354
openai_names= [f'openai:{n}'forninget_model_names(OpenAIModelName)]+ [
5455
nforninget_model_names(OpenAIModelName)ifn.startswith('o1')orn.startswith('gpt')orn.startswith('o3')
@@ -64,6 +65,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]:
6465
+cohere_names
6566
+google_names
6667
+groq_names
68+
+moonshotai_names
6769
+mistral_names
6870
+openai_names
6971
+bedrock_names

‎tests/providers/test_moonshotai.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
importre
2+
3+
importhttpx
4+
importpytest
5+
6+
frompydantic_ai.exceptionsimportUserError
7+
frompydantic_ai.profiles.openaiimportOpenAIJsonSchemaTransformer,OpenAIModelProfile
8+
9+
from ..conftestimportTestEnv,try_import
10+
11+
withtry_import()asimports_successful:
12+
importopenai
13+
14+
frompydantic_ai.models.openaiimportOpenAIModel
15+
frompydantic_ai.providers.moonshotaiimportMoonshotAIProvider
16+
17+
pytestmark=pytest.mark.skipif(notimports_successful(),reason='openai not installed')
18+
19+
20+
deftest_moonshotai_provider():
21+
"""Test basic MoonshotAI provider initialization."""
22+
provider=MoonshotAIProvider(api_key='api-key')
23+
assertprovider.name=='moonshotai'
24+
assertprovider.base_url=='https://api.moonshot.ai/v1'
25+
assertisinstance(provider.client,openai.AsyncOpenAI)
26+
assertprovider.client.api_key=='api-key'
27+
28+
29+
deftest_moonshotai_provider_need_api_key(env:TestEnv)->None:
30+
"""Test that MoonshotAI provider requires an API key."""
31+
env.remove('MOONSHOT_API_KEY')
32+
withpytest.raises(
33+
UserError,
34+
match=re.escape(
35+
'Set the `MOONSHOT_API_KEY` environment variable or pass it via `MoonshotAIProvider(api_key=...)`'
36+
' to use the MoonshotAI provider.'
37+
),
38+
):
39+
MoonshotAIProvider()
40+
41+
42+
deftest_moonshotai_provider_pass_http_client()->None:
43+
"""Test passing a custom HTTP client to MoonshotAI provider."""
44+
http_client=httpx.AsyncClient()
45+
provider=MoonshotAIProvider(http_client=http_client,api_key='api-key')
46+
assertprovider.client._client==http_client# type: ignore[reportPrivateUsage]
47+
48+
49+
deftest_moonshotai_pass_openai_client()->None:
50+
"""Test passing a custom OpenAI client to MoonshotAI provider."""
51+
openai_client=openai.AsyncOpenAI(api_key='api-key')
52+
provider=MoonshotAIProvider(openai_client=openai_client)
53+
assertprovider.client==openai_client
54+
55+
56+
deftest_moonshotai_provider_with_cached_http_client()->None:
57+
"""Test MoonshotAI provider using cached HTTP client (covers line 76)."""
58+
# This should use the else branch with cached_async_http_client
59+
provider=MoonshotAIProvider(api_key='api-key')
60+
assertisinstance(provider.client,openai.AsyncOpenAI)
61+
assertprovider.client.api_key=='api-key'
62+
63+
64+
deftest_moonshotai_model_profile():
65+
provider=MoonshotAIProvider(api_key='api-key')
66+
model=OpenAIModel('kimi-k2-0711-preview',provider=provider)
67+
assertisinstance(model.profile,OpenAIModelProfile)
68+
assertmodel.profile.json_schema_transformer==OpenAIJsonSchemaTransformer
69+
assertmodel.profile.openai_supports_strict_tool_definitionisFalse

‎tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
frompydantic_ai.providers.groqimportGroqProvider
2727
frompydantic_ai.providers.herokuimportHerokuProvider
2828
frompydantic_ai.providers.mistralimportMistralProvider
29+
frompydantic_ai.providers.moonshotaiimportMoonshotAIProvider
2930
frompydantic_ai.providers.openaiimportOpenAIProvider
3031
frompydantic_ai.providers.openrouterimportOpenRouterProvider
3132
frompydantic_ai.providers.togetherimportTogetherProvider
@@ -42,6 +43,7 @@
4243
('groq',GroqProvider,'GROQ_API_KEY'),
4344
('mistral',MistralProvider,'MISTRAL_API_KEY'),
4445
('grok',GrokProvider,'GROK_API_KEY'),
46+
('moonshotai',MoonshotAIProvider,'MOONSHOT_API_KEY'),
4547
('fireworks',FireworksProvider,'FIREWORKS_API_KEY'),
4648
('together',TogetherProvider,'TOGETHER_API_KEY'),
4749
('heroku',HerokuProvider,'HEROKU_INFERENCE_KEY'),

‎tests/test_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def test_list_models(capfd: CaptureFixture[str]):
144144
'cohere',
145145
'deepseek',
146146
'heroku',
147+
'grok',
148+
'moonshotai',
147149
'huggingface',
148150
)
149151
models= {line.strip().split(' ')[0]forlineinoutput[3:]}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp