|
| 1 | +from __future__importannotations |
| 2 | + |
| 3 | +fromcollections.abcimportAsyncIterator |
| 4 | +fromcontextlibimportasynccontextmanager |
| 5 | +fromdataclassesimportdataclass |
| 6 | +fromdatetimeimporttimedelta |
| 7 | +fromtypingimportAny,Callable |
| 8 | + |
| 9 | +frommcpimporttypesasmcp_types |
| 10 | +frompydanticimportConfigDict,with_config |
| 11 | +fromtemporalioimportactivity,workflow |
| 12 | +fromtemporalio.commonimportPriority,RetryPolicy |
| 13 | +fromtemporalio.workflowimportActivityCancellationType,VersioningIntent |
| 14 | + |
| 15 | +frompydantic_ai.agentimportAgent |
| 16 | +frompydantic_ai.mcpimportMCPServer,ToolResult |
| 17 | +frompydantic_ai.toolsets.abstractimportAbstractToolset |
| 18 | +frompydantic_ai.toolsets.functionimportFunctionToolset |
| 19 | + |
| 20 | +from ._run_contextimportAgentDepsT,RunContext |
| 21 | +from .messagesimport ( |
| 22 | +ModelMessage, |
| 23 | +ModelResponse, |
| 24 | +) |
| 25 | +from .modelsimportModel,ModelRequestParameters,StreamedResponse |
| 26 | +from .settingsimportModelSettings |
| 27 | +from .toolsetsimportToolsetTool |
| 28 | + |
| 29 | + |
| 30 | +@dataclass |
| 31 | +classTemporalSettings: |
| 32 | +task_queue:str|None=None |
| 33 | +schedule_to_close_timeout:timedelta|None=None |
| 34 | +schedule_to_start_timeout:timedelta|None=None |
| 35 | +start_to_close_timeout:timedelta|None=None |
| 36 | +heartbeat_timeout:timedelta|None=None |
| 37 | +retry_policy:RetryPolicy|None=None |
| 38 | +cancellation_type:ActivityCancellationType=ActivityCancellationType.TRY_CANCEL |
| 39 | +activity_id:str|None=None |
| 40 | +versioning_intent:VersioningIntent|None=None |
| 41 | +summary:str|None=None |
| 42 | +priority:Priority=Priority.default |
| 43 | + |
| 44 | + |
| 45 | +definitialize_temporal(): |
| 46 | +frompydantic_ai.messagesimport (# noqa F401 |
| 47 | +ModelResponse,# pyright: ignore[reportUnusedImport] |
| 48 | +ImageUrl,# pyright: ignore[reportUnusedImport] |
| 49 | +AudioUrl,# pyright: ignore[reportUnusedImport] |
| 50 | +DocumentUrl,# pyright: ignore[reportUnusedImport] |
| 51 | +VideoUrl,# pyright: ignore[reportUnusedImport] |
| 52 | +BinaryContent,# pyright: ignore[reportUnusedImport] |
| 53 | +UserContent,# pyright: ignore[reportUnusedImport] |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +@dataclass |
| 58 | +@with_config(ConfigDict(arbitrary_types_allowed=True)) |
| 59 | +classModelRequestParams: |
| 60 | +messages:list[ModelMessage] |
| 61 | +model_settings:ModelSettings|None |
| 62 | +model_request_parameters:ModelRequestParameters |
| 63 | + |
| 64 | + |
| 65 | +deftemporalize_model(model:Model,temporal_settings:TemporalSettings|None=None)->list[Callable[...,Any]]: |
| 66 | +ifactivities:=getattr(model,'__temporal_activities',None): |
| 67 | +returnactivities |
| 68 | + |
| 69 | +temporal_settings=temporal_settingsorTemporalSettings() |
| 70 | + |
| 71 | +original_request=model.request |
| 72 | + |
| 73 | +@activity.defn(name='model_request') |
| 74 | +asyncdefrequest_activity(params:ModelRequestParams)->ModelResponse: |
| 75 | +returnawaitoriginal_request(params.messages,params.model_settings,params.model_request_parameters) |
| 76 | + |
| 77 | +asyncdefrequest( |
| 78 | +messages:list[ModelMessage], |
| 79 | +model_settings:ModelSettings|None, |
| 80 | +model_request_parameters:ModelRequestParameters, |
| 81 | + )->ModelResponse: |
| 82 | +returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType] |
| 83 | +activity=request_activity, |
| 84 | +arg=ModelRequestParams( |
| 85 | +messages=messages,model_settings=model_settings,model_request_parameters=model_request_parameters |
| 86 | + ), |
| 87 | +**temporal_settings.__dict__, |
| 88 | + ) |
| 89 | + |
| 90 | +@asynccontextmanager |
| 91 | +asyncdefrequest_stream( |
| 92 | +messages:list[ModelMessage], |
| 93 | +model_settings:ModelSettings|None, |
| 94 | +model_request_parameters:ModelRequestParameters, |
| 95 | + )->AsyncIterator[StreamedResponse]: |
| 96 | +raiseNotImplementedError('Cannot stream with temporal yet') |
| 97 | +yield |
| 98 | + |
| 99 | +model.request=request |
| 100 | +model.request_stream=request_stream |
| 101 | + |
| 102 | +activities= [request_activity] |
| 103 | +setattr(model,'__temporal_activities',activities) |
| 104 | +returnactivities |
| 105 | + |
| 106 | + |
| 107 | +# @dataclass |
| 108 | +# class TemporalModel(WrapperModel): |
| 109 | +# temporal_settings: TemporalSettings |
| 110 | + |
| 111 | +# def __init__( |
| 112 | +# self, |
| 113 | +# wrapped: Model | KnownModelName, |
| 114 | +# temporal_settings: TemporalSettings | None = None, |
| 115 | +# ) -> None: |
| 116 | +# super().__init__(wrapped) |
| 117 | +# self.temporal_settings = temporal_settings or TemporalSettings() |
| 118 | + |
| 119 | +# @activity.defn |
| 120 | +# async def request_activity(params: ModelRequestParams) -> ModelResponse: |
| 121 | +# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) |
| 122 | + |
| 123 | +# self.request_activity = request_activity |
| 124 | + |
| 125 | +# async def request( |
| 126 | +# self, |
| 127 | +# messages: list[ModelMessage], |
| 128 | +# model_settings: ModelSettings | None, |
| 129 | +# model_request_parameters: ModelRequestParameters, |
| 130 | +# ) -> ModelResponse: |
| 131 | +# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] |
| 132 | +# activity=self.request_activity, |
| 133 | +# arg=ModelRequestParams( |
| 134 | +# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters |
| 135 | +# ), |
| 136 | +# **self.temporal_settings.__dict__, |
| 137 | +# ) |
| 138 | + |
| 139 | +# @asynccontextmanager |
| 140 | +# async def request_stream( |
| 141 | +# self, |
| 142 | +# messages: list[ModelMessage], |
| 143 | +# model_settings: ModelSettings | None, |
| 144 | +# model_request_parameters: ModelRequestParameters, |
| 145 | +# ) -> AsyncIterator[StreamedResponse]: |
| 146 | +# raise NotImplementedError('Cannot stream with temporal yet') |
| 147 | +# yield |
| 148 | + |
| 149 | + |
| 150 | +classTemporalRunContext(RunContext[AgentDepsT]): |
| 151 | +_data:dict[str,Any] |
| 152 | + |
| 153 | +def__init__(self,**kwargs:Any): |
| 154 | +self._data=kwargs |
| 155 | +setattr( |
| 156 | +self, |
| 157 | +'__dataclass_fields__', |
| 158 | + {name:fieldforname,fieldinRunContext.__dataclass_fields__.items()ifnameinkwargs}, |
| 159 | + ) |
| 160 | + |
| 161 | +def__getattribute__(self,name:str)->Any: |
| 162 | +try: |
| 163 | +returnsuper().__getattribute__(name) |
| 164 | +exceptAttributeErrorase: |
| 165 | +data=super().__getattribute__('_data') |
| 166 | +ifnameindata: |
| 167 | +returndata[name] |
| 168 | +raisee# TODO: Explain how to make a new run context attribute available |
| 169 | + |
| 170 | +@classmethod |
| 171 | +defserialize_run_context(cls,ctx:RunContext[AgentDepsT])->dict[str,Any]: |
| 172 | +return { |
| 173 | +'deps':ctx.deps, |
| 174 | +'retries':ctx.retries, |
| 175 | +'tool_call_id':ctx.tool_call_id, |
| 176 | +'tool_name':ctx.tool_name, |
| 177 | +'retry':ctx.retry, |
| 178 | +'run_step':ctx.run_step, |
| 179 | + } |
| 180 | + |
| 181 | +@classmethod |
| 182 | +defdeserialize_run_context(cls,ctx:dict[str,Any])->RunContext[AgentDepsT]: |
| 183 | +returncls(**ctx) |
| 184 | + |
| 185 | + |
| 186 | +@dataclass |
| 187 | +@with_config(ConfigDict(arbitrary_types_allowed=True)) |
| 188 | +classMCPCallToolParams: |
| 189 | +name:str |
| 190 | +tool_args:dict[str,Any] |
| 191 | +metadata:dict[str,Any]|None=None |
| 192 | + |
| 193 | + |
| 194 | +@dataclass |
| 195 | +@with_config(ConfigDict(arbitrary_types_allowed=True)) |
| 196 | +classFunctionCallToolParams: |
| 197 | +name:str |
| 198 | +tool_args:dict[str,Any] |
| 199 | +serialized_run_context:Any |
| 200 | + |
| 201 | + |
| 202 | +deftemporalize_mcp_server( |
| 203 | +server:MCPServer,temporal_settings:TemporalSettings|None=None |
| 204 | +)->list[Callable[...,Any]]: |
| 205 | +ifactivities:=getattr(server,'__temporal_activities',None): |
| 206 | +returnactivities |
| 207 | + |
| 208 | +temporal_settings=temporal_settingsorTemporalSettings() |
| 209 | + |
| 210 | +original_list_tools=server.list_tools |
| 211 | +original_direct_call_tool=server.direct_call_tool |
| 212 | + |
| 213 | +@activity.defn( |
| 214 | +name='mcp_server_list_tools' |
| 215 | + )# TODO: Require a name to be passed to TemporalMCPServer? If we get toolsets from a lib, what do we do? Strongly recommend setting a name? |
| 216 | +asyncdeflist_tools_activity()->list[mcp_types.Tool]: |
| 217 | +returnawaitoriginal_list_tools() |
| 218 | + |
| 219 | +@activity.defn(name='mcp_server_call_tool') |
| 220 | +asyncdefcall_tool_activity(params:MCPCallToolParams)->ToolResult: |
| 221 | +returnawaitoriginal_direct_call_tool(params.name,params.tool_args,params.metadata) |
| 222 | + |
| 223 | +asyncdeflist_tools()->list[mcp_types.Tool]: |
| 224 | +returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] |
| 225 | +activity=list_tools_activity, |
| 226 | +**temporal_settings.__dict__, |
| 227 | + ) |
| 228 | + |
| 229 | +asyncdefdirect_call_tool( |
| 230 | +name:str, |
| 231 | +args:dict[str,Any], |
| 232 | +metadata:dict[str,Any]|None=None, |
| 233 | + )->ToolResult: |
| 234 | +returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType] |
| 235 | +activity=call_tool_activity, |
| 236 | +arg=MCPCallToolParams(name=name,tool_args=args,metadata=metadata), |
| 237 | +**temporal_settings.__dict__, |
| 238 | + ) |
| 239 | + |
| 240 | +server.list_tools=list_tools |
| 241 | +server.direct_call_tool=direct_call_tool |
| 242 | + |
| 243 | +activities= [list_tools_activity,call_tool_activity] |
| 244 | +setattr(server,'__temporal_activities',activities) |
| 245 | +returnactivities |
| 246 | + |
| 247 | + |
| 248 | +# class TemporalMCPServer(WrapperToolset[Any]): |
| 249 | +# temporal_settings: TemporalSettings |
| 250 | + |
| 251 | +# @property |
| 252 | +# def wrapped_server(self) -> MCPServer: |
| 253 | +# assert isinstance(self.wrapped, MCPServer) |
| 254 | +# return self.wrapped |
| 255 | + |
| 256 | +# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None): |
| 257 | +# assert isinstance(self.wrapped, MCPServer) |
| 258 | +# super().__init__(wrapped) |
| 259 | +# self.temporal_settings = temporal_settings or TemporalSettings() |
| 260 | + |
| 261 | +# @activity.defn(name='mcp_server_list_tools') |
| 262 | +# async def list_tools_activity() -> list[mcp_types.Tool]: |
| 263 | +# return await self.wrapped_server.list_tools() |
| 264 | + |
| 265 | +# self.list_tools_activity = list_tools_activity |
| 266 | + |
| 267 | +# @activity.defn(name='mcp_server_call_tool') |
| 268 | +# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult: |
| 269 | +# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata) |
| 270 | + |
| 271 | +# self.call_tool_activity = call_tool_activity |
| 272 | + |
| 273 | +# async def list_tools(self) -> list[mcp_types.Tool]: |
| 274 | +# return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] |
| 275 | +# activity=self.list_tools_activity, |
| 276 | +# **self.temporal_settings.__dict__, |
| 277 | +# ) |
| 278 | + |
| 279 | +# async def direct_call_tool( |
| 280 | +# self, |
| 281 | +# name: str, |
| 282 | +# args: dict[str, Any], |
| 283 | +# metadata: dict[str, Any] | None = None, |
| 284 | +# ) -> ToolResult: |
| 285 | +# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] |
| 286 | +# activity=self.call_tool_activity, |
| 287 | +# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata), |
| 288 | +# **self.temporal_settings.__dict__, |
| 289 | +# ) |
| 290 | + |
| 291 | + |
| 292 | +deftemporalize_function_toolset( |
| 293 | +toolset:FunctionToolset[AgentDepsT], |
| 294 | +temporal_settings:TemporalSettings|None=None, |
| 295 | +serialize_run_context:Callable[[RunContext[AgentDepsT]],Any]|None=None, |
| 296 | +deserialize_run_context:Callable[[Any],RunContext[AgentDepsT]]|None=None, |
| 297 | +)->list[Callable[...,Any]]: |
| 298 | +ifactivities:=getattr(toolset,'__temporal_activities',None): |
| 299 | +returnactivities |
| 300 | + |
| 301 | +temporal_settings=temporal_settingsorTemporalSettings() |
| 302 | +# TODO: Settings per tool name |
| 303 | +serialize_run_context=serialize_run_contextorTemporalRunContext[AgentDepsT].serialize_run_context |
| 304 | +deserialize_run_context=deserialize_run_contextorTemporalRunContext[AgentDepsT].deserialize_run_context |
| 305 | + |
| 306 | +original_call_tool=toolset.call_tool |
| 307 | + |
| 308 | +@activity.defn(name='function_toolset_call_tool') |
| 309 | +asyncdefcall_tool_activity(params:FunctionCallToolParams)->Any: |
| 310 | +ctx=deserialize_run_context(params.serialized_run_context) |
| 311 | +tool= (awaittoolset.get_tools(ctx))[params.name] |
| 312 | +returnawaitoriginal_call_tool(params.name,params.tool_args,ctx,tool) |
| 313 | + |
| 314 | +asyncdefcall_tool( |
| 315 | +name:str,tool_args:dict[str,Any],ctx:RunContext[AgentDepsT],tool:ToolsetTool[AgentDepsT] |
| 316 | + )->Any: |
| 317 | +serialized_run_context=serialize_run_context(ctx) |
| 318 | +returnawaitworkflow.execute_activity(# pyright: ignore[reportUnknownMemberType] |
| 319 | +activity=call_tool_activity, |
| 320 | +arg=FunctionCallToolParams(name=name,tool_args=tool_args,serialized_run_context=serialized_run_context), |
| 321 | +**temporal_settings.__dict__, |
| 322 | + ) |
| 323 | + |
| 324 | +toolset.call_tool=call_tool |
| 325 | + |
| 326 | +activities= [call_tool_activity] |
| 327 | +setattr(toolset,'__temporal_activities',activities) |
| 328 | +returnactivities |
| 329 | + |
| 330 | + |
| 331 | +# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]): |
| 332 | +# def __init__( |
| 333 | +# self, |
| 334 | +# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], |
| 335 | +# max_retries: int = 1, |
| 336 | +# temporal_settings: TemporalSettings | None = None, |
| 337 | +# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None, |
| 338 | +# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None, |
| 339 | +# ): |
| 340 | +# super().__init__(tools, max_retries) |
| 341 | +# self.temporal_settings = temporal_settings or TemporalSettings() |
| 342 | +# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context |
| 343 | +# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context |
| 344 | + |
| 345 | +# @activity.defn(name='function_toolset_call_tool') |
| 346 | +# async def call_tool_activity(params: FunctionCallToolParams) -> Any: |
| 347 | +# ctx = self.deserialize_run_context(params.serialized_run_context) |
| 348 | +# tool = (await self.get_tools(ctx))[params.name] |
| 349 | +# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool) |
| 350 | + |
| 351 | +# self.call_tool_activity = call_tool_activity |
| 352 | + |
| 353 | +# async def call_tool( |
| 354 | +# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] |
| 355 | +# ) -> Any: |
| 356 | +# serialized_run_context = self.serialize_run_context(ctx) |
| 357 | +# return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] |
| 358 | +# activity=self.call_tool_activity, |
| 359 | +# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), |
| 360 | +# **self.temporal_settings.__dict__, |
| 361 | +# ) |
| 362 | + |
| 363 | + |
| 364 | +deftemporalize_agent(agent:Agent,temporal_settings:TemporalSettings|None=None)->list[Callable[...,Any]]: |
| 365 | +ifexisting_activities:=getattr(agent,'__temporal_activities',None): |
| 366 | +returnexisting_activities |
| 367 | + |
| 368 | +temporal_settings=temporal_settingsorTemporalSettings() |
| 369 | + |
| 370 | +activities:list[Callable[...,Any]]= [] |
| 371 | +ifisinstance(agent.model,Model): |
| 372 | +# Doesn't work when model is not set already |
| 373 | +activities.extend(temporalize_model(agent.model,temporal_settings)) |
| 374 | + |
| 375 | +# TODO : Make TemporalMCPServer a wrapper |
| 376 | +iftoolset:=agent._get_toolset(): |
| 377 | +# Doesn't consider toolsets passed at iter time |
| 378 | +deftemporalize_toolset(toolset:AbstractToolset[AgentDepsT])->None: |
| 379 | +ifisinstance(toolset,FunctionToolset): |
| 380 | +activities.extend(temporalize_function_toolset(toolset,temporal_settings)) |
| 381 | +elifisinstance(toolset,MCPServer): |
| 382 | +activities.extend(temporalize_mcp_server(toolset,temporal_settings)) |
| 383 | + |
| 384 | +toolset.apply(temporalize_toolset) |
| 385 | + |
| 386 | +setattr(agent,'__temporal_activities',activities) |
| 387 | +returnactivities |