- Notifications
You must be signed in to change notification settings - Fork126
Telemetry server-side flag integration#646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
f9cb8247785d43f0cdfab08c0bd8fb70b703e60cf20d300ab9a8cbbf6ceb092725cce947353296f47cbfa35d0deFile filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| import threading | ||
| import time | ||
| import requests | ||
| from dataclasses import dataclass, field | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from typing import Dict, Optional, List, Any, TYPE_CHECKING | ||
| if TYPE_CHECKING: | ||
| from databricks.sql.client import Connection | ||
saishreeeee marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| @dataclass | ||
| class FeatureFlagEntry: | ||
| """Represents a single feature flag from the server response.""" | ||
| name: str | ||
| value: str | ||
| @dataclass | ||
saishreeeee marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| class FeatureFlagsResponse: | ||
| """Represents the full JSON response from the feature flag endpoint.""" | ||
| flags: List[FeatureFlagEntry] = field(default_factory=list) | ||
| ttl_seconds: Optional[int] = None | ||
| @classmethod | ||
| def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse": | ||
| """Factory method to create an instance from a dictionary (parsed JSON).""" | ||
| flags_data = data.get("flags", []) | ||
| flags_list = [FeatureFlagEntry(**flag) for flag in flags_data] | ||
| return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds")) | ||
| # --- Constants --- | ||
| FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = ( | ||
| "/api/2.0/connector-service/feature-flags/PYTHON/{}" | ||
| ) | ||
| DEFAULT_TTL_SECONDS = 900 # 15 minutes | ||
| REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry | ||
| class FeatureFlagsContext: | ||
| """ | ||
| Manages fetching and caching of server-side feature flags for a connection. | ||
| 1. The very first check for any flag is a synchronous, BLOCKING operation. | ||
| 2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously | ||
| in the background, returning stale data until the refresh completes. | ||
| """ | ||
| def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): | ||
| from databricks.sql import __version__ | ||
| self._connection = connection | ||
| self._executor = executor # Used for ASYNCHRONOUS refreshes | ||
| self._lock = threading.RLock() | ||
| # Cache state: `None` indicates the cache has never been loaded. | ||
| self._flags: Optional[Dict[str, str]] = None | ||
| self._ttl_seconds: int = DEFAULT_TTL_SECONDS | ||
| self._last_refresh_time: float = 0 | ||
| endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__) | ||
| self._feature_flag_endpoint = ( | ||
| f"https://{self._connection.session.host}{endpoint_suffix}" | ||
| ) | ||
| def _is_refresh_needed(self) -> bool: | ||
| """Checks if the cache is due for a proactive background refresh.""" | ||
| if self._flags is None: | ||
| return False # Not eligible for refresh until loaded once. | ||
| refresh_threshold = self._last_refresh_time + ( | ||
| self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS | ||
| ) | ||
| return time.monotonic() > refresh_threshold | ||
| def get_flag_value(self, name: str, default_value: Any) -> Any: | ||
| """ | ||
| Checks if a feature is enabled. | ||
| - BLOCKS on the first call until flags are fetched. | ||
| - Returns cached values on subsequent calls, triggering non-blocking refreshes if needed. | ||
| """ | ||
| with self._lock: | ||
| # If cache has never been loaded, perform a synchronous, blocking fetch. | ||
| if self._flags is None: | ||
| self._refresh_flags() | ||
| # If a proactive background refresh is needed, start one. This is non-blocking. | ||
| elif self._is_refresh_needed(): | ||
| # We don't check for an in-flight refresh; the executor queues the task, which is safe. | ||
| self._executor.submit(self._refresh_flags) | ||
| assert self._flags is not None | ||
saishreeeee marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| # Now, return the value from the populated cache. | ||
| return self._flags.get(name, default_value) | ||
| def _refresh_flags(self): | ||
| """Performs a synchronous network request to fetch and update flags.""" | ||
| headers = {} | ||
| try: | ||
| # Authenticate the request | ||
| self._connection.session.auth_provider.add_headers(headers) | ||
| headers["User-Agent"] = self._connection.session.useragent_header | ||
| response = requests.get( | ||
| self._feature_flag_endpoint, headers=headers, timeout=30 | ||
| ) | ||
| if response.status_code == 200: | ||
| ff_response = FeatureFlagsResponse.from_dict(response.json()) | ||
| self._update_cache_from_response(ff_response) | ||
| else: | ||
| # On failure, initialize with an empty dictionary to prevent re-blocking. | ||
| if self._flags is None: | ||
| self._flags = {} | ||
| except Exception as e: | ||
| # On exception, initialize with an empty dictionary to prevent re-blocking. | ||
| if self._flags is None: | ||
| self._flags = {} | ||
| def _update_cache_from_response(self, ff_response: FeatureFlagsResponse): | ||
| """Atomically updates the internal cache state from a successful server response.""" | ||
| with self._lock: | ||
| self._flags = {flag.name: flag.value for flag in ff_response.flags} | ||
| if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0: | ||
| self._ttl_seconds = ff_response.ttl_seconds | ||
| self._last_refresh_time = time.monotonic() | ||
| class FeatureFlagsContextFactory: | ||
| """ | ||
| Manages a singleton instance of FeatureFlagsContext per connection session. | ||
| Also manages a shared ThreadPoolExecutor for all background refresh operations. | ||
| """ | ||
| _context_map: Dict[str, FeatureFlagsContext] = {} | ||
| _executor: Optional[ThreadPoolExecutor] = None | ||
| _lock = threading.Lock() | ||
| @classmethod | ||
| def _initialize(cls): | ||
| """Initializes the shared executor for async refreshes if it doesn't exist.""" | ||
| if cls._executor is None: | ||
| cls._executor = ThreadPoolExecutor( | ||
| max_workers=3, thread_name_prefix="feature-flag-refresher" | ||
| ) | ||
| @classmethod | ||
| def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: | ||
| """Gets or creates a FeatureFlagsContext for the given connection.""" | ||
| with cls._lock: | ||
| cls._initialize() | ||
| assert cls._executor is not None | ||
| # Use the unique session ID as the key | ||
| key = connection.get_session_id_hex() | ||
| if key not in cls._context_map: | ||
| cls._context_map[key] = FeatureFlagsContext(connection, cls._executor) | ||
| return cls._context_map[key] | ||
| @classmethod | ||
| def remove_instance(cls, connection: "Connection"): | ||
| """Removes the context for a given connection and shuts down the executor if no clients remain.""" | ||
| with cls._lock: | ||
| key = connection.get_session_id_hex() | ||
| if key in cls._context_map: | ||
| cls._context_map.pop(key, None) | ||
| # If this was the last active context, clean up the thread pool. | ||
| if not cls._context_map and cls._executor is not None: | ||
| cls._executor.shutdown(wait=False) | ||
| cls._executor = None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -2,7 +2,7 @@ | ||
| import time | ||
| import logging | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from typing import Dict, Optional, TYPE_CHECKING | ||
| from databricks.sql.common.http import TelemetryHttpClient | ||
| from databricks.sql.telemetry.models.event import ( | ||
| TelemetryEvent, | ||
| @@ -36,6 +36,10 @@ | ||
| import uuid | ||
| import locale | ||
| from databricks.sql.telemetry.utils import BaseTelemetryClient | ||
| from databricks.sql.common.feature_flag import FeatureFlagsContextFactory | ||
| if TYPE_CHECKING: | ||
| from databricks.sql.client import Connection | ||
| logger = logging.getLogger(__name__) | ||
| @@ -44,6 +48,7 @@ class TelemetryHelper: | ||
| """Helper class for getting telemetry related information.""" | ||
| _DRIVER_SYSTEM_CONFIGURATION = None | ||
| TELEMETRY_FEATURE_FLAG_NAME = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver" | ||
| @classmethod | ||
| def get_driver_system_configuration(cls) -> DriverSystemConfiguration: | ||
| @@ -98,6 +103,20 @@ def get_auth_flow(auth_provider): | ||
| else: | ||
| return None | ||
| @staticmethod | ||
| def is_telemetry_enabled(connection: "Connection") -> bool: | ||
| if connection.force_enable_telemetry: | ||
| return True | ||
| if connection.enable_telemetry: | ||
| context = FeatureFlagsContextFactory.get_instance(connection) | ||
| flag_value = context.get_flag_value( | ||
| TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False | ||
| ) | ||
saishreeeee marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| return str(flag_value).lower() == "true" | ||
| else: | ||
| return False | ||
| class NoopTelemetryClient(BaseTelemetryClient): | ||
| """ | ||
Uh oh!
There was an error while loading.Please reload this page.