Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commiteb33f3f

Browse files
committed
WIP: Temporal Agent
1 parent4193208 commiteb33f3f

File tree

8 files changed

+373
-9
lines changed

8 files changed

+373
-9
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from __future__importannotations
2+
3+
fromcollections.abcimportAsyncIterator
4+
fromcontextlibimportasynccontextmanager
5+
fromdataclassesimportdataclass
6+
fromdatetimeimporttimedelta
7+
fromtypingimportAny,Callable
8+
9+
frommcpimporttypesasmcp_types
10+
frompydanticimportConfigDict,with_config
11+
fromtemporalioimportactivity,workflow
12+
fromtemporalio.commonimportPriority,RetryPolicy
13+
fromtemporalio.workflowimportActivityCancellationType,VersioningIntent
14+
15+
frompydantic_ai.mcpimportMCPServer,MCPServerStdio,ToolResult
16+
frompydantic_ai.toolsets.functionimportFunctionToolset
17+
18+
from ._run_contextimportAgentDepsT,RunContext
19+
from .messagesimport (
20+
ModelMessage,
21+
ModelResponse,
22+
)
23+
from .modelsimportKnownModelName,Model,ModelRequestParameters,StreamedResponse
24+
from .models.wrapperimportWrapperModel
25+
from .settingsimportModelSettings
26+
from .toolsetsimportToolsetTool
27+
28+
__all__= ('TemporalModel',)
29+
30+
31+
@dataclass
32+
classTemporalSettings:
33+
task_queue:str|None=None
34+
schedule_to_close_timeout:timedelta|None=None
35+
schedule_to_start_timeout:timedelta|None=None
36+
start_to_close_timeout:timedelta|None=None
37+
heartbeat_timeout:timedelta|None=None
38+
retry_policy:RetryPolicy|None=None
39+
cancellation_type:ActivityCancellationType=ActivityCancellationType.TRY_CANCEL
40+
activity_id:str|None=None
41+
versioning_intent:VersioningIntent|None=None
42+
summary:str|None=None
43+
priority:Priority=Priority.default
44+
45+
46+
definitialize_temporal():
47+
frompydantic_ai.messagesimport (# noqa F401
48+
ModelResponse,# pyright: ignore[reportUnusedImport]
49+
ImageUrl,# pyright: ignore[reportUnusedImport]
50+
AudioUrl,# pyright: ignore[reportUnusedImport]
51+
DocumentUrl,# pyright: ignore[reportUnusedImport]
52+
VideoUrl,# pyright: ignore[reportUnusedImport]
53+
BinaryContent,# pyright: ignore[reportUnusedImport]
54+
UserContent,# pyright: ignore[reportUnusedImport]
55+
)
56+
57+
58+
@dataclass
59+
@with_config(ConfigDict(arbitrary_types_allowed=True))
60+
classModelRequestParams:
61+
messages:list[ModelMessage]
62+
model_settings:ModelSettings|None
63+
model_request_parameters:ModelRequestParameters
64+
65+
66+
@dataclass
67+
classTemporalModel(WrapperModel):
68+
temporal_settings:TemporalSettings
69+
70+
def__init__(
71+
self,
72+
wrapped:Model|KnownModelName,
73+
temporal_settings:TemporalSettings|None=None,
74+
)->None:
75+
super().__init__(wrapped)
76+
self.temporal_settings=temporal_settingsorTemporalSettings()
77+
78+
@activity.defn
79+
asyncdefrequest_activity(params:ModelRequestParams)->ModelResponse:
80+
returnawaitself.wrapped.request(params.messages,params.model_settings,params.model_request_parameters)
81+
82+
self.request_activity=request_activity
83+
84+
asyncdefrequest(
85+
self,
86+
messages:list[ModelMessage],
87+
model_settings:ModelSettings|None,
88+
model_request_parameters:ModelRequestParameters,
89+
)->ModelResponse:
90+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
91+
activity=self.request_activity,
92+
arg=ModelRequestParams(
93+
messages=messages,model_settings=model_settings,model_request_parameters=model_request_parameters
94+
),
95+
**self.temporal_settings.__dict__,
96+
)
97+
98+
@asynccontextmanager
99+
asyncdefrequest_stream(
100+
self,
101+
messages:list[ModelMessage],
102+
model_settings:ModelSettings|None,
103+
model_request_parameters:ModelRequestParameters,
104+
)->AsyncIterator[StreamedResponse]:
105+
raiseNotImplementedError('Cannot stream with temporal yet')
106+
yield
107+
108+
109+
@dataclass
110+
@with_config(ConfigDict(arbitrary_types_allowed=True))
111+
classTemporalRunContext:
112+
@classmethod
113+
defserialize_run_context(cls,ctx:RunContext[AgentDepsT])->dict[str,Any]:
114+
return {
115+
'deps':ctx.deps,
116+
'usage':ctx.usage,
117+
'prompt':ctx.prompt,
118+
'messages':ctx.messages,
119+
'retries':ctx.retries,
120+
'tool_call_id':ctx.tool_call_id,
121+
'tool_name':ctx.tool_name,
122+
'retry':ctx.retry,
123+
'run_step':ctx.run_step,
124+
}
125+
126+
@classmethod
127+
defdeserialize_run_context(cls,ctx:dict[str,Any])->RunContext[AgentDepsT]:
128+
# TODO: Error on anything but deps
129+
returnRunContext(
130+
deps=ctx['deps'],
131+
model=None,# TODO: Add model
132+
usage=ctx['usage'],
133+
prompt=ctx['prompt'],
134+
messages=ctx['messages'],
135+
retries=ctx['retries'],
136+
tool_call_id=ctx['tool_call_id'],
137+
tool_name=ctx['tool_name'],
138+
retry=ctx['retry'],
139+
run_step=ctx['run_step'],
140+
)
141+
142+
143+
@dataclass
144+
@with_config(ConfigDict(arbitrary_types_allowed=True))
145+
classMCPCallToolParams:
146+
name:str
147+
tool_args:dict[str,Any]
148+
metadata:dict[str,Any]|None=None
149+
150+
151+
@dataclass
152+
@with_config(ConfigDict(arbitrary_types_allowed=True))
153+
classFunctionCallToolParams:
154+
name:str
155+
tool_args:dict[str,Any]
156+
serialized_run_context:Any
157+
158+
159+
classTemporalMCPServer(MCPServerStdio):
160+
temporal_settings:TemporalSettings
161+
162+
def__init__(self,*args:Any,temporal_settings:TemporalSettings|None=None,**kwargs:Any):
163+
super().__init__(*args,**kwargs)
164+
self.temporal_settings=temporal_settingsorTemporalSettings()
165+
166+
@activity.defn(name='mcp_server_list_tools')
167+
asyncdeflist_tools_activity()->list[mcp_types.Tool]:
168+
returnawaitMCPServer.list_tools(self)
169+
170+
self.list_tools_activity=list_tools_activity
171+
172+
@activity.defn(name='mcp_server_call_tool')
173+
asyncdefcall_tool_activity(params:MCPCallToolParams)->ToolResult:
174+
returnawaitMCPServer.direct_call_tool(self,params.name,params.tool_args,params.metadata)
175+
176+
self.call_tool_activity=call_tool_activity
177+
178+
asyncdeflist_tools(self)->list[mcp_types.Tool]:
179+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
180+
activity=self.list_tools_activity,
181+
**self.temporal_settings.__dict__,
182+
)
183+
184+
asyncdefdirect_call_tool(
185+
self,
186+
name:str,
187+
args:dict[str,Any],
188+
metadata:dict[str,Any]|None=None,
189+
)->ToolResult:
190+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
191+
activity=self.call_tool_activity,
192+
arg=MCPCallToolParams(name=name,tool_args=args,metadata=metadata),
193+
**self.temporal_settings.__dict__,
194+
)
195+
196+
197+
classTemporalFunctionToolset(FunctionToolset[AgentDepsT]):
198+
# TODO: Drop args, kwargs, use actual typed args
199+
def__init__(
200+
self,
201+
*args:Any,
202+
temporal_settings:TemporalSettings|None=None,
203+
serialize_run_context:Callable[[RunContext[AgentDepsT]],Any]|None=None,
204+
deserialize_run_context:Callable[[Any],RunContext[AgentDepsT]]|None=None,
205+
**kwargs:Any,
206+
):
207+
super().__init__(*args,**kwargs)
208+
self.temporal_settings=temporal_settingsorTemporalSettings()
209+
self.serialize_run_context=serialize_run_contextorTemporalRunContext.serialize_run_context
210+
self.deserialize_run_context=deserialize_run_contextorTemporalRunContext.deserialize_run_context
211+
212+
@activity.defn(name='function_toolset_call_tool')
213+
asyncdefcall_tool_activity(params:FunctionCallToolParams)->Any:
214+
ctx=self.deserialize_run_context(params.serialized_run_context)
215+
tool= (awaitself.get_tools(ctx))[
216+
params.name
217+
]# TODO: Possibly problematic as tools could've been added/removed/renamed during the run
218+
returnawaitFunctionToolset.call_tool(self,params.name,params.tool_args,ctx,tool)
219+
220+
self.call_tool_activity=call_tool_activity
221+
222+
asyncdefcall_tool(
223+
self,name:str,tool_args:dict[str,Any],ctx:RunContext[AgentDepsT],tool:ToolsetTool[AgentDepsT]
224+
)->Any:
225+
serialized_run_context=self.serialize_run_context(ctx)
226+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
227+
activity=self.call_tool_activity,
228+
arg=FunctionCallToolParams(name=name,tool_args=tool_args,serialized_run_context=serialized_run_context),
229+
**self.temporal_settings.__dict__,
230+
)

