66import pytest
77
88from databricks .sql .telemetry .models .enums import StatementType
9- from databricks .sql .telemetry .telemetry_client import TelemetryClient ,TelemetryClientFactory
9+ from databricks .sql .telemetry .telemetry_client import (
10+ TelemetryClient ,
11+ TelemetryClientFactory ,
12+ )
1013from tests .e2e .test_driver import PySQLPytestTestCase
1114
15+
1216def run_in_threads (target ,num_threads ,pass_index = False ):
1317"""Helper to run target function in multiple threads."""
1418threads = [
@@ -22,7 +26,6 @@ def run_in_threads(target, num_threads, pass_index=False):
2226
2327
2428class TestE2ETelemetry (PySQLPytestTestCase ):
25-
2629@pytest .fixture (autouse = True )
2730def telemetry_setup_teardown (self ):
2831"""
@@ -31,7 +34,7 @@ def telemetry_setup_teardown(self):
3134 this robust and automatic.
3235 """
3336try :
34- yield
37+ yield
3538finally :
3639if TelemetryClientFactory ._executor :
3740TelemetryClientFactory ._executor .shutdown (wait = True )
@@ -68,20 +71,25 @@ def callback_wrapper(self_client, future, sent_count):
6871captured_futures .append (future )
6972original_callback (self_client ,future ,sent_count )
7073
71- with patch .object (TelemetryClient ,"_send_telemetry" ,send_telemetry_wrapper ), \
72- patch .object (TelemetryClient ,"_telemetry_request_callback" ,callback_wrapper ):
74+ with patch .object (
75+ TelemetryClient ,"_send_telemetry" ,send_telemetry_wrapper
76+ ),patch .object (
77+ TelemetryClient ,"_telemetry_request_callback" ,callback_wrapper
78+ ):
7379
7480def execute_query_worker (thread_id ):
7581"""Each thread creates a connection and executes a query."""
7682
7783time .sleep (random .uniform (0 ,0.05 ))
78-
79- with self .connection (extra_params = {"force_enable_telemetry" :True })as conn :
84+
85+ with self .connection (
86+ extra_params = {"force_enable_telemetry" :True }
87+ )as conn :
8088# Capture the session ID from the connection before executing the query
8189session_id_hex = conn .get_session_id_hex ()
8290with capture_lock :
8391captured_session_ids .append (session_id_hex )
84-
92+
8593with conn .cursor ()as cursor :
8694cursor .execute (f"SELECT{ thread_id } " )
8795# Capture the statement ID after executing the query
@@ -97,7 +105,10 @@ def execute_query_worker(thread_id):
97105start_time = time .time ()
98106expected_event_count = num_threads
99107
100- while len (captured_futures )< expected_event_count and time .time ()- start_time < timeout_seconds :
108+ while (
109+ len (captured_futures )< expected_event_count
110+ and time .time ()- start_time < timeout_seconds
111+ ):
101112time .sleep (0.1 )
102113
103114done ,not_done = wait (captured_futures ,timeout = timeout_seconds )
@@ -115,30 +126,37 @@ def execute_query_worker(thread_id):
115126
116127assert not captured_exceptions
117128assert len (captured_responses )> 0
118-
129+
119130total_successful_events = 0
120131for response in captured_responses :
121132assert "errors" not in response or not response ["errors" ]
122133if "numProtoSuccess" in response :
123134total_successful_events += response ["numProtoSuccess" ]
124135assert total_successful_events == num_threads * 2
125136
126- assert len (captured_telemetry )== num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
137+ assert (
138+ len (captured_telemetry )== num_threads * 2
139+ )# 2 events per thread (initial_telemetry_log, latency_log (execute))
127140assert len (captured_session_ids )== num_threads # One session ID per thread
128- assert len (captured_statement_ids )== num_threads # One statement ID per thread (per query)
141+ assert (
142+ len (captured_statement_ids )== num_threads
143+ )# One statement ID per thread (per query)
129144
130145# Separate initial logs from latency logs
131146initial_logs = [
132- e for e in captured_telemetry
147+ e
148+ for e in captured_telemetry
133149if e .entry .sql_driver_log .operation_latency_ms is None
134150and e .entry .sql_driver_log .driver_connection_params is not None
135151and e .entry .sql_driver_log .system_configuration is not None
136152 ]
137153latency_logs = [
138- e for e in captured_telemetry
139- if e .entry .sql_driver_log .operation_latency_ms is not None
140- and e .entry .sql_driver_log .sql_statement_id is not None
141- and e .entry .sql_driver_log .sql_operation .statement_type == StatementType .QUERY
154+ e
155+ for e in captured_telemetry
156+ if e .entry .sql_driver_log .operation_latency_ms is not None
157+ and e .entry .sql_driver_log .sql_statement_id is not None
158+ and e .entry .sql_driver_log .sql_operation .statement_type
159+ == StatementType .QUERY
142160 ]
143161
144162# Verify counts
@@ -171,4 +189,4 @@ def execute_query_worker(thread_id):
171189for event in latency_logs :
172190log = event .entry .sql_driver_log
173191assert log .sql_statement_id in captured_statement_ids
174- assert log .session_id in captured_session_ids
192+ assert log .session_id in captured_session_ids