- Notifications
You must be signed in to change notification settings - Fork1k
Open
Description
Question
I'm trying to implement an agent which streams outputs to the user and uses tools. Some tools require user approval, which means that we need to
- suspend the agent pending user approval and thus
- serialize all the messages so far so we can wake up the agent once we have the user's approval.
I'm trying to useAgent.iter
to achieve this, with a function along the lines of:
# Allows sending the user messages over e.g. a WebSocket:classOutputStreamBase(ABC):defon_tool_call_start(self,event:FunctionToolCallEvent)->None:"""Handle the start of a tool call event.""" ...@abstractmethoddefon_tool_call_end(self,event:FunctionToolResultEvent)->None:"""Handle the end of a tool call event.""" ...@abstractmethoddefon_part_start(self,index:int,part:TextPartDelta)->None:"""Handle the start of a part event.""" ...@abstractmethoddefon_part_delta(self,delta:TextPartDelta)->None:"""Handle the delta of a part event.""" ...@abstractmethoddefon_final_result(self,event:FinalResultEvent)->None:"""Handle the final result event.""" ...# Allows storing the current state to e.g. a database:classBaseStore(ABC):@abstractmethodasyncdefadd_messages(self,messages:Iterable[ModelMessage])->None: ...@abstractmethodasyncdefget_messages(self)->list[ModelMessage]: ...# Track the status of the agent (run):RunStatus=Literal['COMPLETED','AWAITING_APPROVAL','RUNNING']classRunResult(BaseModel):status:RunStatusoutput:str|None=None# This is the interesting part:asyncdefrun_agent_graph_step(agent:Agent,store:BaseStore,sink:OutputStreamBase,user_prompt:Optional[str]=None,)->RunResult:messages=awaitstore.get_messages()asyncwithagent.iter(user_prompt,message_history=messages)asrun:asyncfornodeinrun:ifagent.is_user_prompt_node(node):passelifAgent.is_model_request_node(node):asyncwithnode.stream(run.ctx)asrequest_stream:asyncforeventinrequest_stream:ifisinstance(event,PartStartEvent):ifisinstance(event.part,TextPartDelta):sink.on_part_start(event.index,event.part)elifisinstance(event,PartDeltaEvent):ifisinstance(event.delta,TextPartDelta):sink.on_part_delta(event.delta)elifisinstance(event.delta,ToolCallPartDelta):passelifisinstance(event,FinalResultEvent):passelifagent.is_call_tools_node(node):asyncwithnode.stream(run.ctx)ashandle_stream:asyncfortool_eventinhandle_stream:ifisinstance(tool_event,FunctionToolCallEvent):ifrequires_approval(tool_event):# ----> HERE: we need to get new_messages() and serialize them <----returnRunResult(status='AWAITING_APPROVAL')sink.on_tool_call_start(tool_event)elifisinstance(tool_event,FunctionToolResultEvent):sink.on_tool_call_end(tool_event)elifAgent.is_end_node(node):assert (run.resultisnotNoneandrun.result.output==node.data.output )returnRunResult(status='COMPLETED',output=node.data.output, )awaitstore.add_messages(run.result.new_messages())returnRunResult(status='RUNNING')
Is there a way to getnew_messages()
(so far) in the middle of aniter
run? Is there a better way to model user-in-the-loop pausing/resuming of an agent that need tool call approvals?
Additional Context
No response