Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Addtool_choice toModelSettings#825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
webcoderz wants to merge40 commits intopydantic:main
base:main
Choose a base branch
Loading
fromwebcoderz:main
Draft
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
40 commits
Select commitHold shift + click to select a range
3777d06
Update openai.py
webcoderzJan 31, 2025
6746781
Update pyproject.toml
webcoderzJan 31, 2025
7d2c2ef
adding to model settings , removing monkeypatch
webcoderzJan 31, 2025
2567bbc
Update openai.py
webcoderzJan 31, 2025
4bc72f9
Update settings.py
webcoderzFeb 4, 2025
ac10ee1
backing this out
webcoderzFeb 6, 2025
ce38756
Update pyproject.toml
webcoderzFeb 6, 2025
fe341b1
Merge branch 'main' into webcoderz-model-settings
webcoderzFeb 7, 2025
12015c9
Update groq.py
webcoderzFeb 9, 2025
6dd9987
removing fallback comment
webcoderzFeb 12, 2025
3250aff
adding as per reccomendation
webcoderzFeb 12, 2025
a0b7454
removing tool_choice from ModelSettings and placing in each individu…
webcoderzFeb 12, 2025
79acaf3
the conditional checking tool_choice was not evaluating when i added …
webcoderzFeb 12, 2025
67e6ac3
adding _get_tool_choice to groq,cohere, openai
webcoderzFeb 12, 2025
9aa905c
unsure if these are necessary since seem supported already in mistral…
webcoderzFeb 12, 2025
9349412
fixing tool_choice across all models
webcoderzFeb 21, 2025
1bd0cf3
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderzFeb 22, 2025
a03fcfe
moving to top level settings
webcoderzFeb 24, 2025
89946a9
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderzFeb 24, 2025
20d8c8c
Merge branch 'webcoderz-model-settings'
webcoderzFeb 24, 2025
396b89c
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderzFeb 25, 2025
61b360b
Merge branch 'webcoderz-model-settings'
webcoderzFeb 25, 2025
e6df5fb
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderzFeb 28, 2025
1e4b0c5
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderzMar 5, 2025
1419b5a
Merge branch 'webcoderz-model-settings'
webcoderzMar 5, 2025
8b6f102
fixing ChatCompletionNamedToolChoiceParam
webcoderzMar 5, 2025
1af96f3
Merge branch 'webcoderz-model-settings'
webcoderzMar 5, 2025
0f47da9
Update cohere.py
webcoderzMar 5, 2025
a23c014
Update openai.py
webcoderzMar 5, 2025
9cefe9e
Update openai.py
webcoderzMar 5, 2025
445c7ba
Update openai.py
webcoderzMar 5, 2025
6ac1857
Merge remote-tracking branch 'origin/main' into webcoderz/main
KludexMar 7, 2025
2e238a2
Refactor
KludexMar 7, 2025
a86f5ce
Add Anthropic
KludexMar 7, 2025
9159aeb
full implementation
KludexMar 7, 2025
5cb233b
merge
KludexMar 31, 2025
059f92e
Merge remote-tracking branch 'origin/main' into webcoderz/main
KludexApr 7, 2025
30560b5
Merge remote-tracking branch 'origin/main' into webcoderz/main
KludexApr 15, 2025
ae1ab1c
Make GeminiModelSettings total=False
KludexApr 15, 2025
32cf6a6
Check safety settings on gemini properly
KludexApr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletionspydantic_ai_slim/pydantic_ai/models/anthropic.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -29,7 +29,7 @@
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..settings importForcedFunctionToolChoice,ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
Expand DownExpand Up@@ -209,19 +209,7 @@ async def _messages_create(
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters)
tool_choice: ToolChoiceParam | None

if not tools:
tool_choice = None
else:
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls

tool_choice = self._map_tool_choice(model_settings, model_request_parameters, tools)
system_prompt, anthropic_messages = await self._map_message(messages)

try:
Expand DownExpand Up@@ -281,6 +269,32 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

