Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitaee6863

Browse files
authored
Telemetry server-side flag integration (databricks#646)
* feature_flagSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* fix static type checkSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* fix static type checkSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* force enable telemetrySigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* added flagSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* lintingSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* testsSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* changed flag value to be of any typeSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>* test fixSigned-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>---------Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent2f8b1ab commitaee6863

File tree

5 files changed

+289
-11
lines changed

5 files changed

+289
-11
lines changed

‎src/databricks/sql/client.py‎

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,6 @@ def read(self) -> Optional[OAuthToken]:
248248
self.lz4_compression=kwargs.get("enable_query_result_lz4_compression",True)
249249
self.use_cloud_fetch=kwargs.get("use_cloud_fetch",True)
250250
self._cursors= []# type: List[Cursor]
251-
252-
self.server_telemetry_enabled=True
253-
self.client_telemetry_enabled=kwargs.get("enable_telemetry",False)
254-
self.telemetry_enabled= (
255-
self.client_telemetry_enabledandself.server_telemetry_enabled
256-
)
257251
self.telemetry_batch_size=kwargs.get(
258252
"telemetry_batch_size",TelemetryClientFactory.DEFAULT_BATCH_SIZE
259253
)
@@ -288,6 +282,10 @@ def read(self) -> Optional[OAuthToken]:
288282
)
289283
self.staging_allowed_local_path=kwargs.get("staging_allowed_local_path",None)
290284

