6
6
from contextlib import AbstractContextManager ,ExitStack ,asynccontextmanager
7
7
from dataclasses import dataclass ,field
8
8
from functools import cached_property
9
- from typing import Any ,Generic ,cast
9
+ from typing import Any ,Generic ,cast , overload
10
10
11
11
import logfire_api
12
12
import typing_extensions
13
- from logfire_api import LogfireSpan
13
+ from opentelemetry . trace import Span
14
14
from typing_extensions import deprecated
15
15
from typing_inspection import typing_objects
16
16
17
17
from .import _utils ,exceptions ,mermaid
18
+ from ._utils import AbstractSpan
18
19
from .nodes import BaseNode ,DepsT ,End ,GraphRunContext ,NodeDef ,RunEndT ,StateT
19
20
from .persistence import BaseStatePersistence
20
21
from .persistence .in_mem import SimpleStatePersistence
@@ -125,7 +126,6 @@ async def run(
125
126
deps :DepsT = None ,
126
127
persistence :BaseStatePersistence [StateT ,RunEndT ]| None = None ,
127
128
infer_name :bool = True ,
128
- span :LogfireSpan | None = None ,
129
129
)-> GraphRunResult [StateT ,RunEndT ]:
130
130
"""Run the graph from a starting node until it ends.
131
131
@@ -137,8 +137,6 @@ async def run(
137
137
persistence: State persistence interface, defaults to
138
138
[`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
139
139
infer_name: Whether to infer the graph name from the calling frame.
140
- span: The span to use for the graph run. If not provided, a span will be created depending on the value of
141
- the `auto_instrument` field.
142
140
143
141
Returns:
144
142
A `GraphRunResult` containing information about the run, including its final result.
@@ -164,7 +162,7 @@ async def main():
164
162
self ._infer_name (inspect .currentframe ())
165
163
166
164
async with self .iter (
167
- start_node ,state = state ,deps = deps ,persistence = persistence ,span = span , infer_name = False
165
+ start_node ,state = state ,deps = deps ,persistence = persistence ,infer_name = False
168
166
)as graph_run :
169
167
async for _node in graph_run :
170
168
pass
@@ -214,7 +212,7 @@ async def iter(
214
212
state :StateT = None ,
215
213
deps :DepsT = None ,
216
214
persistence :BaseStatePersistence [StateT ,RunEndT ]| None = None ,
217
- span :AbstractContextManager [Any ]| None = None ,
215
+ span :AbstractContextManager [Span ]| None = None ,
218
216
infer_name :bool = True ,
219
217
)-> AsyncIterator [GraphRun [StateT ,DepsT ,RunEndT ]]:
220
218
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
@@ -252,14 +250,15 @@ async def iter(
252
250
persistence = SimpleStatePersistence ()
253
251
persistence .set_graph_types (self )
254
252
255
- if self .auto_instrument and span is None :
256
- span = logfire_api .span ('run graph {graph.name}' ,graph = self )
257
-
258
253
with ExitStack ()as stack :
259
- if span is not None :
260
- stack .enter_context (span )
254
+ entered_span :AbstractSpan | None = None
255
+ if span is None :
256
+ if self .auto_instrument :
257
+ entered_span = stack .enter_context (logfire_api .span ('run graph {graph.name}' ,graph = self ))
258
+ else :
259
+ entered_span = stack .enter_context (span )
261
260
yield GraphRun [StateT ,DepsT ,RunEndT ](
262
- graph = self ,start_node = start_node ,persistence = persistence ,state = state ,deps = deps
261
+ graph = self ,start_node = start_node ,persistence = persistence ,state = state ,deps = deps , span = entered_span
263
262
)
264
263
265
264
@asynccontextmanager
@@ -268,7 +267,7 @@ async def iter_from_persistence(
268
267
persistence :BaseStatePersistence [StateT ,RunEndT ],
269
268
* ,
270
269
deps :DepsT = None ,
271
- span :AbstractContextManager [Any ]| None = None ,
270
+ span :AbstractContextManager [AbstractSpan ]| None = None ,
272
271
infer_name :bool = True ,
273
272
)-> AsyncIterator [GraphRun [StateT ,DepsT ,RunEndT ]]:
274
273
"""A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object.
@@ -301,15 +300,15 @@ async def iter_from_persistence(
301
300
span = logfire_api .span ('run graph {graph.name}' ,graph = self )
302
301
303
302
with ExitStack ()as stack :
304
- if span is not None :
305
- stack .enter_context (span )
303
+ entered_span = None if span is None else stack .enter_context (span )
306
304
yield GraphRun [StateT ,DepsT ,RunEndT ](
307
305
graph = self ,
308
306
start_node = snapshot .node ,
309
307
persistence = persistence ,
310
308
state = snapshot .state ,
311
309
deps = deps ,
312
310
snapshot_id = snapshot .id ,
311
+ span = entered_span ,
313
312
)
314
313
315
314
async def initialize (
@@ -370,6 +369,7 @@ async def next(
370
369
persistence = persistence ,
371
370
state = state ,
372
371
deps = deps ,
372
+ span = None ,
373
373
)
374
374
return await run .next (node )
375
375
@@ -644,6 +644,7 @@ def __init__(
644
644
persistence :BaseStatePersistence [StateT ,RunEndT ],
645
645
state :StateT ,
646
646
deps :DepsT ,
647
+ span :AbstractSpan | None ,
647
648
snapshot_id :str | None = None ,
648
649
):
649
650
"""Create a new run for a given graph, starting at the specified node.
@@ -658,6 +659,7 @@ def __init__(
658
659
to all nodes via `ctx.state`.
659
660
deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections,
660
661
configuration, or logging clients.
662
+ span: The span used for the graph run.
661
663
snapshot_id: The ID of the snapshot the node came from.
662
664
"""
663
665
self .graph = graph
@@ -666,9 +668,19 @@ def __init__(
666
668
self .state = state
667
669
self .deps = deps
668
670
671
+ self .__span = span
669
672
self ._next_node :BaseNode [StateT ,DepsT ,RunEndT ]| End [RunEndT ]= start_node
670
673
self ._is_started :bool = False
671
674
675
+ @overload
676
+ def _span (self ,* ,required :typing_extensions .Literal [False ])-> AbstractSpan | None : ...
677
+ @overload
678
+ def _span (self )-> AbstractSpan : ...
679
+ def _span (self ,* ,required :bool = True )-> AbstractSpan | None :
680
+ if self .__span is None and required :# pragma: no cover
681
+ raise exceptions .GraphRuntimeError ('No span available for this graph run.' )
682
+ return self .__span
683
+
672
684
@property
673
685
def next_node (self )-> BaseNode [StateT ,DepsT ,RunEndT ]| End [RunEndT ]:
674
686
"""The next node that will be run in the graph.
@@ -682,10 +694,8 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
682
694
"""The final result of the graph run if the run is completed, otherwise `None`."""
683
695
if not isinstance (self ._next_node ,End ):
684
696
return None # The GraphRun has not finished running
685
- return GraphRunResult (
686
- self ._next_node .data ,
687
- state = self .state ,
688
- persistence = self .persistence ,
697
+ return GraphRunResult [StateT ,RunEndT ](
698
+ self ._next_node .data ,state = self .state ,persistence = self .persistence ,span = self ._span (required = False )
689
699
)
690
700
691
701
async def next (
@@ -793,10 +803,31 @@ def __repr__(self) -> str:
793
803
return f'<GraphRun graph={ self .graph .name or "[unnamed]" } >'
794
804
795
805
796
- @dataclass
806
+ @dataclass ( init = False )
797
807
class GraphRunResult (Generic [StateT ,RunEndT ]):
798
808
"""The final result of running a graph."""
799
809
800
810
output :RunEndT
801
811
state :StateT
802
812
persistence :BaseStatePersistence [StateT ,RunEndT ]= field (repr = False )
813
+
814
+ def __init__ (
815
+ self ,
816
+ output :RunEndT ,
817
+ state :StateT ,
818
+ persistence :BaseStatePersistence [StateT ,RunEndT ],
819
+ span :AbstractSpan | None = None ,
820
+ ):
821
+ self .output = output
822
+ self .state = state
823
+ self .persistence = persistence
824
+ self .__span = span
825
+
826
+ @overload
827
+ def _span (self ,* ,required :typing_extensions .Literal [False ])-> AbstractSpan | None : ...
828
+ @overload
829
+ def _span (self )-> AbstractSpan : ...
830
+ def _span (self ,* ,required :bool = True )-> AbstractSpan | None :# pragma: no cover
831
+ if self .__span is None and required :
832
+ raise exceptions .GraphRuntimeError ('No span available for this graph run.' )
833
+ return self .__span