@@ -391,7 +391,7 @@ async def ret_a(x: str) -> str: # pragma: no cover
391
391
)
392
392
393
393
394
- class ResultType (BaseModel ):
394
+ class OutputType (BaseModel ):
395
395
"""Result type used by all tests."""
396
396
397
397
value :str
@@ -407,7 +407,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
407
407
yield {2 :DeltaToolCall ('regular_tool' ,'{"x": 1}' )}
408
408
yield {3 :DeltaToolCall ('another_tool' ,'{"y": 2}' )}
409
409
410
- agent = Agent (FunctionModel (stream_function = sf ),output_type = ResultType ,end_strategy = 'early' )
410
+ agent = Agent (FunctionModel (stream_function = sf ),output_type = OutputType ,end_strategy = 'early' )
411
411
412
412
@agent .tool_plain
413
413
def regular_tool (x :int )-> int :# pragma: no cover
@@ -476,7 +476,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
476
476
yield {1 :DeltaToolCall ('final_result' ,'{"value": "first"}' )}
477
477
yield {2 :DeltaToolCall ('final_result' ,'{"value": "second"}' )}
478
478
479
- agent = Agent (FunctionModel (stream_function = sf ),output_type = ResultType ,end_strategy = 'early' )
479
+ agent = Agent (FunctionModel (stream_function = sf ),output_type = OutputType ,end_strategy = 'early' )
480
480
481
481
async with agent .run_stream ('test multiple final results' )as result :
482
482
response = await result .get_output ()
@@ -529,7 +529,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
529
529
yield {4 :DeltaToolCall ('final_result' ,'{"value": "second"}' )}
530
530
yield {5 :DeltaToolCall ('unknown_tool' ,'{"value": "???"}' )}
531
531
532
- agent = Agent (FunctionModel (stream_function = sf ),output_type = ResultType ,end_strategy = 'exhaustive' )
532
+ agent = Agent (FunctionModel (stream_function = sf ),output_type = OutputType ,end_strategy = 'exhaustive' )
533
533
534
534
@agent .tool_plain
535
535
def regular_tool (x :int )-> int :
@@ -606,7 +606,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
606
606
yield {3 :DeltaToolCall ('another_tool' ,'{"y": 2}' )}
607
607
yield {4 :DeltaToolCall ('unknown_tool' ,'{"value": "???"}' )}
608
608
609
- agent = Agent (FunctionModel (stream_function = sf ),output_type = ResultType ,end_strategy = 'early' )
609
+ agent = Agent (FunctionModel (stream_function = sf ),output_type = OutputType ,end_strategy = 'early' )
610
610
611
611
@agent .tool_plain
612
612
def regular_tool (x :int )-> int :# pragma: no cover
@@ -715,7 +715,7 @@ def another_tool(y: int) -> int: # pragma: no cover
715
715
async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool ():
716
716
"""Test that 'early' strategy does not apply to tool calls without final tool."""
717
717
tool_called :list [str ]= []
718
- agent = Agent (TestModel (),output_type = ResultType ,end_strategy = 'early' )
718
+ agent = Agent (TestModel (),output_type = OutputType ,end_strategy = 'early' )
719
719
720
720
@agent .tool_plain
721
721
def regular_tool (x :int )-> int :
@@ -777,17 +777,17 @@ async def test_custom_output_type_default_str() -> None:
777
777
response = await result .get_output ()
778
778
assert response == snapshot ('success (no tool calls)' )
779
779
780
- async with agent .run_stream ('test' ,output_type = ResultType )as result :
780
+ async with agent .run_stream ('test' ,output_type = OutputType )as result :
781
781
response = await result .get_output ()
782
- assert response == snapshot (ResultType (value = 'a' ))
782
+ assert response == snapshot (OutputType (value = 'a' ))
783
783
784
784
785
785
async def test_custom_output_type_default_structured ()-> None :
786
- agent = Agent ('test' ,output_type = ResultType )
786
+ agent = Agent ('test' ,output_type = OutputType )
787
787
788
788
async with agent .run_stream ('test' )as result :
789
789
response = await result .get_output ()
790
- assert response == snapshot (ResultType (value = 'a' ))
790
+ assert response == snapshot (OutputType (value = 'a' ))
791
791
792
792
async with agent .run_stream ('test' ,output_type = str )as result :
793
793
response = await result .get_output ()
@@ -880,21 +880,21 @@ def output_validator_simple(data: str) -> str:
880
880
881
881
882
882
async def test_stream_iter_structured_validator ()-> None :
883
- class NotResultType (BaseModel ):
883
+ class NotOutputType (BaseModel ):
884
884
not_value :str
885
885
886
- agent = Agent [None ,Union [ResultType , NotResultType ]]('test' ,output_type = Union [ResultType , NotResultType ])# pyright: ignore[reportArgumentType]
886
+ agent = Agent [None ,Union [OutputType , NotOutputType ]]('test' ,output_type = Union [OutputType , NotOutputType ])# pyright: ignore[reportArgumentType]
887
887
888
888
@agent .output_validator
889
- def output_validator (data :ResultType | NotResultType )-> ResultType | NotResultType :
890
- assert isinstance (data ,ResultType )
891
- return ResultType (value = data .value + ' (validated)' )
889
+ def output_validator (data :OutputType | NotOutputType )-> OutputType | NotOutputType :
890
+ assert isinstance (data ,OutputType )
891
+ return OutputType (value = data .value + ' (validated)' )
892
892
893
- outputs :list [ResultType ]= []
893
+ outputs :list [OutputType ]= []
894
894
async with agent .iter ('test' )as run :
895
895
async for node in run :
896
896
if agent .is_model_request_node (node ):
897
897
async with node .stream (run .ctx )as stream :
898
898
async for output in stream .stream_output (debounce_by = None ):
899
899
outputs .append (output )
900
- assert outputs == [ResultType (value = 'a (validated)' ),ResultType (value = 'a (validated)' )]
900
+ assert outputs == [OutputType (value = 'a (validated)' ),OutputType (value = 'a (validated)' )]