Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

feat(toolbox-langchain): Support per-invocation auth viaRunnableConfig#291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
anubhav756 wants to merge4 commits intoanubhav-state-li
base:anubhav-state-li
Choose a base branch
Loading
fromanubhav-self-auth-tools
Draft
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,6 +15,7 @@
from typing import Any, Callable, Union

from deprecated import deprecated
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from toolbox_core.tool import ToolboxTool as ToolboxCoreTool
from toolbox_core.utils import params_to_pydantic_model
Expand DownExpand Up@@ -52,7 +53,11 @@ def __init__(
def _run(self, **kwargs: Any) -> str:
raise NotImplementedError("Synchronous methods not supported by async tools.")

async def _arun(self, **kwargs: Any) -> str:
async def _arun(
self,
config: RunnableConfig,
**kwargs: Any,
) -> str:
"""
The coroutine that invokes the tool with the given arguments.

Expand All@@ -63,7 +68,33 @@ async def _arun(self, **kwargs: Any) -> str:
A dictionary containing the parsed JSON response from the tool
invocation.
"""
return await self.__core_tool(**kwargs)
tool_to_run = self.__core_tool
if (
config
and "configurable" in config
and "auth_token_getters" in config["configurable"]
):
auth_token_getters = config["configurable"]["auth_token_getters"]
if auth_token_getters:

# The `add_auth_token_getters` method requires that all provided
# getters are used by the tool. To prevent validation errors,
# filter the incoming getters to include only those that this
# specific tool requires.
req_auth_services = set(self.__core_tool._required_authz_tokens)
for auth_list in self.__core_tool._required_authn_params.values():
req_auth_services.update(auth_list)
filtered_getters = {
k: v
for k, v in auth_token_getters.items()
if k in req_auth_services
}
if filtered_getters:
tool_to_run = self.__core_tool.add_auth_token_getters(
filtered_getters
)

return await tool_to_run(**kwargs)

def add_auth_token_getters(
self, auth_token_getters: dict[str, Callable[[], str]]
Expand Down
45 changes: 40 additions & 5 deletionspackages/toolbox-langchain/src/toolbox_langchain/tools.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -16,6 +16,7 @@
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union

from deprecated import deprecated
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
from toolbox_core.utils import params_to_pydantic_model
Expand DownExpand Up@@ -73,11 +74,45 @@ def _client_headers(
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
return self.__core_tool._client_headers

def _run(self, **kwargs: Any) -> str:
return self.__core_tool(**kwargs)

async def _arun(self, **kwargs: Any) -> str:
return await to_thread(self.__core_tool, **kwargs)
def __get_tool_to_run(self, config: RunnableConfig) -> ToolboxCoreSyncTool:
tool_to_run = self.__core_tool
if (
config
and "configurable" in config
and "auth_token_getters" in config["configurable"]
):
auth_token_getters = config["configurable"]["auth_token_getters"]
if auth_token_getters:

# The `add_auth_token_getters` method requires that all provided
# getters are used by the tool. To prevent validation errors,
# filter the incoming getters to include only those that this
# specific tool requires.
req_auth_services = set(self.__core_tool._required_authz_tokens)
for auth_list in self.__core_tool._required_authn_params.values():
req_auth_services.update(auth_list)
filtered_getters = {
k: v
for k, v in auth_token_getters.items()
if k in req_auth_services
}
if filtered_getters:
tool_to_run = self.__core_tool.add_auth_token_getters(
filtered_getters
)
return tool_to_run

def _run(
self,
config: RunnableConfig,
**kwargs: Any,
) -> str:
tool_to_run = self.__get_tool_to_run(config)
return tool_to_run(**kwargs)

async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str:
tool_to_run = self.__get_tool_to_run(config)
return await to_thread(tool_to_run, **kwargs)

def add_auth_token_getters(
self, auth_token_getters: dict[str, Callable[[], str]]
Expand Down
4 changes: 2 additions & 2 deletionspackages/toolbox-langchain/tests/test_tools.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -286,7 +286,7 @@ def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool):
expected_result = "sync_run_output"
mock_core_tool.return_value = expected_result

result = toolbox_tool._run(**kwargs_to_run)
result = toolbox_tool._run(**kwargs_to_run, config={})

assert result == expected_result
assert mock_core_tool.call_count == 1
Expand All@@ -307,7 +307,7 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func):

mock_to_thread_in_tools.side_effect = to_thread_side_effect

result = await toolbox_tool._arun(**kwargs_to_run)
result = await toolbox_tool._arun(**kwargs_to_run, config={})

assert result == expected_result
mock_to_thread_in_tools.assert_awaited_once_with(
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp