1
1
from __future__import annotations as _annotations
2
2
3
+ import dataclasses
3
4
import inspect
4
5
import json
5
6
from abc import ABC ,abstractmethod
6
7
from collections .abc import Awaitable ,Iterable ,Iterator ,Sequence
7
8
from dataclasses import dataclass ,field
8
9
from typing import TYPE_CHECKING ,Any ,Callable ,Generic ,Literal ,Union ,cast ,overload
9
10
11
+ from opentelemetry .trace import Tracer
10
12
from pydantic import TypeAdapter ,ValidationError
11
13
from pydantic_core import SchemaValidator
12
14
from typing_extensions import TypedDict ,TypeVar ,assert_never
13
15
16
+ from pydantic_graph .nodes import GraphRunContext
17
+
14
18
from .import _function_schema ,_utils ,messages as _messages
15
19
from ._run_context import AgentDepsT ,RunContext
16
20
from .exceptions import ModelRetry ,UserError
29
33
from .tools import GenerateToolJsonSchema ,ObjectJsonSchema ,ToolDefinition
30
34
31
35
if TYPE_CHECKING :
36
+ from pydantic_ai ._agent_graph import DepsT ,GraphAgentDeps ,GraphAgentState
37
+
32
38
from .profiles import ModelProfile
33
39
34
40
T = TypeVar ('T' )
66
72
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
67
73
68
74
75
+ @dataclass (frozen = True )
76
+ class TraceContext :
77
+ """A context for tracing output processing."""
78
+
79
+ tracer :Tracer
80
+ include_content :bool
81
+ call :_messages .ToolCallPart | None = None
82
+
83
+ def with_call (self ,call :_messages .ToolCallPart ):
84
+ return dataclasses .replace (self ,call = call )
85
+
86
+ async def execute_function_with_span (
87
+ self ,
88
+ function_schema :_function_schema .FunctionSchema ,
89
+ run_context :RunContext [AgentDepsT ],
90
+ args :dict [str ,Any ]| Any ,
91
+ call :_messages .ToolCallPart ,
92
+ include_tool_call_id :bool = True ,
93
+ )-> Any :
94
+ """Execute a function call within a traced span, automatically recording the response."""
95
+ # Set up span attributes
96
+ attributes = {
97
+ 'gen_ai.tool.name' :call .tool_name ,
98
+ 'logfire.msg' :f'running output function:{ call .tool_name } ' ,
99
+ }
100
+ if include_tool_call_id :
101
+ attributes ['gen_ai.tool.call.id' ]= call .tool_call_id
102
+ if self .include_content :
103
+ attributes ['tool_arguments' ]= call .args_as_json_str ()
104
+ attributes ['logfire.json_schema' ]= json .dumps (
105
+ {
106
+ 'type' :'object' ,
107
+ 'properties' : {
108
+ 'tool_arguments' : {'type' :'object' },
109
+ 'tool_response' : {'type' :'object' },
110
+ },
111
+ }
112
+ )
113
+
114
+ # Execute function within span
115
+ with self .tracer .start_as_current_span ('running output function' ,attributes = attributes )as span :
116
+ output = await function_schema .call (args ,run_context )
117
+
118
+ # Record response if content inclusion is enabled
119
+ if self .include_content and span .is_recording ():
120
+ from .models .instrumented import InstrumentedModel
121
+
122
+ span .set_attribute (
123
+ 'tool_response' ,
124
+ output if isinstance (output ,str )else json .dumps (InstrumentedModel .serialize_any (output )),
125
+ )
126
+
127
+ return output
128
+
129
+
130
+ def build_trace_context (ctx :GraphRunContext [GraphAgentState ,GraphAgentDeps [DepsT ,Any ]])-> TraceContext :
131
+ """Build a `TraceContext` from the current agent graph run context."""
132
+ return TraceContext (
133
+ tracer = ctx .deps .tracer ,
134
+ include_content = (
135
+ ctx .deps .instrumentation_settings is not None and ctx .deps .instrumentation_settings .include_content
136
+ ),
137
+ )
138
+
139
+
69
140
class ToolRetryError (Exception ):
70
141
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
71
142
@@ -96,6 +167,7 @@ async def validate(
96
167
result: The result data after Pydantic validation the message content.
97
168
tool_call: The original tool call message, `None` if there was no tool call.
98
169
run_context: The current run context.
170
+ trace_context: The trace context to use for tracing the output processing.
99
171
100
172
Returns:
101
173
Result of either the validated result data (ok) or a retry message (Err).
@@ -349,6 +421,7 @@ async def process(
349
421
self ,
350
422
text :str ,
351
423
run_context :RunContext [AgentDepsT ],
424
+ trace_context :TraceContext ,
352
425
allow_partial :bool = False ,
353
426
wrap_validation_errors :bool = True ,
354
427
)-> OutputDataT :
@@ -371,6 +444,7 @@ async def process(
371
444
self ,
372
445
text :str ,
373
446
run_context :RunContext [AgentDepsT ],
447
+ trace_context :TraceContext ,
374
448
allow_partial :bool = False ,
375
449
wrap_validation_errors :bool = True ,
376
450
)-> OutputDataT :
@@ -379,6 +453,7 @@ async def process(
379
453
Args:
380
454
text: The output text to validate.
381
455
run_context: The current run context.
456
+ trace_context: The trace context to use for tracing the output processing.
382
457
allow_partial: If true, allow partial validation.
383
458
wrap_validation_errors: If true, wrap the validation errors in a retry message.
384
459
@@ -389,7 +464,7 @@ async def process(
389
464
return cast (OutputDataT ,text )
390
465
391
466
return await self .processor .process (
392
- text ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
467
+ text ,run_context ,trace_context , allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
393
468
)
394
469
395
470
@@ -417,6 +492,7 @@ async def process(
417
492
self ,
418
493
text :str ,
419
494
run_context :RunContext [AgentDepsT ],
495
+ trace_context :TraceContext ,
420
496
allow_partial :bool = False ,
421
497
wrap_validation_errors :bool = True ,
422
498
)-> OutputDataT :
@@ -425,14 +501,15 @@ async def process(
425
501
Args:
426
502
text: The output text to validate.
427
503
run_context: The current run context.
504
+ trace_context: The trace context to use for tracing the output processing.
428
505
allow_partial: If true, allow partial validation.
429
506
wrap_validation_errors: If true, wrap the validation errors in a retry message.
430
507
431
508
Returns:
432
509
Either the validated output data (left) or a retry message (right).
433
510
"""
434
511
return await self .processor .process (
435
- text ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
512
+ text ,run_context ,trace_context , allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
436
513
)
437
514
438
515
@@ -468,6 +545,7 @@ async def process(
468
545
self ,
469
546
text :str ,
470
547
run_context :RunContext [AgentDepsT ],
548
+ trace_context :TraceContext ,
471
549
allow_partial :bool = False ,
472
550
wrap_validation_errors :bool = True ,
473
551
)-> OutputDataT :
@@ -476,6 +554,7 @@ async def process(
476
554
Args:
477
555
text: The output text to validate.
478
556
run_context: The current run context.
557
+ trace_context: The trace context to use for tracing the output processing.
479
558
allow_partial: If true, allow partial validation.
480
559
wrap_validation_errors: If true, wrap the validation errors in a retry message.
481
560
@@ -485,7 +564,7 @@ async def process(
485
564
text = _utils .strip_markdown_fences (text )
486
565
487
566
return await self .processor .process (
488
- text ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
567
+ text ,run_context ,trace_context , allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
489
568
)
490
569
491
570
@@ -568,6 +647,7 @@ async def process(
568
647
self ,
569
648
data :str ,
570
649
run_context :RunContext [AgentDepsT ],
650
+ trace_context :TraceContext ,
571
651
allow_partial :bool = False ,
572
652
wrap_validation_errors :bool = True ,
573
653
)-> OutputDataT :
@@ -637,6 +717,7 @@ async def process(
637
717
self ,
638
718
data :str | dict [str ,Any ]| None ,
639
719
run_context :RunContext [AgentDepsT ],
720
+ trace_context :TraceContext ,
640
721
allow_partial :bool = False ,
641
722
wrap_validation_errors :bool = True ,
642
723
)-> OutputDataT :
@@ -645,6 +726,7 @@ async def process(
645
726
Args:
646
727
data: The output data to validate.
647
728
run_context: The current run context.
729
+ trace_context: The trace context to use for tracing the output processing.
648
730
allow_partial: If true, allow partial validation.
649
731
wrap_validation_errors: If true, wrap the validation errors in a retry message.
650
732
@@ -670,8 +752,18 @@ async def process(
670
752
output = output [k ]
671
753
672
754
if self ._function_schema :
755
+ # Wraps the output function call in an OpenTelemetry span.
756
+ if trace_context .call :
757
+ call = trace_context .call
758
+ include_tool_call_id = True
759
+ else :
760
+ function_name = getattr (self ._function_schema .function ,'__name__' ,'output_function' )
761
+ call = _messages .ToolCallPart (tool_name = function_name ,args = data )
762
+ include_tool_call_id = False
673
763
try :
674
- output = await self ._function_schema .call (output ,run_context )
764
+ output = await trace_context .execute_function_with_span (
765
+ self ._function_schema ,run_context ,output ,call ,include_tool_call_id
766
+ )
675
767
except ModelRetry as r :
676
768
if wrap_validation_errors :
677
769
m = _messages .RetryPromptPart (
@@ -784,11 +876,12 @@ async def process(
784
876
self ,
785
877
data :str | dict [str ,Any ]| None ,
786
878
run_context :RunContext [AgentDepsT ],
879
+ trace_context :TraceContext ,
787
880
allow_partial :bool = False ,
788
881
wrap_validation_errors :bool = True ,
789
882
)-> OutputDataT :
790
883
union_object = await self ._union_processor .process (
791
- data ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
884
+ data ,run_context ,trace_context , allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
792
885
)
793
886
794
887
result = union_object .result
@@ -804,7 +897,7 @@ async def process(
804
897
raise
805
898
806
899
return await processor .process (
807
- data ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
900
+ data ,run_context ,trace_context , allow_partial = allow_partial ,wrap_validation_errors = wrap_validation_errors
808
901
)
809
902
810
903
@@ -835,13 +928,20 @@ async def process(
835
928
self ,
836
929
data :str ,
837
930
run_context :RunContext [AgentDepsT ],
931
+ trace_context :TraceContext ,
838
932
allow_partial :bool = False ,
839
933
wrap_validation_errors :bool = True ,
840
934
)-> OutputDataT :
841
935
args = {self ._str_argument_name :data }
842
-
936
+ # Wraps the output function call in an OpenTelemetry span.
937
+ # Note: PlainTextOutputProcessor is used for text responses (not tool calls),
938
+ # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
939
+ function_name = getattr (self ._function_schema .function ,'__name__' ,'text_output_function' )
940
+ call = _messages .ToolCallPart (tool_name = function_name ,args = args )
843
941
try :
844
- output = await self ._function_schema .call (args ,run_context )
942
+ output = await trace_context .execute_function_with_span (
943
+ self ._function_schema ,run_context ,args ,call ,include_tool_call_id = False
944
+ )
845
945
except ModelRetry as r :
846
946
if wrap_validation_errors :
847
947
m = _messages .RetryPromptPart (
@@ -881,6 +981,7 @@ async def process(
881
981
self ,
882
982
tool_call :_messages .ToolCallPart ,
883
983
run_context :RunContext [AgentDepsT ],
984
+ trace_context :TraceContext ,
884
985
allow_partial :bool = False ,
885
986
wrap_validation_errors :bool = True ,
886
987
)-> OutputDataT :
@@ -889,6 +990,7 @@ async def process(
889
990
Args:
890
991
tool_call: The tool call from the LLM to validate.
891
992
run_context: The current run context.
993
+ trace_context: The trace context to use for tracing the output processing.
892
994
allow_partial: If true, allow partial validation.
893
995
wrap_validation_errors: If true, wrap the validation errors in a retry message.
894
996
@@ -897,7 +999,11 @@ async def process(
897
999
"""
898
1000
try :
899
1001
output = await self .processor .process (
900
- tool_call .args ,run_context ,allow_partial = allow_partial ,wrap_validation_errors = False
1002
+ tool_call .args ,
1003
+ run_context ,
1004
+ trace_context .with_call (tool_call ),
1005
+ allow_partial = allow_partial ,
1006
+ wrap_validation_errors = False ,
901
1007
)
902
1008
except ValidationError as e :
903
1009
if wrap_validation_errors :