|
1 | 1 | importabc |
2 | | -importjwt |
3 | 2 | importlogging |
4 | | -importtime |
5 | 3 | fromtypingimportCallable,Dict,List |
6 | | -fromdatabricks.sql.common.httpimportHttpMethod,DatabricksHttpClient,HttpHeader |
7 | | -fromdatabricks.sql.auth.oauthimportOAuthManager |
| 4 | +fromdatabricks.sql.common.httpimportHttpHeader |
| 5 | +fromdatabricks.sql.auth.oauthimport ( |
| 6 | +OAuthManager, |
| 7 | +RefreshableTokenSource, |
| 8 | +ClientCredentialsTokenSource, |
| 9 | +) |
8 | 10 | fromdatabricks.sql.auth.endpointimportget_oauth_endpoints |
9 | | -fromdatabricks.sql.common.httpimportDatabricksHttpClient,OAuthResponse |
10 | | -fromurllib.parseimporturlencode |
| 11 | +fromdatabricks.sql.common.authimportAuthType,get_effective_azure_login_app_id |
11 | 12 |
|
12 | 13 | # Private API: this is an evolving interface and it will change in the future. |
13 | 14 | # Please must not depend on it in your applications. |
@@ -38,35 +39,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: |
38 | 39 | ... |
39 | 40 |
|
40 | 41 |
|
41 | | -classToken: |
42 | | -""" |
43 | | - A class to represent a token. |
44 | | -
|
45 | | - Attributes: |
46 | | - access_token (str): The access token string. |
47 | | - token_type (str): The type of token (e.g., "Bearer"). |
48 | | - refresh_token (str): The refresh token string. |
49 | | - """ |
50 | | - |
51 | | -def__init__(self,access_token:str,token_type:str,refresh_token:str): |
52 | | -self.access_token=access_token |
53 | | -self.token_type=token_type |
54 | | -self.refresh_token=refresh_token |
55 | | - |
56 | | -defis_expired(self): |
57 | | -try: |
58 | | -decoded_token=jwt.decode( |
59 | | -self.access_token,options={"verify_signature":False} |
60 | | - ) |
61 | | -exp_time=decoded_token.get("exp") |
62 | | -current_time=time.time() |
63 | | -buffer_time=30# 30 seconds buffer |
64 | | -returnexp_timeand (exp_time-buffer_time)<=current_time |
65 | | -exceptExceptionase: |
66 | | -logger.error("Failed to decode token: %s",e) |
67 | | -returne |
68 | | - |
69 | | - |
70 | 42 | # Private API: this is an evolving interface and it will change in the future. |
71 | 43 | # Please must not depend on it in your applications. |
72 | 44 | classAccessTokenAuthProvider(AuthProvider): |
@@ -192,64 +164,68 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider): |
192 | 164 | from Azure AD and automatically refreshes them when they expire. |
193 | 165 |
|
194 | 166 | Attributes: |
195 | | - client_id (str): The Azure service principal's client ID. |
196 | | - client_secret (str): The Azure service principal's client secret. |
197 | | - tenant_id (str): The Azure AD tenant ID. |
| 167 | + hostname (str): The Databricks workspace hostname. |
| 168 | + oauth_client_id (str): The Azure service principal's client ID. |
| 169 | + oauth_client_secret (str): The Azure service principal's client secret. |
| 170 | + azure_tenant_id (str): The Azure AD tenant ID. |
| 171 | + azure_workspace_resource_id (str, optional): The Azure workspace resource ID. |
198 | 172 | """ |
199 | 173 |
|
200 | 174 | AZURE_AAD_ENDPOINT="https://login.microsoftonline.com" |
201 | 175 | AZURE_TOKEN_ENDPOINT="oauth2/token" |
202 | 176 |
|
203 | | -def__init__(self,client_id:str,client_secret:str,tenant_id:str): |
204 | | -self.client_id=client_id |
205 | | -self.client_secret=client_secret |
206 | | -self.tenant_id=tenant_id |
207 | | -self._token:Token=None |
208 | | -self._http_client=DatabricksHttpClient.get_instance() |
| 177 | +AZURE_MANAGED_RESOURCE="https://management.core.windows.net/" |
| 178 | + |
| 179 | +DATABRICKS_AZURE_SP_TOKEN_HEADER="X-Databricks-Azure-SP-Management-Token" |
| 180 | +DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER= ( |
| 181 | +"X-Databricks-Azure-Workspace-Resource-Id" |
| 182 | + ) |
| 183 | + |
| 184 | +def__init__( |
| 185 | +self, |
| 186 | +hostname:str, |
| 187 | +oauth_client_id:str, |
| 188 | +oauth_client_secret:str, |
| 189 | +azure_tenant_id:str, |
| 190 | +azure_workspace_resource_id:str=None, |
| 191 | + ): |
| 192 | +self.hostname=hostname |
| 193 | +self.oauth_client_id=oauth_client_id |
| 194 | +self.oauth_client_secret=oauth_client_secret |
| 195 | +self.azure_tenant_id=azure_tenant_id |
| 196 | +self.azure_workspace_resource_id=azure_workspace_resource_id |
209 | 197 |
|
210 | 198 | defauth_type(self)->str: |
211 | | -return"azure-service-principal" |
| 199 | +returnAuthType.AZURE_SP_M2M.value |
| 200 | + |
| 201 | +defget_token_source(self,resource:str)->RefreshableTokenSource: |
| 202 | +returnClientCredentialsTokenSource( |
| 203 | +token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}", |
| 204 | +oauth_client_id=self.oauth_client_id, |
| 205 | +oauth_client_secret=self.oauth_client_secret, |
| 206 | +extra_params={"resource":resource}, |
| 207 | + ) |
212 | 208 |
|
213 | 209 | def__call__(self,*args,**kwargs)->HeaderFactory: |
214 | | -defheader_factory()->Dict[str,str]: |
215 | | -self._refresh() |
216 | | -return { |
217 | | -HttpHeader.AUTHORIZATION.value:f"{self._token.token_type}{self._token.access_token}", |
218 | | - } |
219 | | - |
220 | | -returnheader_factory |
| 210 | +inner=self.get_token_source( |
| 211 | +resource=get_effective_azure_login_app_id(self.hostname) |
| 212 | + ) |
| 213 | +cloud=self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE) |
221 | 214 |
|
222 | | -def_refresh(self)->None: |
223 | | -ifself._tokenisNoneorself._token.is_expired(): |
224 | | -self._token=self._get_token() |
| 215 | +defheader_factory()->Dict[str,str]: |
| 216 | +inner_token=inner.get_token() |
| 217 | +cloud_token=cloud.get_token() |
225 | 218 |
|
226 | | -def_get_token(self)->Token: |
227 | | -request_url= ( |
228 | | -f"{self.AZURE_AAD_ENDPOINT}/{self.tenant_id}/{self.AZURE_TOKEN_ENDPOINT}" |
229 | | - ) |
230 | | -headers= { |
231 | | -HttpHeader.CONTENT_TYPE.value:"application/x-www-form-urlencoded", |
232 | | - } |
233 | | -data=urlencode( |
234 | | - { |
235 | | -"grant_type":"client_credentials", |
236 | | -"client_id":self.client_id, |
237 | | -"client_secret":self.client_secret, |
| 219 | +headers= { |
| 220 | +HttpHeader.AUTHORIZATION.value:f"{inner_token.token_type}{inner_token.access_token}", |
| 221 | +self.DATABRICKS_AZURE_SP_TOKEN_HEADER:cloud_token.access_token, |
238 | 222 | } |
239 | | - ) |
240 | 223 |
|
241 | | -response=self._http_client.execute( |
242 | | -method=HttpMethod.POST,url=request_url,headers=headers,data=data |
243 | | - ) |
| 224 | +ifself.azure_workspace_resource_id: |
| 225 | +headers[ |
| 226 | +self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER |
| 227 | + ]=self.azure_workspace_resource_id |
244 | 228 |
|
245 | | -ifresponse.status_code==200: |
246 | | -oauth_response=OAuthResponse(**response.json()) |
247 | | -returnToken( |
248 | | -oauth_response.access_token, |
249 | | -oauth_response.token_type, |
250 | | -oauth_response.refresh_token, |
251 | | - ) |
252 | | -else: |
253 | | -raiseException( |
254 | | -f"Failed to get token:{response.status_code}{response.text}" |
255 | | - ) |
| 229 | +returnheaders |
| 230 | + |
| 231 | +returnheader_factory |