Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit593ff0e

Browse files
authored
Yield initial node during Graph (and therefore Agent) iteration (pydantic#1412)
1 parent8ba6234 commit593ff0e

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

‎docs/agents.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ async def main():
121121
print(nodes)
122122
"""
123123
[
124+
UserPromptNode(
125+
user_prompt='What is the capital of France?',
126+
system_prompts=(),
127+
system_prompt_functions=[],
128+
system_prompt_dynamic_functions={},
129+
),
124130
ModelRequestNode(
125131
request=ModelRequest(
126132
parts=[
@@ -338,6 +344,7 @@ if __name__ == '__main__':
338344
print(output_messages)
339345
"""
340346
[
347+
'=== UserPromptNode: What will the weather be like in Paris on Tuesday? ===',
341348
'=== ModelRequestNode: streaming partial request tokens ===',
342349
'[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')',
343350
'[Request] Part 0 args_delta=ris","forecast_',

‎docs/graph.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ async def main():
510510
#> Node: CountDown()
511511
#> Node: CountDown()
512512
#> Node: CountDown()
513+
#> Node: CountDown()
513514
#> Node: End(data=0)
514515
print('Final result:', run.result.output)# (3)!
515516
#> Final result: 0

‎pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ async def main():
372372
print(nodes)
373373
'''
374374
[
375+
UserPromptNode(
376+
user_prompt='What is the capital of France?',
377+
system_prompts=(),
378+
system_prompt_functions=[],
379+
system_prompt_dynamic_functions={},
380+
),
375381
ModelRequestNode(
376382
request=ModelRequest(
377383
parts=[
@@ -1355,6 +1361,12 @@ async def main():
13551361
print(nodes)
13561362
'''
13571363
[
1364+
UserPromptNode(
1365+
user_prompt='What is the capital of France?',
1366+
system_prompts=(),
1367+
system_prompt_functions=[],
1368+
system_prompt_dynamic_functions={},
1369+
),
13581370
ModelRequestNode(
13591371
request=ModelRequest(
13601372
parts=[

‎pydantic_graph/pydantic_graph/graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ async def main():
607607
print(node_states)
608608
'''
609609
[
610+
(Increment(), MyState(number=1)),
610611
(Increment(), MyState(number=1)),
611612
(Check42(), MyState(number=2)),
612613
(End(data=2), MyState(number=2)),
@@ -621,6 +622,7 @@ async def main():
621622
print(node_states)
622623
'''
623624
[
625+
(Increment(), MyState(number=41)),
624626
(Increment(), MyState(number=41)),
625627
(Check42(), MyState(number=42)),
626628
(Increment(), MyState(number=42)),
@@ -665,6 +667,7 @@ def __init__(
665667
self.deps=deps
666668

667669
self._next_node:BaseNode[StateT,DepsT,RunEndT]|End[RunEndT]=start_node
670+
self._is_started:bool=False
668671

669672
@property
670673
defnext_node(self)->BaseNode[StateT,DepsT,RunEndT]|End[RunEndT]:
@@ -777,8 +780,13 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE
777780

778781
asyncdef__anext__(self)->BaseNode[StateT,DepsT,RunEndT]|End[RunEndT]:
779782
"""Use the last returned node as the input to `Graph.next`."""
783+
ifnotself._is_started:
784+
self._is_started=True
785+
returnself._next_node
786+
780787
ifisinstance(self._next_node,End):
781788
raiseStopAsyncIteration
789+
782790
returnawaitself.next(self._next_node)
783791

784792
def__repr__(self)->str:

‎tests/graph/test_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ async def test_iter():
312312
assertgraph_iter.result
313313
assertgraph_iter.result.output==8
314314

315-
assertnode_reprs==snapshot(["String2Length(input_data='3.14')",'Double(input_data=4)','End(data=8)'])
315+
assertnode_reprs==snapshot(
316+
['Float2String(input_data=3.14)',"String2Length(input_data='3.14')",'Double(input_data=4)','End(data=8)']
317+
)
316318

317319

318320
asyncdeftest_iter_next(mock_snapshot_id:object):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp