- Notifications
You must be signed in to change notification settings - Fork126
Refactor codebase to use a unified http client#673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
4437a2a30c04a642946003155211d00e3c8000d3a3cba3da72a1f7191dd40a13847acad9a4797ba2a3a9d1f045e7c33fe476fdb98d657b9dFile filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -2,7 +2,8 @@ | ||
| import logging | ||
| from typing import Optional, List | ||
| from urllib.parse import urlparse | ||
| from databricks.sql.auth.retry import DatabricksRetryPolicy | ||
| from databricks.sql.common.http import HttpMethod | ||
| logger = logging.getLogger(__name__) | ||
| @@ -36,6 +37,21 @@ def __init__( | ||
| tls_client_cert_file: Optional[str] = None, | ||
| oauth_persistence=None, | ||
| credentials_provider=None, | ||
| # HTTP client configuration parameters | ||
| ssl_options=None, # SSLOptions type | ||
| socket_timeout: Optional[float] = None, | ||
| retry_stop_after_attempts_count: Optional[int] = None, | ||
| retry_delay_min: Optional[float] = None, | ||
| retry_delay_max: Optional[float] = None, | ||
| retry_stop_after_attempts_duration: Optional[float] = None, | ||
| retry_delay_default: Optional[float] = None, | ||
| retry_dangerous_codes: Optional[List[int]] = None, | ||
| http_proxy: Optional[str] = None, | ||
| proxy_username: Optional[str] = None, | ||
| proxy_password: Optional[str] = None, | ||
| pool_connections: Optional[int] = None, | ||
| pool_maxsize: Optional[int] = None, | ||
| user_agent: Optional[str] = None, | ||
| ): | ||
| self.hostname = hostname | ||
| self.access_token = access_token | ||
| @@ -52,6 +68,24 @@ def __init__( | ||
| self.oauth_persistence = oauth_persistence | ||
| self.credentials_provider = credentials_provider | ||
| # HTTP client configuration | ||
| self.ssl_options = ssl_options | ||
| self.socket_timeout = socket_timeout | ||
| self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 | ||
| self.retry_delay_min = retry_delay_min or 1.0 | ||
| self.retry_delay_max = retry_delay_max or 10.0 | ||
| self.retry_stop_after_attempts_duration = ( | ||
| retry_stop_after_attempts_duration or 300.0 | ||
| ) | ||
| self.retry_delay_default = retry_delay_default or 5.0 | ||
| self.retry_dangerous_codes = retry_dangerous_codes or [] | ||
| self.http_proxy = http_proxy | ||
| self.proxy_username = proxy_username | ||
| self.proxy_password = proxy_password | ||
| self.pool_connections = pool_connections or 10 | ||
vikrantpuppala marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| self.pool_maxsize = pool_maxsize or 20 | ||
| self.user_agent = user_agent | ||
| def get_effective_azure_login_app_id(hostname) -> str: | ||
| """ | ||
| @@ -69,7 +103,7 @@ def get_effective_azure_login_app_id(hostname) -> str: | ||
| return AzureAppId.PROD.value[1] | ||
| def get_azure_tenant_id_from_host(host: str, http_client) -> str: | ||
| """ | ||
| Load the Azure tenant ID from the Azure Databricks login page. | ||
| @@ -78,23 +112,20 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: | ||
| the Azure login page, and the tenant ID is extracted from the redirect URL. | ||
| """ | ||
| login_url = f"{host}/aad/auth" | ||
| logger.debug("Loading tenant ID from %s", login_url) | ||
| with http_client.request_context(HttpMethod.GET, login_url) as resp: | ||
| entra_id_endpoint = resp.retries.history[-1].redirect_location | ||
| if entra_id_endpoint is None: | ||
| raise ValueError( | ||
| f"No Location header in response from {login_url}:{entra_id_endpoint}" | ||
| ) | ||
| # The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?... | ||
| # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). | ||
| url = urlparse(entra_id_endpoint) | ||
| path_segments = url.path.split("/") | ||
| if len(path_segments) < 2: | ||
| raise ValueError(f"Invalid path in Location header: {url.path}") | ||
vikrantpuppala marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| return path_segments[1] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -9,10 +9,8 @@ | ||
| from typing import List, Optional | ||
| import oauthlib.oauth2 | ||
| from oauthlib.oauth2.rfc6749.errors import OAuth2Error | ||
| from databricks.sql.common.http import HttpMethod, HttpHeader | ||
| from databricks.sql.common.http import OAuthResponse | ||
| from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler | ||
| from databricks.sql.auth.endpoint import OAuthEndpointCollection | ||
| @@ -63,33 +61,19 @@ def refresh(self) -> Token: | ||
| pass | ||
| class OAuthManager: | ||
| def __init__( | ||
| self, | ||
| port_range: List[int], | ||
| client_id: str, | ||
| idp_endpoint: OAuthEndpointCollection, | ||
| http_client, | ||
| ): | ||
| self.port_range = port_range | ||
| self.client_id = client_id | ||
| self.redirect_port = None | ||
| self.idp_endpoint = idp_endpoint | ||
| self.http_client = http_client | ||
| @staticmethod | ||
| def __token_urlsafe(nbytes=32): | ||
| @@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str): | ||
| known_config_url = self.idp_endpoint.get_openid_config_url(hostname) | ||
| try: | ||
| response = self.http_client.request(HttpMethod.GET, url=known_config_url) | ||
| # Convert urllib3 response to requests-like response for compatibility | ||
| response.status_code = response.status | ||
| response.json = lambda: json.loads(response.data.decode()) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Unable to fetch OAuth configuration from {known_config_url}.\n" | ||
| "Verify it is a valid workspace URL and that OAuth is " | ||
| @@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str): | ||
| raise RuntimeError(msg) | ||
| try: | ||
| return response.json() | ||
| exceptException as e: | ||
| logger.error( | ||
| f"Unable to decode OAuth configuration from {known_config_url}.\n" | ||
| "Verify it is a valid workspace URL and that OAuth is " | ||
| @@ -203,16 +190,17 @@ def __send_auth_code_token_request( | ||
| data = f"{token_request_body}&code_verifier={verifier}" | ||
| return self.__send_token_request(token_request_url, data) | ||
| def __send_token_request(self, token_request_url, data): | ||
vikrantpuppala marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| headers = { | ||
| "Accept": "application/json", | ||
| "Content-Type": "application/x-www-form-urlencoded", | ||
| } | ||
| # Use unified HTTP client | ||
| response = self.http_client.request( | ||
| HttpMethod.POST, url=token_request_url, body=data, headers=headers | ||
| ) | ||
| # Convert urllib3 response to dict for compatibility | ||
| return json.loads(response.data.decode()) | ||
| def __send_refresh_token_request(self, hostname, refresh_token): | ||
| oauth_config = self.__fetch_well_known_config(hostname) | ||
| @@ -221,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token): | ||
| token_request_body = client.prepare_refresh_body( | ||
| refresh_token=refresh_token, client_id=client.client_id | ||
| ) | ||
| returnself.__send_token_request(token_request_url, token_request_body) | ||
| @staticmethod | ||
| def __get_tokens_from_response(oauth_response): | ||
| @@ -320,14 +308,15 @@ def __init__( | ||
| token_url, | ||
| client_id, | ||
| client_secret, | ||
| http_client, | ||
| extra_params: dict = {}, | ||
| ): | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.token_url = token_url | ||
| self.extra_params = extra_params | ||
| self.token: Optional[Token] = None | ||
| self._http_client =http_client | ||
| def get_token(self) -> Token: | ||
| if self.token is None or self.token.is_expired(): | ||
| @@ -348,17 +337,17 @@ def refresh(self) -> Token: | ||
| } | ||
| ) | ||
| response =self._http_client.request( | ||
| method=HttpMethod.POST, url=self.token_url, headers=headers,body=data | ||
| ) | ||
| if response.status == 200: | ||
| oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8"))) | ||
| return Token( | ||
| oauth_response.access_token, | ||
| oauth_response.token_type, | ||
| oauth_response.refresh_token, | ||
| ) | ||
| else: | ||
| raise Exception( | ||
| f"Failed to get token: {response.status} {response.data.decode('utf-8')}" | ||
| ) | ||
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.