Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2abccad

Browse files
committed
WIP
1 parent01c550c commit2abccad

File tree

8 files changed

+540
-4
lines changed

8 files changed

+540
-4
lines changed
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
from __future__importannotations
2+
3+
fromcollections.abcimportAsyncIterator
4+
fromcontextlibimportasynccontextmanager
5+
fromdataclassesimportdataclass
6+
fromdatetimeimporttimedelta
7+
fromtypingimportAny,Callable
8+
9+
frommcpimporttypesasmcp_types
10+
frompydanticimportConfigDict,with_config
11+
fromtemporalioimportactivity,workflow
12+
fromtemporalio.commonimportPriority,RetryPolicy
13+
fromtemporalio.workflowimportActivityCancellationType,VersioningIntent
14+
15+
frompydantic_ai.agentimportAgent
16+
frompydantic_ai.mcpimportMCPServer,ToolResult
17+
frompydantic_ai.toolsets.abstractimportAbstractToolset
18+
frompydantic_ai.toolsets.functionimportFunctionToolset
19+
20+
from ._run_contextimportAgentDepsT,RunContext
21+
from .messagesimport (
22+
ModelMessage,
23+
ModelResponse,
24+
)
25+
from .modelsimportModel,ModelRequestParameters,StreamedResponse
26+
from .settingsimportModelSettings
27+
from .toolsetsimportToolsetTool
28+
29+
30+
@dataclass
31+
classTemporalSettings:
32+
task_queue:str|None=None
33+
schedule_to_close_timeout:timedelta|None=None
34+
schedule_to_start_timeout:timedelta|None=None
35+
start_to_close_timeout:timedelta|None=None
36+
heartbeat_timeout:timedelta|None=None
37+
retry_policy:RetryPolicy|None=None
38+
cancellation_type:ActivityCancellationType=ActivityCancellationType.TRY_CANCEL
39+
activity_id:str|None=None
40+
versioning_intent:VersioningIntent|None=None
41+
summary:str|None=None
42+
priority:Priority=Priority.default
43+
44+
45+
definitialize_temporal():
46+
frompydantic_ai.messagesimport (# noqa F401
47+
ModelResponse,# pyright: ignore[reportUnusedImport]
48+
ImageUrl,# pyright: ignore[reportUnusedImport]
49+
AudioUrl,# pyright: ignore[reportUnusedImport]
50+
DocumentUrl,# pyright: ignore[reportUnusedImport]
51+
VideoUrl,# pyright: ignore[reportUnusedImport]
52+
BinaryContent,# pyright: ignore[reportUnusedImport]
53+
UserContent,# pyright: ignore[reportUnusedImport]
54+
)
55+
56+
57+
@dataclass
58+
@with_config(ConfigDict(arbitrary_types_allowed=True))
59+
classModelRequestParams:
60+
messages:list[ModelMessage]
61+
model_settings:ModelSettings|None
62+
model_request_parameters:ModelRequestParameters
63+
64+
65+
deftemporalize_model(model:Model,temporal_settings:TemporalSettings|None=None)->list[Callable[...,Any]]:
66+
ifactivities:=getattr(model,'__temporal_activities',None):
67+
returnactivities
68+
69+
temporal_settings=temporal_settingsorTemporalSettings()
70+
71+
original_request=model.request
72+
73+
@activity.defn(name='model_request')
74+
asyncdefrequest_activity(params:ModelRequestParams)->ModelResponse:
75+
returnawaitoriginal_request(params.messages,params.model_settings,params.model_request_parameters)
76+
77+
asyncdefrequest(
78+
messages:list[ModelMessage],
79+
model_settings:ModelSettings|None,
80+
model_request_parameters:ModelRequestParameters,
81+
)->ModelResponse:
82+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
83+
activity=request_activity,
84+
arg=ModelRequestParams(
85+
messages=messages,model_settings=model_settings,model_request_parameters=model_request_parameters
86+
),
87+
**temporal_settings.__dict__,
88+
)
89+
90+
@asynccontextmanager
91+
asyncdefrequest_stream(
92+
messages:list[ModelMessage],
93+
model_settings:ModelSettings|None,
94+
model_request_parameters:ModelRequestParameters,
95+
)->AsyncIterator[StreamedResponse]:
96+
raiseNotImplementedError('Cannot stream with temporal yet')
97+
yield
98+
99+
model.request=request
100+
model.request_stream=request_stream
101+
102+
activities= [request_activity]
103+
setattr(model,'__temporal_activities',activities)
104+
returnactivities
105+
106+
107+
# @dataclass
108+
# class TemporalModel(WrapperModel):
109+
# temporal_settings: TemporalSettings
110+
111+
# def __init__(
112+
# self,
113+
# wrapped: Model | KnownModelName,
114+
# temporal_settings: TemporalSettings | None = None,
115+
# ) -> None:
116+
# super().__init__(wrapped)
117+
# self.temporal_settings = temporal_settings or TemporalSettings()
118+
119+
# @activity.defn
120+
# async def request_activity(params: ModelRequestParams) -> ModelResponse:
121+
# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters)
122+
123+
# self.request_activity = request_activity
124+
125+
# async def request(
126+
# self,
127+
# messages: list[ModelMessage],
128+
# model_settings: ModelSettings | None,
129+
# model_request_parameters: ModelRequestParameters,
130+
# ) -> ModelResponse:
131+
# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
132+
# activity=self.request_activity,
133+
# arg=ModelRequestParams(
134+
# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters
135+
# ),
136+
# **self.temporal_settings.__dict__,
137+
# )
138+
139+
# @asynccontextmanager
140+
# async def request_stream(
141+
# self,
142+
# messages: list[ModelMessage],
143+
# model_settings: ModelSettings | None,
144+
# model_request_parameters: ModelRequestParameters,
145+
# ) -> AsyncIterator[StreamedResponse]:
146+
# raise NotImplementedError('Cannot stream with temporal yet')
147+
# yield
148+
149+
150+
classTemporalRunContext(RunContext[AgentDepsT]):
151+
_data:dict[str,Any]
152+
153+
def__init__(self,**kwargs:Any):
154+
self._data=kwargs
155+
setattr(
156+
self,
157+
'__dataclass_fields__',
158+
{name:fieldforname,fieldinRunContext.__dataclass_fields__.items()ifnameinkwargs},
159+
)
160+
161+
def__getattribute__(self,name:str)->Any:
162+
try:
163+
returnsuper().__getattribute__(name)
164+
exceptAttributeErrorase:
165+
data=super().__getattribute__('_data')
166+
ifnameindata:
167+
returndata[name]
168+
raisee# TODO: Explain how to make a new run context attribute available
169+
170+
@classmethod
171+
defserialize_run_context(cls,ctx:RunContext[AgentDepsT])->dict[str,Any]:
172+
return {
173+
'deps':ctx.deps,
174+
'retries':ctx.retries,
175+
'tool_call_id':ctx.tool_call_id,
176+
'tool_name':ctx.tool_name,
177+
'retry':ctx.retry,
178+
'run_step':ctx.run_step,
179+
}
180+
181+
@classmethod
182+
defdeserialize_run_context(cls,ctx:dict[str,Any])->RunContext[AgentDepsT]:
183+
returncls(**ctx)
184+
185+
186+
@dataclass
187+
@with_config(ConfigDict(arbitrary_types_allowed=True))
188+
classMCPCallToolParams:
189+
name:str
190+
tool_args:dict[str,Any]
191+
metadata:dict[str,Any]|None=None
192+
193+
194+
@dataclass
195+
@with_config(ConfigDict(arbitrary_types_allowed=True))
196+
classFunctionCallToolParams:
197+
name:str
198+
tool_args:dict[str,Any]
199+
serialized_run_context:Any
200+
201+
202+
deftemporalize_mcp_server(
203+
server:MCPServer,temporal_settings:TemporalSettings|None=None
204+
)->list[Callable[...,Any]]:
205+
ifactivities:=getattr(server,'__temporal_activities',None):
206+
returnactivities
207+
208+
temporal_settings=temporal_settingsorTemporalSettings()
209+
210+
original_list_tools=server.list_tools
211+
original_direct_call_tool=server.direct_call_tool
212+
213+
@activity.defn(
214+
name='mcp_server_list_tools'
215+
)# TODO: Require a name to be passed to TemporalMCPServer? If we get toolsets from a lib, what do we do? Strongly recommend setting a name?
216+
asyncdeflist_tools_activity()->list[mcp_types.Tool]:
217+
returnawaitoriginal_list_tools()
218+
219+
@activity.defn(name='mcp_server_call_tool')
220+
asyncdefcall_tool_activity(params:MCPCallToolParams)->ToolResult:
221+
returnawaitoriginal_direct_call_tool(params.name,params.tool_args,params.metadata)
222+
223+
asyncdeflist_tools()->list[mcp_types.Tool]:
224+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
225+
activity=list_tools_activity,
226+
**temporal_settings.__dict__,
227+
)
228+
229+
asyncdefdirect_call_tool(
230+
name:str,
231+
args:dict[str,Any],
232+
metadata:dict[str,Any]|None=None,
233+
)->ToolResult:
234+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
235+
activity=call_tool_activity,
236+
arg=MCPCallToolParams(name=name,tool_args=args,metadata=metadata),
237+
**temporal_settings.__dict__,
238+
)
239+
240+
server.list_tools=list_tools
241+
server.direct_call_tool=direct_call_tool
242+
243+
activities= [list_tools_activity,call_tool_activity]
244+
setattr(server,'__temporal_activities',activities)
245+
returnactivities
246+
247+
248+
# class TemporalMCPServer(WrapperToolset[Any]):
249+
# temporal_settings: TemporalSettings
250+
251+
# @property
252+
# def wrapped_server(self) -> MCPServer:
253+
# assert isinstance(self.wrapped, MCPServer)
254+
# return self.wrapped
255+
256+
# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None):
257+
# assert isinstance(self.wrapped, MCPServer)
258+
# super().__init__(wrapped)
259+
# self.temporal_settings = temporal_settings or TemporalSettings()
260+
261+
# @activity.defn(name='mcp_server_list_tools')
262+
# async def list_tools_activity() -> list[mcp_types.Tool]:
263+
# return await self.wrapped_server.list_tools()
264+
265+
# self.list_tools_activity = list_tools_activity
266+
267+
# @activity.defn(name='mcp_server_call_tool')
268+
# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult:
269+
# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata)
270+
271+
# self.call_tool_activity = call_tool_activity
272+
273+
# async def list_tools(self) -> list[mcp_types.Tool]:
274+
# return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
275+
# activity=self.list_tools_activity,
276+
# **self.temporal_settings.__dict__,
277+
# )
278+
279+
# async def direct_call_tool(
280+
# self,
281+
# name: str,
282+
# args: dict[str, Any],
283+
# metadata: dict[str, Any] | None = None,
284+
# ) -> ToolResult:
285+
# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
286+
# activity=self.call_tool_activity,
287+
# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata),
288+
# **self.temporal_settings.__dict__,
289+
# )
290+
291+
292+
deftemporalize_function_toolset(
293+
toolset:FunctionToolset[AgentDepsT],
294+
temporal_settings:TemporalSettings|None=None,
295+
serialize_run_context:Callable[[RunContext[AgentDepsT]],Any]|None=None,
296+
deserialize_run_context:Callable[[Any],RunContext[AgentDepsT]]|None=None,
297+
)->list[Callable[...,Any]]:
298+
ifactivities:=getattr(toolset,'__temporal_activities',None):
299+
returnactivities
300+
301+
temporal_settings=temporal_settingsorTemporalSettings()
302+
# TODO: Settings per tool name
303+
serialize_run_context=serialize_run_contextorTemporalRunContext[AgentDepsT].serialize_run_context
304+
deserialize_run_context=deserialize_run_contextorTemporalRunContext[AgentDepsT].deserialize_run_context
305+
306+
original_call_tool=toolset.call_tool
307+
308+
@activity.defn(name='function_toolset_call_tool')
309+
asyncdefcall_tool_activity(params:FunctionCallToolParams)->Any:
310+
ctx=deserialize_run_context(params.serialized_run_context)
311+
tool= (awaittoolset.get_tools(ctx))[params.name]
312+
returnawaitoriginal_call_tool(params.name,params.tool_args,ctx,tool)
313+
314+
asyncdefcall_tool(
315+
name:str,tool_args:dict[str,Any],ctx:RunContext[AgentDepsT],tool:ToolsetTool[AgentDepsT]
316+
)->Any:
317+
serialized_run_context=serialize_run_context(ctx)
318+
returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType]
319+
activity=call_tool_activity,
320+
arg=FunctionCallToolParams(name=name,tool_args=tool_args,serialized_run_context=serialized_run_context),
321+
**temporal_settings.__dict__,
322+
)
323+
324+
toolset.call_tool=call_tool
325+
326+
activities= [call_tool_activity]
327+
setattr(toolset,'__temporal_activities',activities)
328+
returnactivities
329+
330+
331+
# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]):
332+
# def __init__(
333+
# self,
334+
# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
335+
# max_retries: int = 1,
336+
# temporal_settings: TemporalSettings | None = None,
337+
# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None,
338+
# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None,
339+
# ):
340+
# super().__init__(tools, max_retries)
341+
# self.temporal_settings = temporal_settings or TemporalSettings()
342+
# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context
343+
# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context
344+
345+
# @activity.defn(name='function_toolset_call_tool')
346+
# async def call_tool_activity(params: FunctionCallToolParams) -> Any:
347+
# ctx = self.deserialize_run_context(params.serialized_run_context)
348+
# tool = (await self.get_tools(ctx))[params.name]
349+
# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool)
350+
351+
# self.call_tool_activity = call_tool_activity
352+
353+
# async def call_tool(
354+
# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
355+
# ) -> Any:
356+
# serialized_run_context = self.serialize_run_context(ctx)
357+
# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
358+
# activity=self.call_tool_activity,
359+
# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
360+
# **self.temporal_settings.__dict__,
361+
# )
362+
363+
364+
deftemporalize_agent(agent:Agent,temporal_settings:TemporalSettings|None=None)->list[Callable[...,Any]]:
365+
ifexisting_activities:=getattr(agent,'__temporal_activities',None):
366+
returnexisting_activities
367+
368+
temporal_settings=temporal_settingsorTemporalSettings()
369+
370+
activities:list[Callable[...,Any]]= []
371+
ifisinstance(agent.model,Model):
372+
# Doesn't work when model is not set already
373+
activities.extend(temporalize_model(agent.model,temporal_settings))
374+
375+
# TODO : Make TemporalMCPServer a wrapper
376+
iftoolset:=agent._get_toolset():
377+
# Doesn't consider toolsets passed at iter time
378+
deftemporalize_toolset(toolset:AbstractToolset[AgentDepsT])->None:
379+
ifisinstance(toolset,FunctionToolset):
380+
activities.extend(temporalize_function_toolset(toolset,temporal_settings))
381+
elifisinstance(toolset,MCPServer):
382+
activities.extend(temporalize_mcp_server(toolset,temporal_settings))
383+
384+
toolset.apply(temporalize_toolset)
385+
386+
setattr(agent,'__temporal_activities',activities)
387+
returnactivities

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp