22
33from __future__import annotations
44
5+ import dataclasses
56from collections .abc import Callable ,Iterator ,Mapping ,Sequence
67from contextlib import contextmanager
78from dataclasses import dataclass
89from typing import (
910Any ,
10- Dict ,
11+ Generic ,
1112NoReturn ,
12- Optional ,
13- Type ,
1413TypeAlias ,
14+ TypeVar ,
1515cast ,
1616)
1717
18+ import nexusrpc .handler
1819import opentelemetry .baggage .propagation
1920import opentelemetry .context
2021import opentelemetry .context .context
5455
5556_CarrierDict :TypeAlias = dict [str ,opentelemetry .propagators .textmap .CarrierValT ]
5657
58+ _ContextT = TypeVar ("_ContextT" ,bound = nexusrpc .handler .OperationContext )
59+
5760
5861class TracingInterceptor (temporalio .client .Interceptor ,temporalio .worker .Interceptor ):
5962"""Interceptor that supports client and worker OpenTelemetry span creation
@@ -133,6 +136,14 @@ def workflow_interceptor_class(
133136 )
134137return TracingWorkflowInboundInterceptor
135138
139+ def intercept_nexus_operation (
140+ self ,next :temporalio .worker .NexusOperationInboundInterceptor
141+ )-> temporalio .worker .NexusOperationInboundInterceptor :
142+ """Implementation of
143+ :py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`.
144+ """
145+ return _TracingNexusOperationInboundInterceptor (next ,self )
146+
136147def _context_to_headers (
137148self ,headers :Mapping [str ,temporalio .api .common .v1 .Payload ]
138149 )-> Mapping [str ,temporalio .api .common .v1 .Payload ]:
@@ -166,7 +177,8 @@ def _start_as_current_span(
166177name :str ,
167178* ,
168179attributes :opentelemetry .util .types .Attributes ,
169- input :_InputWithHeaders | None = None ,
180+ input_with_headers :_InputWithHeaders | None = None ,
181+ input_with_ctx :_InputWithOperationContext | None = None ,
170182kind :opentelemetry .trace .SpanKind ,
171183context :Context | None = None ,
172184 )-> Iterator [None ]:
@@ -179,8 +191,19 @@ def _start_as_current_span(
179191context = context ,
180192set_status_on_exception = False ,
181193 )as span :
182- if input :
183- input .headers = self ._context_to_headers (input .headers )
194+ if input_with_headers :
195+ input_with_headers .headers = self ._context_to_headers (
196+ input_with_headers .headers
197+ )
198+ if input_with_ctx :
199+ carrier :_CarrierDict = {}
200+ self .text_map_propagator .inject (carrier )
201+ input_with_ctx .ctx = dataclasses .replace (
202+ input_with_ctx .ctx ,
203+ headers = _carrier_to_nexus_headers (
204+ carrier ,input_with_ctx .ctx .headers
205+ ),
206+ )
184207try :
185208yield None
186209except Exception as exc :
@@ -258,7 +281,7 @@ async def start_workflow(
258281with self .root ._start_as_current_span (
259282f"{ prefix } :{ input .workflow } " ,
260283attributes = {"temporalWorkflowID" :input .id },
261- input = input ,
284+ input_with_headers = input ,
262285kind = opentelemetry .trace .SpanKind .CLIENT ,
263286 ):
264287return await super ().start_workflow (input )
@@ -267,7 +290,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
267290with self .root ._start_as_current_span (
268291f"QueryWorkflow:{ input .query } " ,
269292attributes = {"temporalWorkflowID" :input .id },
270- input = input ,
293+ input_with_headers = input ,
271294kind = opentelemetry .trace .SpanKind .CLIENT ,
272295 ):
273296return await super ().query_workflow (input )
@@ -278,7 +301,7 @@ async def signal_workflow(
278301with self .root ._start_as_current_span (
279302f"SignalWorkflow:{ input .signal } " ,
280303attributes = {"temporalWorkflowID" :input .id },
281- input = input ,
304+ input_with_headers = input ,
282305kind = opentelemetry .trace .SpanKind .CLIENT ,
283306 ):
284307return await super ().signal_workflow (input )
@@ -289,7 +312,7 @@ async def start_workflow_update(
289312with self .root ._start_as_current_span (
290313f"StartWorkflowUpdate:{ input .update } " ,
291314attributes = {"temporalWorkflowID" :input .id },
292- input = input ,
315+ input_with_headers = input ,
293316kind = opentelemetry .trace .SpanKind .CLIENT ,
294317 ):
295318return await super ().start_workflow_update (input )
@@ -306,7 +329,7 @@ async def start_update_with_start_workflow(
306329with self .root ._start_as_current_span (
307330f"StartUpdateWithStartWorkflow:{ input .start_workflow_input .workflow } " ,
308331attributes = attrs ,
309- input = input .start_workflow_input ,
332+ input_with_headers = input .start_workflow_input ,
310333kind = opentelemetry .trace .SpanKind .CLIENT ,
311334 ):
312335otel_header = input .start_workflow_input .headers .get (self .root .header_key )
@@ -345,10 +368,60 @@ async def execute_activity(
345368return await super ().execute_activity (input )
346369
347370
371+ class _TracingNexusOperationInboundInterceptor (
372+ temporalio .worker .NexusOperationInboundInterceptor
373+ ):
374+ def __init__ (
375+ self ,
376+ next :temporalio .worker .NexusOperationInboundInterceptor ,
377+ root :TracingInterceptor ,
378+ )-> None :
379+ super ().__init__ (next )
380+ self ._root = root
381+
382+ def _context_from_nexus_headers (self ,headers :Mapping [str ,str ]):
383+ return self ._root .text_map_propagator .extract (headers )
384+
385+ async def execute_nexus_operation_start (
386+ self ,input :temporalio .worker .ExecuteNexusOperationStartInput
387+ )-> (
388+ nexusrpc .handler .StartOperationResultSync [Any ]
389+ | nexusrpc .handler .StartOperationResultAsync
390+ ):
391+ with self ._root ._start_as_current_span (
392+ f"RunStartNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
393+ context = self ._context_from_nexus_headers (input .ctx .headers ),
394+ attributes = {},
395+ input_with_ctx = input ,
396+ kind = opentelemetry .trace .SpanKind .SERVER ,
397+ ):
398+ return await self .next .execute_nexus_operation_start (input )
399+
400+ async def execute_nexus_operation_cancel (
401+ self ,input :temporalio .worker .ExecuteNexusOperationCancelInput
402+ )-> None :
403+ with self ._root ._start_as_current_span (
404+ f"RunCancelNexusOperationHandler:{ input .ctx .service } /{ input .ctx .operation } " ,
405+ context = self ._context_from_nexus_headers (input .ctx .headers ),
406+ attributes = {},
407+ input_with_ctx = input ,
408+ kind = opentelemetry .trace .SpanKind .SERVER ,
409+ ):
410+ return await self .next .execute_nexus_operation_cancel (input )
411+
412+
348413class _InputWithHeaders (Protocol ):
349414headers :Mapping [str ,temporalio .api .common .v1 .Payload ]
350415
351416
417+ class _InputWithStringHeaders (Protocol ):
418+ headers :Mapping [str ,str ]| None
419+
420+
421+ class _InputWithOperationContext (Generic [_ContextT ],Protocol ):
422+ ctx :_ContextT
423+
424+
352425class _WorkflowExternFunctions (TypedDict ):
353426__temporal_opentelemetry_completed_span :Callable [
354427 [_CompletedWorkflowSpanParams ],_CarrierDict | None
@@ -602,6 +675,7 @@ def _completed_span(
602675* ,
603676link_context_carrier :_CarrierDict | None = None ,
604677add_to_outbound :_InputWithHeaders | None = None ,
678+ add_to_outbound_str :_InputWithStringHeaders | None = None ,
605679new_span_even_on_replay :bool = False ,
606680additional_attributes :opentelemetry .util .types .Attributes = None ,
607681exception :Exception | None = None ,
@@ -614,12 +688,14 @@ def _completed_span(
614688# Create the span. First serialize current context to carrier.
615689new_context_carrier :_CarrierDict = {}
616690self .text_map_propagator .inject (new_context_carrier )
691+
617692# Invoke
618693info = temporalio .workflow .info ()
619694attributes :dict [str ,opentelemetry .util .types .AttributeValue ]= {
620695"temporalWorkflowID" :info .workflow_id ,
621696"temporalRunID" :info .run_id ,
622697 }
698+
623699if additional_attributes :
624700attributes .update (additional_attributes )
625701updated_context_carrier = self ._extern_functions [
@@ -640,10 +716,16 @@ def _completed_span(
640716 )
641717
642718# Add to outbound if needed
643- if add_to_outbound and updated_context_carrier :
644- add_to_outbound .headers = self ._context_carrier_to_headers (
645- updated_context_carrier ,add_to_outbound .headers
646- )
719+ if updated_context_carrier :
720+ if add_to_outbound :
721+ add_to_outbound .headers = self ._context_carrier_to_headers (
722+ updated_context_carrier ,add_to_outbound .headers
723+ )
724+
725+ if add_to_outbound_str :
726+ add_to_outbound_str .headers = _carrier_to_nexus_headers (
727+ updated_context_carrier ,add_to_outbound_str .headers
728+ )
647729
648730def _set_on_context (
649731self ,context :opentelemetry .context .Context
@@ -722,6 +804,29 @@ def start_local_activity(
722804 )
723805return super ().start_local_activity (input )
724806
807+ async def start_nexus_operation (
808+ self ,input :temporalio .worker .StartNexusOperationInput [Any ,Any ]
809+ )-> temporalio .workflow .NexusOperationHandle [Any ]:
810+ self .root ._completed_span (
811+ f"StartNexusOperation:{ input .service } /{ input .operation_name } " ,
812+ kind = opentelemetry .trace .SpanKind .CLIENT ,
813+ add_to_outbound_str = input ,
814+ )
815+
816+ return await super ().start_nexus_operation (input )
817+
818+
819+ def _carrier_to_nexus_headers (
820+ carrier :_CarrierDict ,initial :Mapping [str ,str ]| None = None
821+ )-> Mapping [str ,str ]:
822+ out = {** initial }if initial else {}
823+ for k ,v in carrier .items ():
824+ if isinstance (v ,list ):
825+ out [k ]= "," .join (v )
826+ else :
827+ out [k ]= v
828+ return out
829+
725830
726831class workflow :
727832"""Contains static methods that are safe to call from within a workflow.