- Notifications
You must be signed in to change notification settings - Fork126
[PECO-626] Support OAuth flow for Databricks Azure#86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
31f6053d4e9b2b84cf5199b52b2f5cdd01402b35d7010c36476297ffFile filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # | ||
| # It implements all the cloud specific OAuth configuration/metadata | ||
| # | ||
| # Azure: It uses AAD | ||
| # AWS: It uses Databricks internal IdP | ||
| # GCP: Not support yet | ||
| # | ||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| from typing import Optional, List | ||
| import os | ||
| OIDC_REDIRECTOR_PATH = "oidc" | ||
| class OAuthScope: | ||
| OFFLINE_ACCESS = "offline_access" | ||
| SQL = "sql" | ||
| class CloudType(Enum): | ||
| AWS = "aws" | ||
| AZURE = "azure" | ||
| DATABRICKS_AWS_DOMAINS = [".cloud.databricks.com", ".dev.databricks.com"] | ||
| DATABRICKS_AZURE_DOMAINS = [ | ||
| ".azuredatabricks.net", | ||
| ".databricks.azure.cn", | ||
| ".databricks.azure.us", | ||
| ] | ||
| # Infer cloud type from Databricks SQL instance hostname | ||
| def infer_cloud_from_host(hostname: str) -> Optional[CloudType]: | ||
| # normalize | ||
| host = hostname.lower().replace("https://", "").split("/")[0] | ||
| if any(e for e in DATABRICKS_AZURE_DOMAINS if host.endswith(e)): | ||
| return CloudType.AZURE | ||
| elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)): | ||
| return CloudType.AWS | ||
| else: | ||
| return None | ||
| def get_databricks_oidc_url(hostname: str): | ||
| maybe_scheme = "https://" if not hostname.startswith("https://") else "" | ||
| maybe_trailing_slash = "/" if not hostname.endswith("/") else "" | ||
| return f"{maybe_scheme}{hostname}{maybe_trailing_slash}{OIDC_REDIRECTOR_PATH}" | ||
susodapop marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| class OAuthEndpointCollection(ABC): | ||
| @abstractmethod | ||
| def get_scopes_mapping(self, scopes: List[str]) -> List[str]: | ||
susodapop marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| raise NotImplementedError() | ||
| # Endpoint for oauth2 authorization e.g https://idp.example.com/oauth2/v2.0/authorize | ||
| @abstractmethod | ||
| def get_authorization_url(self, hostname: str) -> str: | ||
| raise NotImplementedError() | ||
| # Endpoint for well-known openid configuration e.g https://idp.example.com/oauth2/.well-known/openid-configuration | ||
| @abstractmethod | ||
| def get_openid_config_url(self, hostname: str) -> str: | ||
| raise NotImplementedError() | ||
| class AzureOAuthEndpointCollection(OAuthEndpointCollection): | ||
| DATATRICKS_AZURE_APP = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" | ||
| def get_scopes_mapping(self, scopes: List[str]) -> List[str]: | ||
| # There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks | ||
| tenant_id = os.getenv( | ||
| "DATABRICKS_AZURE_TENANT_ID", | ||
| AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP, | ||
| ) | ||
| azure_scope = f"{tenant_id}/user_impersonation" | ||
jackyhu-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| mapped_scopes = [azure_scope] | ||
| if OAuthScope.OFFLINE_ACCESS in scopes: | ||
| mapped_scopes.append(OAuthScope.OFFLINE_ACCESS) | ||
| return mapped_scopes | ||
| def get_authorization_url(self, hostname: str): | ||
| # We need get account specific url, which can be redirected by databricks unified oidc endpoint | ||
| return f"{get_databricks_oidc_url(hostname)}/oauth2/v2.0/authorize" | ||
| def get_openid_config_url(self, hostname: str): | ||
| return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration" | ||
| class AwsOAuthEndpointCollection(OAuthEndpointCollection): | ||
| def get_scopes_mapping(self, scopes: List[str]) -> List[str]: | ||
jackyhu-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| # No scope mapping in AWS | ||
| return scopes.copy() | ||
| def get_authorization_url(self, hostname: str): | ||
| idp_url = get_databricks_oidc_url(hostname) | ||
| return f"{idp_url}/oauth2/v2.0/authorize" | ||
| def get_openid_config_url(self, hostname: str): | ||
| idp_url = get_databricks_oidc_url(hostname) | ||
| return f"{idp_url}/.well-known/oauth-authorization-server" | ||
| def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]: | ||
| if cloud == CloudType.AWS: | ||
| return AwsOAuthEndpointCollection() | ||
| elif cloud == CloudType.AZURE: | ||
| return AzureOAuthEndpointCollection() | ||
| else: | ||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -14,17 +14,22 @@ | ||
| from requests.exceptions import RequestException | ||
| from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler | ||
| from databricks.sql.auth.endpoint import OAuthEndpointCollection | ||
| logger = logging.getLogger(__name__) | ||
| class OAuthManager: | ||
| def __init__( | ||
| self, | ||
| port_range: List[int], | ||
| client_id: str, | ||
| idp_endpoint: OAuthEndpointCollection, | ||
| ): | ||
| self.port_range = port_range | ||
| self.client_id = client_id | ||
| self.redirect_port = None | ||
| self.idp_endpoint = idp_endpoint | ||
jackyhu-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| @staticmethod | ||
| def __token_urlsafe(nbytes=32): | ||
| @@ -34,14 +39,14 @@ def __token_urlsafe(nbytes=32): | ||
| def __get_redirect_url(redirect_port: int): | ||
| return f"http://localhost:{redirect_port}" | ||
| def __fetch_well_known_config(self, hostname: str): | ||
| known_config_url = self.idp_endpoint.get_openid_config_url(hostname) | ||
| try: | ||
| response = requests.get(url=known_config_url) | ||
| except RequestException as e: | ||
| logger.error( | ||
| f"Unable to fetch OAuth configuration from {known_config_url}.\n" | ||
| "Verify it is a valid workspace URL and that OAuth is " | ||
| "enabled on this account." | ||
| ) | ||
| @@ -50,7 +55,7 @@ def __fetch_well_known_config(idp_url: str): | ||
| if response.status_code != 200: | ||
| msg = ( | ||
| f"Received status {response.status_code} OAuth configuration from " | ||
| f"{known_config_url}.\n Verify it is a valid workspace URL and " | ||
| "that OAuth is enabled on this account." | ||
| ) | ||
| logger.error(msg) | ||
| @@ -59,18 +64,12 @@ def __fetch_well_known_config(idp_url: str): | ||
| return response.json() | ||
| except requests.exceptions.JSONDecodeError as e: | ||
| logger.error( | ||
| f"Unable to decode OAuth configuration from {known_config_url}.\n" | ||
| "Verify it is a valid workspace URL and that OAuth is " | ||
| "enabled on this account." | ||
| ) | ||
| raise e | ||
| @staticmethod | ||
| def __get_challenge(): | ||
| verifier_string = OAuthManager.__token_urlsafe(32) | ||
| @@ -154,8 +153,7 @@ def __send_token_request(token_request_url, data): | ||
| return response.json() | ||
| def __send_refresh_token_request(self, hostname, refresh_token): | ||
| oauth_config = self.__fetch_well_known_config(hostname) | ||
| token_request_url = oauth_config["token_endpoint"] | ||
| client = oauthlib.oauth2.WebApplicationClient(self.client_id) | ||
| token_request_body = client.prepare_refresh_body( | ||
| @@ -215,14 +213,15 @@ def check_and_refresh_access_token( | ||
| return fresh_access_token, fresh_refresh_token, True | ||
| def get_tokens(self, hostname: str, scope=None): | ||
| oauth_config = self.__fetch_well_known_config(hostname) | ||
| # We are going to override oauth_config["authorization_endpoint"] use the | ||
| # /oidc redirector on the hostname, which may inject additional parameters. | ||
| auth_url = self.idp_endpoint.get_authorization_url(hostname) | ||
| state = OAuthManager.__token_urlsafe(16) | ||
| (verifier, challenge) = OAuthManager.__get_challenge() | ||
| client = oauthlib.oauth2.WebApplicationClient(self.client_id) | ||
| try: | ||
| auth_response = self.__get_authorization_code( | ||
| client, auth_url, scope, state, challenge | ||
Uh oh!
There was an error while loading.Please reload this page.