‎pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ async def call_tool(
115115
"""
116116
raiseNotImplementedError()
117117

118-
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],Any])->Any:
118+
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],None])->None:
119119
"""Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling)."""
120-
returnvisitor(self)
120+
visitor(self)
121121

122122
deffiltered(
123123
self,filter_func:Callable[[RunContext[AgentDepsT],ToolDefinition],bool]

‎pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ async def call_tool(
8383
assertisinstance(tool,_CombinedToolsetTool)
8484
returnawaittool.source_toolset.call_tool(name,tool_args,ctx,tool.source_tool)
8585

86-
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],Any])->Any:
86+
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],None])->None:
87+
# TODO: Let this be used to replace toolsets with temporal-specific versions, see David's visit-and-replace-toolsets branch
8788
fortoolsetinself.toolsets:
8889
toolset.apply(visitor)

‎pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ async def call_tool(
3333
)->Any:
3434
returnawaitself.wrapped.call_tool(name,tool_args,ctx,tool)
3535

36-
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],Any])->Any:
37-
returnself.wrapped.apply(visitor)
36+
defapply(self,visitor:Callable[[AbstractToolset[AgentDepsT]],None])->None:
37+
self.wrapped.apply(visitor)

‎pydantic_ai_slim/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
8282
evals = ["pydantic-evals=={{ version }}"]
8383
# A2A
8484
a2a = ["fasta2a>=0.4.1"]
85+
# Temporal
86+
temporal = ["temporalio>=1.13.0"]
8587

8688
[dependency-groups]
8789
dev = [

‎pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ requires-python = ">=3.9"
4747

4848
[tool.hatch.metadata.hooks.uv-dynamic-versioning]
4949
dependencies = [
50-
"pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals]=={{ version }}",
50+
"pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,temporal]=={{ version }}",
5151
]
5252

5353
[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies]

‎temporal.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# /// script
2+
# dependencies = [
3+
# "temporalio",
4+
# "logfire",
5+
# ]
6+
# ///
7+
importasyncio
8+
importrandom
9+
fromdatetimeimporttimedelta
10+
11+
fromtemporalioimportworkflow
12+
fromtemporalio.clientimportClient
13+
fromtemporalio.contrib.opentelemetryimportTracingInterceptor
14+
fromtemporalio.contrib.pydanticimportpydantic_data_converter
15+
fromtemporalio.runtimeimportOpenTelemetryConfig,Runtime,TelemetryConfig
16+
fromtemporalio.workerimportWorker
17+
18+
withworkflow.unsafe.imports_passed_through():
19+
frompydantic_aiimportAgent
20+
frompydantic_ai.temporalimport (
21+
TemporalFunctionToolset,
22+
TemporalMCPServer,
23+
TemporalModel,
24+
TemporalSettings,
25+
initialize_temporal,
26+
)
27+
28+
initialize_temporal()
29+
30+
temporal_settings=TemporalSettings(start_to_close_timeout=timedelta(seconds=60))
31+
32+
model=TemporalModel(
33+
'openai:gpt-4o',
34+
temporal_settings=temporal_settings,
35+
)
36+
37+
defget_uv_index(location:str)->int:
38+
return3
39+
40+
toolset=TemporalFunctionToolset(tools=[get_uv_index],temporal_settings=temporal_settings)
41+
mcp_server=TemporalMCPServer(
42+
'python',
43+
['-m','tests.mcp_server'],
44+
timeout=20,
45+
temporal_settings=temporal_settings,
46+
)
47+
48+
my_agent=Agent(model=model,instructions='be helpful',toolsets=[toolset,mcp_server])
49+
50+
51+
definit_runtime_with_telemetry()->Runtime:
52+
# import logfire
53+
54+
# logfire.configure(send_to_logfire=True, service_version='0.0.1', console=False)
55+
# logfire.instrument_pydantic_ai()
56+
# logfire.instrument_httpx(capture_all=True)
57+
58+
# Setup SDK metrics to OTel endpoint
59+
returnRuntime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318')))
60+
61+
62+
# Basic workflow that logs and invokes an activity
63+
@workflow.defn
64+
classMyAgentWorkflow:
65+
@workflow.run
66+
asyncdefrun(self,prompt:str)->str:
67+
return (awaitmy_agent.run(prompt)).output
68+
69+
70+
asyncdefmain():
71+
client=awaitClient.connect(
72+
'localhost:7233',
73+
interceptors=[TracingInterceptor()],
74+
data_converter=pydantic_data_converter,
75+
runtime=init_runtime_with_telemetry(),
76+
)
77+
78+
asyncwithWorker(
79+
client,
80+
task_queue='my-agent-task-queue',
81+
workflows=[MyAgentWorkflow],
82+
activities=[
83+
model.request_activity,
84+
toolset.call_tool_activity,
85+
mcp_server.list_tools_activity,
86+
mcp_server.call_tool_activity,
87+
],
88+
):
89+
result=awaitclient.execute_workflow(# pyright: ignore[reportUnknownMemberType]
90+
MyAgentWorkflow.run,
91+
'what is 2 plus the UV Index in Mexico City? and what is the product name?',
92+
id=f'my-agent-workflow-id-{random.random()}',
93+
task_queue='my-agent-task-queue',
94+
)
95+
print(f'Result:{result!r}')
96+
97+
98+
if__name__=='__main__':
99+
asyncio.run(main())

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp