88ClientCredentialsTokenSource ,
99)
1010from databricks .sql .auth .endpoint import get_oauth_endpoints
11- from databricks .sql .auth .common import AuthType ,get_effective_azure_login_app_id
12- from databricks .sdk import WorkspaceClient
11+ from databricks .sql .auth .common import (
12+ AuthType ,
13+ get_effective_azure_login_app_id ,
14+ get_azure_tenant_id_from_host ,
15+ )
1316
1417# Private API: this is an evolving interface and it will change in the future.
1518# Please must not depend on it in your applications.
1619from databricks .sql .experimental .oauth_persistence import OAuthToken ,OAuthPersistence
1720
1821
19-
2022class AuthProvider :
2123def add_headers (self ,request_headers :Dict [str ,str ]):
2224pass
@@ -165,8 +167,8 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
165167
166168 Attributes:
167169 hostname (str): The Databricks workspace hostname.
168- oauth_client_id (str): The Azure service principal's client ID.
169- oauth_client_secret (str): The Azure service principal's client secret.
170+ azure_client_id (str): The Azure service principal's client ID.
171+ azure_client_secret (str): The Azure service principal's client secret.
170172 azure_tenant_id (str): The Azure AD tenant ID.
171173 azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
172174 """
@@ -184,56 +186,50 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
184186def __init__ (
185187self ,
186188hostname ,
187- oauth_client_id ,
188- oauth_client_secret ,
189- azure_tenant_id ,
189+ azure_client_id ,
190+ azure_client_secret ,
191+ azure_tenant_id = None ,
190192azure_workspace_resource_id = None ,
191193 ):
192- self .workspace_api_client = WorkspaceClient (
193- host = hostname ,
194- azure_workspace_resource_id = azure_workspace_resource_id ,
195- azure_tenant_id = azure_tenant_id ,
196- azure_client_id = oauth_client_id ,
197- azure_client_secret = oauth_client_secret ,
198- )
199194self .hostname = hostname
200- self .oauth_client_id = oauth_client_id
201- self .oauth_client_secret = oauth_client_secret
202- self .azure_tenant_id = azure_tenant_id
195+ self .azure_client_id = azure_client_id
196+ self .azure_client_secret = azure_client_secret
203197self .azure_workspace_resource_id = azure_workspace_resource_id
198+ self .azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host (
199+ hostname
200+ )
204201
205202def auth_type (self )-> str :
206203return AuthType .AZURE_SP_M2M .value
207204
208205def get_token_source (self ,resource :str )-> RefreshableTokenSource :
209206return ClientCredentialsTokenSource (
210207token_url = f"{ self .AZURE_AAD_ENDPOINT } /{ self .azure_tenant_id } /{ self .AZURE_TOKEN_ENDPOINT } " ,
211- oauth_client_id = self .oauth_client_id ,
212- oauth_client_secret = self .oauth_client_secret ,
208+ client_id = self .azure_client_id ,
209+ client_secret = self .azure_client_secret ,
213210extra_params = {"resource" :resource },
214211 )
215212
216213def __call__ (self ,* args ,** kwargs )-> HeaderFactory :
217- # inner = self.get_token_source(
218- # resource=get_effective_azure_login_app_id(self.hostname)
219- # )
220- # cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
214+ inner = self .get_token_source (
215+ resource = get_effective_azure_login_app_id (self .hostname )
216+ )
217+ cloud = self .get_token_source (resource = self .AZURE_MANAGED_RESOURCE )
221218
222219def header_factory ()-> Dict [str ,str ]:
223- # inner_token = inner.get_token()
224- # cloud_token = cloud.get_token()
220+ inner_token = inner .get_token ()
221+ cloud_token = cloud .get_token ()
225222
226- # headers = {
227- # HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
228- # self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
229- # }
223+ headers = {
224+ HttpHeader .AUTHORIZATION .value :f"{ inner_token .token_type } { inner_token .access_token } " ,
225+ self .DATABRICKS_AZURE_SP_TOKEN_HEADER :cloud_token .access_token ,
226+ }
230227
231- # if self.azure_workspace_resource_id:
232- # headers[
233- # self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
234- # ] = self.azure_workspace_resource_id
228+ if self .azure_workspace_resource_id :
229+ headers [
230+ self .DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
231+ ]= self .azure_workspace_resource_id
235232
236- # return headers
237- return self .workspace_api_client .config .authenticate ()
233+ return headers
238234
239235return header_factory