@staticmethod
def _map_tool_choice(
model_settings: AnthropicModelSettings,
model_request_parameters: ModelRequestParameters,
tools: list[ToolParam],
) -> ToolChoiceParam | None:
"""Determine the `tool_choice` setting for the model.

Anthropic only supports `'auto'`, `'any'`, `'none'`, and a named tool.
"""
tool_choice = model_settings.get('tool_choice', 'auto')
disable_parallel_tool_use = not model_settings.get('parallel_tool_calls', True)

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return {'type': 'any', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'required':
return {'type': 'any', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'auto':
return {'type': 'auto', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'none':
return {'type': 'none'}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'type': 'tool', 'name': tool_choice.name, 'disable_parallel_tool_use': disable_parallel_tool_use}
else:
assert_never(tool_choice)

async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
system_prompt: str = ''
Expand Down
55 changes: 34 additions & 21 deletionspydantic_ai_slim/pydantic_ai/models/bedrock.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -33,7 +33,7 @@
)
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.settings import ModelSettings
from pydantic_ai.settings importForcedFunctionToolChoice,ModelSettings
from pydantic_ai.tools import ToolDefinition

if TYPE_CHECKING:
Expand All@@ -54,7 +54,7 @@
PerformanceConfigurationTypeDef,
PromptVariableValuesTypeDef,
SystemContentBlockTypeDef,
ToolChoiceTypeDef,
ToolConfigurationTypeDef,
ToolTypeDef,
VideoBlockTypeDef,
)
Expand DownExpand Up@@ -275,36 +275,28 @@ async def _messages_create(
self,
messages: list[ModelMessage],
stream: Literal[True],
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> EventStream[ConverseStreamOutputTypeDef]:
pass
) -> EventStream[ConverseStreamOutputTypeDef]: ...

@overload
async def _messages_create(
self,
messages: list[ModelMessage],
stream: Literal[False],
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ConverseResponseTypeDef:
pass
) -> ConverseResponseTypeDef: ...

async def _messages_create(
self,
messages: list[ModelMessage],
stream: bool,
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
tools = self._get_tools(model_request_parameters)
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
if not tools or not support_tools_choice:
tool_choice: ToolChoiceTypeDef = {}
elif not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}
tool_config = self._get_tool_config(model_settings, model_request_parameters, tools)

system_prompt, bedrock_messages = await self._map_messages(messages)
inference_config = self._map_inference_config(model_settings)
Expand All@@ -315,6 +307,8 @@ async def _messages_create(
'system': system_prompt,
'inferenceConfig': inference_config,
}
if tool_config:
params['toolConfig'] = tool_config

# Bedrock supports a set of specific extra parameters
if model_settings:
Expand All@@ -333,18 +327,37 @@ async def _messages_create(
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
params['promptVariables'] = prompt_variables

if tools:
params['toolConfig'] = {'tools': tools}
if tool_choice:
params['toolConfig']['toolChoice'] = tool_choice

if stream:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
model_response = model_response['stream']
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
return model_response

def _get_tool_config(
self,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
tools: list[ToolTypeDef],
) -> ToolConfigurationTypeDef | None:
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
tool_choice = model_settings.get('tool_choice', 'auto')

if not tools or not support_tools_choice:
return None
elif tool_choice == 'auto' and not model_request_parameters.allow_text_output:
return {'tools': tools, 'toolChoice': {'any': {}}}
elif tool_choice == 'auto':
return {'tools': tools, 'toolChoice': {'auto': {}}}
elif tool_choice == 'none':
return None
elif tool_choice == 'required':
return {'tools': tools, 'toolChoice': {'any': {}}}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'tools': tools, 'toolChoice': {'tool': {'name': tool_choice.name}}}
else:
assert_never(tool_choice)