285+
self.force_enable_telemetry=kwargs.get("force_enable_telemetry",False)
286+
self.enable_telemetry=kwargs.get("enable_telemetry",False)
287+
self.telemetry_enabled=TelemetryHelper.is_telemetry_enabled(self)
288+
291289
TelemetryClientFactory.initialize_telemetry_client(
292290
telemetry_enabled=self.telemetry_enabled,
293291
session_id_hex=self.get_session_id_hex(),
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
importthreading
2+
importtime
3+
importrequests
4+
fromdataclassesimportdataclass,field
5+
fromconcurrent.futuresimportThreadPoolExecutor
6+
fromtypingimportDict,Optional,List,Any,TYPE_CHECKING
7+
8+
ifTYPE_CHECKING:
9+
fromdatabricks.sql.clientimportConnection
10+
11+
12+
@dataclass
13+
classFeatureFlagEntry:
14+
"""Represents a single feature flag from the server response."""
15+
16+
name:str
17+
value:str
18+
19+
20+
@dataclass
21+
classFeatureFlagsResponse:
22+
"""Represents the full JSON response from the feature flag endpoint."""
23+
24+
flags:List[FeatureFlagEntry]=field(default_factory=list)
25+
ttl_seconds:Optional[int]=None
26+
27+
@classmethod
28+
deffrom_dict(cls,data:Dict[str,Any])->"FeatureFlagsResponse":
29+
"""Factory method to create an instance from a dictionary (parsed JSON)."""
30+
flags_data=data.get("flags", [])
31+
flags_list= [FeatureFlagEntry(**flag)forflaginflags_data]
32+
returncls(flags=flags_list,ttl_seconds=data.get("ttl_seconds"))
33+
34+
35+
# --- Constants ---
36+
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT= (
37+
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
38+
)
39+
DEFAULT_TTL_SECONDS=900# 15 minutes
40+
REFRESH_BEFORE_EXPIRY_SECONDS=10# Start proactive refresh 10s before expiry
41+
42+
43+
classFeatureFlagsContext:
44+
"""
45+
Manages fetching and caching of server-side feature flags for a connection.
46+
47+
1. The very first check for any flag is a synchronous, BLOCKING operation.
48+
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
49+
in the background, returning stale data until the refresh completes.
50+
"""
51+
52+
def__init__(self,connection:"Connection",executor:ThreadPoolExecutor):
53+
fromdatabricks.sqlimport__version__
54+
55+
self._connection=connection
56+
self._executor=executor# Used for ASYNCHRONOUS refreshes
57+
self._lock=threading.RLock()
58+
59+
# Cache state: `None` indicates the cache has never been loaded.
60+
self._flags:Optional[Dict[str,str]]=None
61+
self._ttl_seconds:int=DEFAULT_TTL_SECONDS
62+
self._last_refresh_time:float=0
63+
64+
endpoint_suffix=FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
65+
self._feature_flag_endpoint= (
66+
f"https://{self._connection.session.host}{endpoint_suffix}"
67+
)
68+
69+
def_is_refresh_needed(self)->bool:
70+
"""Checks if the cache is due for a proactive background refresh."""
71+
ifself._flagsisNone:
72+
returnFalse# Not eligible for refresh until loaded once.
73+
74+
refresh_threshold=self._last_refresh_time+ (
75+
self._ttl_seconds-REFRESH_BEFORE_EXPIRY_SECONDS
76+
)
77+
returntime.monotonic()>refresh_threshold
78+
79+
defget_flag_value(self,name:str,default_value:Any)->Any:
80+
"""
81+
Checks if a feature is enabled.
82+
- BLOCKS on the first call until flags are fetched.
83+
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
84+
"""
85+
withself._lock:
86+
# If cache has never been loaded, perform a synchronous, blocking fetch.
87+
ifself._flagsisNone:
88+
self._refresh_flags()
89+
90+
# If a proactive background refresh is needed, start one. This is non-blocking.
91+
elifself._is_refresh_needed():
92+
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
93+
self._executor.submit(self._refresh_flags)
94+
95+
assertself._flagsisnotNone
96+
97+
# Now, return the value from the populated cache.
98+
returnself._flags.get(name,default_value)
99+
100+
def_refresh_flags(self):
101+
"""Performs a synchronous network request to fetch and update flags."""
102+
headers= {}
103+
try:
104+
# Authenticate the request
105+
self._connection.session.auth_provider.add_headers(headers)
106+
headers["User-Agent"]=self._connection.session.useragent_header
107+
108+
response=requests.get(
109+
self._feature_flag_endpoint,headers=headers,timeout=30
110+
)
111+
112+
ifresponse.status_code==200:
113+
ff_response=FeatureFlagsResponse.from_dict(response.json())
114+
self._update_cache_from_response(ff_response)
115+
else:
116+
# On failure, initialize with an empty dictionary to prevent re-blocking.
117+
ifself._flagsisNone:
118+
self._flags= {}
119+
120+
exceptExceptionase:
121+
# On exception, initialize with an empty dictionary to prevent re-blocking.
122+
ifself._flagsisNone:
123+
self._flags= {}
124+
125+
def_update_cache_from_response(self,ff_response:FeatureFlagsResponse):
126+
"""Atomically updates the internal cache state from a successful server response."""
127+
withself._lock:
128+
self._flags= {flag.name:flag.valueforflaginff_response.flags}
129+
ifff_response.ttl_secondsisnotNoneandff_response.ttl_seconds>0:
130+
self._ttl_seconds=ff_response.ttl_seconds
131+
self._last_refresh_time=time.monotonic()
132+
133+
134+
classFeatureFlagsContextFactory:
135+
"""
136+
Manages a singleton instance of FeatureFlagsContext per connection session.
137+
Also manages a shared ThreadPoolExecutor for all background refresh operations.
138+
"""
139+
140+
_context_map:Dict[str,FeatureFlagsContext]= {}
141+
_executor:Optional[ThreadPoolExecutor]=None
142+
_lock=threading.Lock()
143+
144+
@classmethod
145+
def_initialize(cls):
146+
"""Initializes the shared executor for async refreshes if it doesn't exist."""
147+
ifcls._executorisNone:
148+
cls._executor=ThreadPoolExecutor(
149+
max_workers=3,thread_name_prefix="feature-flag-refresher"
150+
)
151+
152+
@classmethod
153+
defget_instance(cls,connection:"Connection")->FeatureFlagsContext:
154+
"""Gets or creates a FeatureFlagsContext for the given connection."""
155+
withcls._lock:
156+
cls._initialize()
157+
assertcls._executorisnotNone
158+
159+
# Use the unique session ID as the key
160+
key=connection.get_session_id_hex()
161+
ifkeynotincls._context_map:
162+
cls._context_map[key]=FeatureFlagsContext(connection,cls._executor)
163+
returncls._context_map[key]
164+
165+
@classmethod
166+
defremove_instance(cls,connection:"Connection"):
167+
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
168+
withcls._lock:
169+
key=connection.get_session_id_hex()
170+
ifkeyincls._context_map:
171+
cls._context_map.pop(key,None)
172+
173+
# If this was the last active context, clean up the thread pool.
174+
ifnotcls._context_mapandcls._executorisnotNone:
175+
cls._executor.shutdown(wait=False)
176+
cls._executor=None

‎src/databricks/sql/telemetry/telemetry_client.py‎

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
importtime
33
importlogging
44
fromconcurrent.futuresimportThreadPoolExecutor
5-
fromtypingimportDict,Optional
5+
fromtypingimportDict,Optional,TYPE_CHECKING
66
fromdatabricks.sql.common.httpimportTelemetryHttpClient
77
fromdatabricks.sql.telemetry.models.eventimport (
88
TelemetryEvent,
@@ -36,6 +36,10 @@
3636
importuuid
3737
importlocale
3838
fromdatabricks.sql.telemetry.utilsimportBaseTelemetryClient
39+
fromdatabricks.sql.common.feature_flagimportFeatureFlagsContextFactory
40+
41+
ifTYPE_CHECKING:
42+
fromdatabricks.sql.clientimportConnection
3943

4044
logger=logging.getLogger(__name__)
4145

@@ -44,6 +48,7 @@ class TelemetryHelper:
4448
"""Helper class for getting telemetry related information."""
4549

4650
_DRIVER_SYSTEM_CONFIGURATION=None
51+
TELEMETRY_FEATURE_FLAG_NAME="databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver"
4752

4853
@classmethod
4954
defget_driver_system_configuration(cls)->DriverSystemConfiguration:
@@ -98,6 +103,20 @@ def get_auth_flow(auth_provider):
98103
else:
99104
returnNone
100105

106+
@staticmethod
107+
defis_telemetry_enabled(connection:"Connection")->bool:
108+
ifconnection.force_enable_telemetry:
109+
returnTrue
110+
111+
ifconnection.enable_telemetry:
112+
context=FeatureFlagsContextFactory.get_instance(connection)
113+
flag_value=context.get_flag_value(
114+
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME,default_value=False
115+
)
116+
returnstr(flag_value).lower()=="true"
117+
else:
118+
returnFalse
119+
101120

102121
classNoopTelemetryClient(BaseTelemetryClient):
103122
"""

‎tests/e2e/test_concurrent_telemetry.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def execute_query_worker(thread_id):
7676

7777
time.sleep(random.uniform(0,0.05))
7878

79-
withself.connection(extra_params={"enable_telemetry":True})asconn:
79+
withself.connection(extra_params={"force_enable_telemetry":True})asconn:
8080
# Capture the session ID from the connection before executing the query
8181
session_id_hex=conn.get_session_id_hex()
8282
withcapture_lock:

‎tests/unit/test_telemetry.py‎

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
NoopTelemetryClient,
88
TelemetryClientFactory,
99
TelemetryHelper,
10-
BaseTelemetryClient,
1110
)
1211
fromdatabricks.sql.telemetry.models.enumsimportAuthMech,AuthFlow
1312
fromdatabricks.sql.auth.authenticatorsimport (
1413
AccessTokenAuthProvider,
1514
DatabricksOAuthProvider,
1615
ExternalAuthProvider,
1716
)
17+
fromdatabricksimportsql
1818

1919

2020
@pytest.fixture
@@ -311,8 +311,6 @@ def test_connection_failure_sends_correct_telemetry_payload(
311311
mock_session.side_effect=Exception(error_message)
312312

313313
try:
314-
fromdatabricksimportsql
315-
316314
sql.connect(server_hostname="test-host",http_path="/test-path")
317315
exceptExceptionase:
318316
assertstr(e)==error_message
@@ -321,3 +319,90 @@ def test_connection_failure_sends_correct_telemetry_payload(
321319
call_arguments=mock_export_failure_log.call_args
322320
assertcall_arguments[0][0]=="Exception"
323321
assertcall_arguments[0][1]==error_message
322+
323+
324+
@patch("databricks.sql.client.Session")
325+
classTestTelemetryFeatureFlag:
326+
"""Tests the interaction between the telemetry feature flag and connection parameters."""
327+
328+
def_mock_ff_response(self,mock_requests_get,enabled:bool):
329+
"""Helper to configure the mock response for the feature flag endpoint."""
330+
mock_response=MagicMock()
331+
mock_response.status_code=200
332+
payload= {
333+
"flags": [
334+
{
335+
"name":"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver",
336+
"value":str(enabled).lower(),
337+
}
338+
],
339+
"ttl_seconds":3600,
340+
}
341+
mock_response.json.return_value=payload
342+
mock_requests_get.return_value=mock_response
343+
344+
@patch("databricks.sql.common.feature_flag.requests.get")
345+
deftest_telemetry_enabled_when_flag_is_true(
346+
self,mock_requests_get,MockSession
347+
):
348+
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
349+
self._mock_ff_response(mock_requests_get,enabled=True)
350+
mock_session_instance=MockSession.return_value
351+
mock_session_instance.guid_hex="test-session-ff-true"
352+
mock_session_instance.auth_provider=AccessTokenAuthProvider("token")
353+
354+
conn=sql.client.Connection(
355+
server_hostname="test",
356+
http_path="test",
357+
access_token="test",
358+
enable_telemetry=True,
359+
)
360+
361+
assertconn.telemetry_enabledisTrue
362+
mock_requests_get.assert_called_once()
363+
client=TelemetryClientFactory.get_telemetry_client("test-session-ff-true")
364+
assertisinstance(client,TelemetryClient)
365+
366+
@patch("databricks.sql.common.feature_flag.requests.get")
367+
deftest_telemetry_disabled_when_flag_is_false(
368+
self,mock_requests_get,MockSession
369+
):
370+
"""Telemetry should be OFF when enable_telemetry=True but server flag is 'false'."""
371+
self._mock_ff_response(mock_requests_get,enabled=False)
372+
mock_session_instance=MockSession.return_value
373+
mock_session_instance.guid_hex="test-session-ff-false"
374+
mock_session_instance.auth_provider=AccessTokenAuthProvider("token")
375+
376+
conn=sql.client.Connection(
377+
server_hostname="test",
378+
http_path="test",
379+
access_token="test",
380+
enable_telemetry=True,
381+
)
382+
383+
assertconn.telemetry_enabledisFalse
384+
mock_requests_get.assert_called_once()
385+
client=TelemetryClientFactory.get_telemetry_client("test-session-ff-false")
386+
assertisinstance(client,NoopTelemetryClient)
387+
388+
@patch("databricks.sql.common.feature_flag.requests.get")
389+
deftest_telemetry_disabled_when_flag_request_fails(
390+
self,mock_requests_get,MockSession
391+
):
392+
"""Telemetry should default to OFF if the feature flag network request fails."""
393+
mock_requests_get.side_effect=Exception("Network is down")
394+
mock_session_instance=MockSession.return_value
395+
mock_session_instance.guid_hex="test-session-ff-fail"
396+
mock_session_instance.auth_provider=AccessTokenAuthProvider("token")
397+
398+
conn=sql.client.Connection(
399+
server_hostname="test",
400+
http_path="test",
401+
access_token="test",
402+
enable_telemetry=True,
403+
)
404+
405+
assertconn.telemetry_enabledisFalse
406+
mock_requests_get.assert_called_once()
407+
client=TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
408+
assertisinstance(client,NoopTelemetryClient)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp