1+ import random
2+ import threading
3+ import time
4+ from unittest .mock import patch
5+ import pytest
6+
7+ from databricks .sql .telemetry .models .enums import StatementType
8+ from databricks .sql .telemetry .telemetry_client import TelemetryClient ,TelemetryClientFactory
9+ from tests .e2e .test_driver import PySQLPytestTestCase
10+
11+ def run_in_threads (target ,num_threads ,pass_index = False ):
12+ """Helper to run target function in multiple threads."""
13+ threads = [
14+ threading .Thread (target = target ,args = (i ,)if pass_index else ())
15+ for i in range (num_threads )
16+ ]
17+ for t in threads :
18+ t .start ()
19+ for t in threads :
20+ t .join ()
21+
22+
23+ class TestE2ETelemetry (PySQLPytestTestCase ):
24+
25+ @pytest .fixture (autouse = True )
26+ def telemetry_setup_teardown (self ):
27+ """
28+ This fixture ensures the TelemetryClientFactory is in a clean state
29+ before each test and shuts it down afterward. Using a fixture makes
30+ this robust and automatic.
31+ """
32+ try :
33+ yield
34+ finally :
35+ if TelemetryClientFactory ._executor :
36+ TelemetryClientFactory ._executor .shutdown (wait = True )
37+ TelemetryClientFactory ._executor = None
38+ TelemetryClientFactory ._initialized = False
39+
40+ def test_concurrent_queries_sends_telemetry (self ):
41+ """
42+ An E2E test where concurrent threads execute real queries against
43+ the staging endpoint, while we capture and verify the generated telemetry.
44+ """
45+ num_threads = 30
46+ capture_lock = threading .Lock ()
47+ captured_telemetry = []
48+ captured_session_ids = []
49+ captured_statement_ids = []
50+ captured_responses = []
51+ captured_exceptions = []
52+
53+ original_send_telemetry = TelemetryClient ._send_telemetry
54+ original_callback = TelemetryClient ._telemetry_request_callback
55+
56+ def send_telemetry_wrapper (self_client ,events ):
57+ with capture_lock :
58+ captured_telemetry .extend (events )
59+ original_send_telemetry (self_client ,events )
60+
61+ def callback_wrapper (self_client ,future ,sent_count ):
62+ """
63+ Wraps the original callback to capture the server's response
64+ or any exceptions from the async network call.
65+ """
66+ try :
67+ original_callback (self_client ,future ,sent_count )
68+
69+ # Now, capture the result for our assertions
70+ response = future .result ()
71+ response .raise_for_status ()# Raise an exception for 4xx/5xx errors
72+ telemetry_response = response .json ()
73+ with capture_lock :
74+ captured_responses .append (telemetry_response )
75+ except Exception as e :
76+ with capture_lock :
77+ captured_exceptions .append (e )
78+
79+ with patch .object (TelemetryClient ,"_send_telemetry" ,send_telemetry_wrapper ), \
80+ patch .object (TelemetryClient ,"_telemetry_request_callback" ,callback_wrapper ):
81+
82+ def execute_query_worker (thread_id ):
83+ """Each thread creates a connection and executes a query."""
84+
85+ time .sleep (random .uniform (0 ,0.05 ))
86+
87+ with self .connection (extra_params = {"enable_telemetry" :True })as conn :
88+ # Capture the session ID from the connection before executing the query
89+ session_id_hex = conn .get_session_id_hex ()
90+ with capture_lock :
91+ captured_session_ids .append (session_id_hex )
92+
93+ with conn .cursor ()as cursor :
94+ cursor .execute (f"SELECT{ thread_id } " )
95+ # Capture the statement ID after executing the query
96+ statement_id = cursor .query_id
97+ with capture_lock :
98+ captured_statement_ids .append (statement_id )
99+ cursor .fetchall ()
100+
101+ # Run the workers concurrently
102+ run_in_threads (execute_query_worker ,num_threads ,pass_index = True )
103+
104+ if TelemetryClientFactory ._executor :
105+ TelemetryClientFactory ._executor .shutdown (wait = True )
106+
107+ # --- VERIFICATION ---
108+ assert not captured_exceptions
109+ assert len (captured_responses )> 0
110+
111+ total_successful_events = 0
112+ for response in captured_responses :
113+ assert "errors" not in response or not response ["errors" ]
114+ if "numProtoSuccess" in response :
115+ total_successful_events += response ["numProtoSuccess" ]
116+ assert total_successful_events == num_threads * 2
117+
118+ assert len (captured_telemetry )== num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
119+ assert len (captured_session_ids )== num_threads # One session ID per thread
120+ assert len (captured_statement_ids )== num_threads # One statement ID per thread (per query)
121+
122+ # Separate initial logs from latency logs
123+ initial_logs = [
124+ e for e in captured_telemetry
125+ if e .entry .sql_driver_log .operation_latency_ms is None
126+ and e .entry .sql_driver_log .driver_connection_params is not None
127+ and e .entry .sql_driver_log .system_configuration is not None
128+ ]
129+ latency_logs = [
130+ e for e in captured_telemetry
131+ if e .entry .sql_driver_log .operation_latency_ms is not None
132+ and e .entry .sql_driver_log .sql_statement_id is not None
133+ and e .entry .sql_driver_log .sql_operation .statement_type == StatementType .QUERY
134+ ]
135+
136+ # Verify counts
137+ assert len (initial_logs )== num_threads
138+ assert len (latency_logs )== num_threads
139+
140+ # Verify that telemetry events contain the exact session IDs we captured from connections
141+ telemetry_session_ids = set ()
142+ for event in captured_telemetry :
143+ session_id = event .entry .sql_driver_log .session_id
144+ assert session_id is not None
145+ telemetry_session_ids .add (session_id )
146+
147+ captured_session_ids_set = set (captured_session_ids )
148+ assert telemetry_session_ids == captured_session_ids_set
149+ assert len (captured_session_ids_set )== num_threads
150+
151+ # Verify that telemetry latency logs contain the exact statement IDs we captured from cursors
152+ telemetry_statement_ids = set ()
153+ for event in latency_logs :
154+ statement_id = event .entry .sql_driver_log .sql_statement_id
155+ assert statement_id is not None
156+ telemetry_statement_ids .add (statement_id )
157+
158+ captured_statement_ids_set = set (captured_statement_ids )
159+ assert telemetry_statement_ids == captured_statement_ids_set
160+ assert len (captured_statement_ids_set )== num_threads
161+
162+ # Verify that each latency log has a statement ID from our captured set
163+ for event in latency_logs :
164+ log = event .entry .sql_driver_log
165+ assert log .sql_statement_id in captured_statement_ids
166+ assert log .session_id in captured_session_ids