@staticmethod
def _map_inference_config(
model_settings: ModelSettings | None,
Expand Down
38 changes: 32 additions & 6 deletionspydantic_ai_slim/pydantic_ai/models/cohere.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -7,6 +7,8 @@

from typing_extensions import assert_never

from pydantic_ai.exceptions import UserError

from .. import ModelHTTPError, usage
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
from ..messages import (
Expand All@@ -22,13 +24,9 @@
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..settings importForcedFunctionToolChoice,ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
ModelRequestParameters,
check_allow_model_requests,
)
from . import Model, ModelRequestParameters, check_allow_model_requests

try:
from cohere import (
Expand All@@ -44,6 +42,7 @@
ToolV2,
ToolV2Function,
UserChatMessageV2,
V2ChatRequestToolChoice,
)
from cohere.core.api_error import ApiError
from cohere.v2.client import OMIT
Expand DownExpand Up@@ -156,12 +155,14 @@ async def _chat(
model_request_parameters: ModelRequestParameters,
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
tool_choice = self._map_tool_choice(model_settings, model_request_parameters, tools)
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
try:
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
tool_choice=tool_choice or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
stop_sequences=model_settings.get('stop_sequences', OMIT),
temperature=model_settings.get('temperature', OMIT),
Expand DownExpand Up@@ -223,6 +224,31 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

@staticmethod
def _map_tool_choice(
model_settings: CohereModelSettings, model_request_parameters: ModelRequestParameters, tools: list[ToolV2]
) -> V2ChatRequestToolChoice | None:
"""Determine the `tool_choice` setting for the model.

Cohere only supports `'REQUIRED'` and `'NONE'` for tool choice.
See [Cohere's docs](https://docs.cohere.com/v2/docs/tool-use-usage-patterns#forcing-tool-usage) for more details.
"""
tool_choice = model_settings.get('tool_choice', 'auto')

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return 'REQUIRED'
elif tool_choice == 'auto':
return None
elif isinstance(tool_choice, ForcedFunctionToolChoice):
raise UserError(
'Cohere does not support forcing a specific tool. '
'Please choose a different value for the `tool_choice` parameter in the model settings.'
)
elif tool_choice in ('none', 'required'):
return tool_choice.upper()
else:
assert_never(tool_choice)

@staticmethod
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
return ToolCallV2(
Expand Down
55 changes: 39 additions & 16 deletionspydantic_ai_slim/pydantic_ai/models/gemini.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -36,7 +36,7 @@
UserPromptPart,
VideoUrl,
)
from ..settings import ModelSettings
from ..settings importForcedFunctionToolChoice,ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
Expand DownExpand Up@@ -72,7 +72,7 @@
"""


class GeminiModelSettings(ModelSettings):
class GeminiModelSettings(ModelSettings, total=False):
"""Settings used for a Gemini model request.

ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
Expand DownExpand Up@@ -180,15 +180,35 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
return _GeminiTools(function_declarations=tools) if tools else None

@staticmethod
def _get_tool_config(
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
tools: _GeminiTools | None,
) -> _GeminiToolConfig | None:
if model_request_parameters.allow_text_output:
"""Determine the `tool_choice` setting for the model.

AUTO: The default model behavior. The model decides to predict either a function call or a natural language response.
ANY: The model is constrained to always predict a function call. If allowed_function_names is not provided,
the model picks from all of the available function declarations. If allowed_function_names is provided,
the model picks from the set of allowed functions.
NONE: The model won't predict a function call. In this case, the model behavior is the same as if you don't
pass any function declarations.
"""
tool_choice = model_settings.get('tool_choice', 'auto')

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return {'function_calling_config': {'mode': 'ANY'}}
elif tool_choice == 'auto':
return None
elif tools:
return _tool_config([t['name'] for t in tools['function_declarations']])
elif tool_choice == 'none':
return {'function_calling_config': {'mode': 'NONE'}}
elif tool_choice == 'required':
return {'function_calling_config': {'mode': 'ANY'}}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'function_calling_config': {'mode': 'ANY', 'allowed_function_names': [tool_choice.name]}}
else:
return _tool_config([])
assert_never(tool_choice)

@asynccontextmanager
async def _make_request(
Expand All@@ -199,7 +219,7 @@ async def _make_request(
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
tool_config = self._get_tool_config(model_settings,model_request_parameters, tools)
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)

request_data = _GeminiRequest(contents=contents)
Expand All@@ -222,7 +242,7 @@ async def _make_request(
generation_config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
generation_config['frequency_penalty'] = frequency_penalty
if(gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
if gemini_safety_settings := model_settings.get('gemini_safety_settings'):
request_data['safety_settings'] = gemini_safety_settings
if generation_config:
request_data['generation_config'] = generation_config
Expand DownExpand Up@@ -666,15 +686,18 @@ class _GeminiToolConfig(TypedDict):
function_calling_config: _GeminiFunctionCallingConfig


def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
return _GeminiToolConfig(
function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
)
class _GeminiFunctionCallingConfig(TypedDict):
"""The function calling config for the Gemini API.

See <https://ai.google.dev/gemini-api/docs/function-calling>
"""

class _GeminiFunctionCallingConfig(TypedDict):
mode: Literal['ANY', 'AUTO']
allowed_function_names: list[str]
mode: Literal['ANY', 'AUTO', 'NONE']
allowed_function_names: NotRequired[list[str]]
"""If not provided, all functions are allowed.

It can only be used with `mode` set to `'ANY'`.
"""


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
Expand Down
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp