Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbef7ac6

Browse files
committed
working
1 parent05ab3e8 commitbef7ac6

File tree

4 files changed

+213
-97
lines changed

4 files changed

+213
-97
lines changed

‎src/databricks/sql/auth/auth.py‎

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
1-
fromenumimportEnum
21
fromtypingimportOptional,List
32

43
fromdatabricks.sql.auth.authenticatorsimport (
54
AuthProvider,
65
AccessTokenAuthProvider,
76
ExternalAuthProvider,
87
DatabricksOAuthProvider,
8+
AzureServicePrincipalCredentialProvider,
99
)
10-
11-
12-
classAuthType(Enum):
13-
DATABRICKS_OAUTH="databricks-oauth"
14-
AZURE_OAUTH="azure-oauth"
15-
# other supported types (access_token) can be inferred
16-
# we can add more types as needed later
10+
fromdatabricks.sql.common.authimportAuthType
1711

1812

1913
classClientContext:
@@ -24,6 +18,9 @@ def __init__(
2418
auth_type:Optional[str]=None,
2519
oauth_scopes:Optional[List[str]]=None,
2620
oauth_client_id:Optional[str]=None,
21+
oauth_client_secret:Optional[str]=None,
22+
azure_tenant_id:Optional[str]=None,
23+
azure_workspace_resource_id:Optional[str]=None,
2724
oauth_redirect_port_range:Optional[List[int]]=None,
2825
use_cert_as_auth:Optional[str]=None,
2926
tls_client_cert_file:Optional[str]=None,
@@ -35,6 +32,9 @@ def __init__(
3532
self.auth_type=auth_type
3633
self.oauth_scopes=oauth_scopes
3734
self.oauth_client_id=oauth_client_id
35+
self.oauth_client_secret=oauth_client_secret
36+
self.azure_tenant_id=azure_tenant_id
37+
self.azure_workspace_resource_id=azure_workspace_resource_id
3838
self.oauth_redirect_port_range=oauth_redirect_port_range
3939
self.use_cert_as_auth=use_cert_as_auth
4040
self.tls_client_cert_file=tls_client_cert_file
@@ -45,7 +45,17 @@ def __init__(
4545
defget_auth_provider(cfg:ClientContext):
4646
ifcfg.credentials_provider:
4747
returnExternalAuthProvider(cfg.credentials_provider)
48-
ifcfg.auth_typein [AuthType.DATABRICKS_OAUTH.value,AuthType.AZURE_OAUTH.value]:
48+
elifcfg.auth_type==AuthType.AZURE_SP_M2M.value:
49+
returnExternalAuthProvider(
50+
AzureServicePrincipalCredentialProvider(
51+
cfg.hostname,
52+
cfg.oauth_client_id,
53+
cfg.oauth_client_secret,
54+
cfg.azure_tenant_id,
55+
cfg.azure_workspace_resource_id,
56+
)
57+
)
58+
elifcfg.auth_typein [AuthType.DATABRICKS_OAUTH.value,AuthType.AZURE_OAUTH.value]:
4959
assertcfg.oauth_redirect_port_rangeisnotNone
5060
assertcfg.oauth_client_idisnotNone
5161
assertcfg.oauth_scopesisnotNone
@@ -103,9 +113,15 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
103113

104114
defget_python_sql_connector_auth_provider(hostname:str,**kwargs):
105115
auth_type=kwargs.get("auth_type")
106-
(client_id,redirect_port_range)=get_client_id_and_redirect_port(
107-
auth_type==AuthType.AZURE_OAUTH.value
108-
)
116+
client_id=kwargs.get("oauth_client_id")
117+
redirect_port_range=kwargs.get("oauth_redirect_port_range")
118+
119+
ifauth_type==AuthType.AZURE_SP_M2M.value:
120+
pass
121+
else:
122+
(client_id,redirect_port_range)=get_client_id_and_redirect_port(
123+
auth_type==AuthType.AZURE_OAUTH.value
124+
)
109125
ifkwargs.get("username")orkwargs.get("password"):
110126
raiseValueError(
111127
"Username/password authentication is no longer supported. "
@@ -119,9 +135,12 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
119135
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
120136
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121137
oauth_scopes=PYSQL_OAUTH_SCOPES,
122-
oauth_client_id=kwargs.get("oauth_client_id")orclient_id,
138+
oauth_client_id=client_id,
139+
oauth_client_secret=kwargs.get("oauth_client_secret"),
140+
azure_tenant_id=kwargs.get("azure_tenant_id"),
141+
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123142
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124-
ifkwargs.get("oauth_client_id")andkwargs.get("oauth_redirect_port")
143+
ifclient_idandkwargs.get("oauth_redirect_port")
125144
elseredirect_port_range,
126145
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
127146
credentials_provider=kwargs.get("credentials_provider"),

‎src/databricks/sql/auth/authenticators.py‎

Lines changed: 58 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
importabc
2-
importjwt
32
importlogging
4-
importtime
53
fromtypingimportCallable,Dict,List
6-
fromdatabricks.sql.common.httpimportHttpMethod,DatabricksHttpClient,HttpHeader
7-
fromdatabricks.sql.auth.oauthimportOAuthManager
4+
fromdatabricks.sql.common.httpimportHttpHeader
5+
fromdatabricks.sql.auth.oauthimport (
6+
OAuthManager,
7+
RefreshableTokenSource,
8+
ClientCredentialsTokenSource,
9+
)
810
fromdatabricks.sql.auth.endpointimportget_oauth_endpoints
9-
fromdatabricks.sql.common.httpimportDatabricksHttpClient,OAuthResponse
10-
fromurllib.parseimporturlencode
11+
fromdatabricks.sql.common.authimportAuthType,get_effective_azure_login_app_id
1112

1213
# Private API: this is an evolving interface and it will change in the future.
1314
# Please must not depend on it in your applications.
@@ -38,35 +39,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
3839
...
3940

4041

41-
classToken:
42-
"""
43-
A class to represent a token.
44-
45-
Attributes:
46-
access_token (str): The access token string.
47-
token_type (str): The type of token (e.g., "Bearer").
48-
refresh_token (str): The refresh token string.
49-
"""
50-
51-
def__init__(self,access_token:str,token_type:str,refresh_token:str):
52-
self.access_token=access_token
53-
self.token_type=token_type
54-
self.refresh_token=refresh_token
55-
56-
defis_expired(self):
57-
try:
58-
decoded_token=jwt.decode(
59-
self.access_token,options={"verify_signature":False}
60-
)
61-
exp_time=decoded_token.get("exp")
62-
current_time=time.time()
63-
buffer_time=30# 30 seconds buffer
64-
returnexp_timeand (exp_time-buffer_time)<=current_time
65-
exceptExceptionase:
66-
logger.error("Failed to decode token: %s",e)
67-
returne
68-
69-
7042
# Private API: this is an evolving interface and it will change in the future.
7143
# Please must not depend on it in your applications.
7244
classAccessTokenAuthProvider(AuthProvider):
@@ -192,64 +164,68 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
192164
from Azure AD and automatically refreshes them when they expire.
193165
194166
Attributes:
195-
client_id (str): The Azure service principal's client ID.
196-
client_secret (str): The Azure service principal's client secret.
197-
tenant_id (str): The Azure AD tenant ID.
167+
hostname (str): The Databricks workspace hostname.
168+
oauth_client_id (str): The Azure service principal's client ID.
169+
oauth_client_secret (str): The Azure service principal's client secret.
170+
azure_tenant_id (str): The Azure AD tenant ID.
171+
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
198172
"""
199173

200174
AZURE_AAD_ENDPOINT="https://login.microsoftonline.com"
201175
AZURE_TOKEN_ENDPOINT="oauth2/token"
202176

203-
def__init__(self,client_id:str,client_secret:str,tenant_id:str):
204-
self.client_id=client_id
205-
self.client_secret=client_secret
206-
self.tenant_id=tenant_id
207-
self._token:Token=None
208-
self._http_client=DatabricksHttpClient.get_instance()
177+
AZURE_MANAGED_RESOURCE="https://management.core.windows.net/"
178+
179+
DATABRICKS_AZURE_SP_TOKEN_HEADER="X-Databricks-Azure-SP-Management-Token"
180+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER= (
181+
"X-Databricks-Azure-Workspace-Resource-Id"
182+
)
183+
184+
def__init__(
185+
self,
186+
hostname:str,
187+
oauth_client_id:str,
188+
oauth_client_secret:str,
189+
azure_tenant_id:str,
190+
azure_workspace_resource_id:str=None,
191+
):
192+
self.hostname=hostname
193+
self.oauth_client_id=oauth_client_id
194+
self.oauth_client_secret=oauth_client_secret
195+
self.azure_tenant_id=azure_tenant_id
196+
self.azure_workspace_resource_id=azure_workspace_resource_id
209197

210198
defauth_type(self)->str:
211-
return"azure-service-principal"
199+
returnAuthType.AZURE_SP_M2M.value
200+
201+
defget_token_source(self,resource:str)->RefreshableTokenSource:
202+
returnClientCredentialsTokenSource(
203+
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
204+
oauth_client_id=self.oauth_client_id,
205+
oauth_client_secret=self.oauth_client_secret,
206+
extra_params={"resource":resource},
207+
)
212208

213209
def__call__(self,*args,**kwargs)->HeaderFactory:
214-
defheader_factory()->Dict[str,str]:
215-
self._refresh()
216-
return {
217-
HttpHeader.AUTHORIZATION.value:f"{self._token.token_type}{self._token.access_token}",
218-
}
219-
220-
returnheader_factory
210+
inner=self.get_token_source(
211+
resource=get_effective_azure_login_app_id(self.hostname)
212+
)
213+
cloud=self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
221214

222-
def_refresh(self)->None:
223-
ifself._tokenisNoneorself._token.is_expired():
224-
self._token=self._get_token()
215+
defheader_factory()->Dict[str,str]:
216+
inner_token=inner.get_token()
217+
cloud_token=cloud.get_token()
225218

226-
def_get_token(self)->Token:
227-
request_url= (
228-
f"{self.AZURE_AAD_ENDPOINT}/{self.tenant_id}/{self.AZURE_TOKEN_ENDPOINT}"
229-
)
230-
headers= {
231-
HttpHeader.CONTENT_TYPE.value:"application/x-www-form-urlencoded",
232-
}
233-
data=urlencode(
234-
{
235-
"grant_type":"client_credentials",
236-
"client_id":self.client_id,
237-
"client_secret":self.client_secret,
219+
headers= {
220+
HttpHeader.AUTHORIZATION.value:f"{inner_token.token_type}{inner_token.access_token}",
221+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER:cloud_token.access_token,
238222
}
239-
)
240223

241-
response=self._http_client.execute(
242-
method=HttpMethod.POST,url=request_url,headers=headers,data=data
243-
)
224+
ifself.azure_workspace_resource_id:
225+
headers[
226+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
227+
]=self.azure_workspace_resource_id
244228

245-
ifresponse.status_code==200:
246-
oauth_response=OAuthResponse(**response.json())
247-
returnToken(
248-
oauth_response.access_token,
249-
oauth_response.token_type,
250-
oauth_response.refresh_token,
251-
)
252-
else:
253-
raiseException(
254-
f"Failed to get token:{response.status_code}{response.text}"
255-
)
229+
returnheaders
230+
231+
returnheader_factory

‎src/databricks/sql/auth/oauth.py‎

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,57 @@
1212
importrequests
1313
fromoauthlib.oauth2.rfc6749.errorsimportOAuth2Error
1414
fromrequests.exceptionsimportRequestException
15-
15+
fromdatabricks.sql.common.httpimportHttpMethod,DatabricksHttpClient,HttpHeader
16+
fromdatabricks.sql.common.httpimportOAuthResponse
1617
fromdatabricks.sql.auth.oauth_http_handlerimportOAuthHttpSingleRequestHandler
1718
fromdatabricks.sql.auth.endpointimportOAuthEndpointCollection
19+
fromabcimportabstractmethod,ABC
20+
fromurllib.parseimporturlencode
21+
importjwt
22+
importtime
1823

1924
logger=logging.getLogger(__name__)
2025

2126

27+
classToken:
28+
"""
29+
A class to represent a token.
30+
31+
Attributes:
32+
access_token (str): The access token string.
33+
token_type (str): The type of token (e.g., "Bearer").
34+
refresh_token (str): The refresh token string.
35+
"""
36+
37+
def__init__(self,access_token:str,token_type:str,refresh_token:str):
38+
self.access_token=access_token
39+
self.token_type=token_type
40+
self.refresh_token=refresh_token
41+
42+
defis_expired(self):
43+
try:
44+
decoded_token=jwt.decode(
45+
self.access_token,options={"verify_signature":False}
46+
)
47+
exp_time=decoded_token.get("exp")
48+
current_time=time.time()
49+
buffer_time=30# 30 seconds buffer
50+
returnexp_timeand (exp_time-buffer_time)<=current_time
51+
exceptExceptionase:
52+
logger.error("Failed to decode token: %s",e)
53+
returne
54+
55+
56+
classRefreshableTokenSource(ABC):
57+
@abstractmethod
58+
defget_token(self)->Token:
59+
pass
60+
61+
@abstractmethod
62+
defrefresh(self):
63+
pass
64+
65+
2266
classIgnoreNetrcAuth(requests.auth.AuthBase):
2367
"""This auth method is a no-op.
2468
@@ -258,3 +302,53 @@ def get_tokens(self, hostname: str, scope=None):
258302
client,token_request_url,redirect_url,code,verifier
259303
)
260304
returnself.__get_tokens_from_response(oauth_response)
305+
306+
307+
classClientCredentialsTokenSource(RefreshableTokenSource):
308+
def__init__(
309+
self,
310+
token_url:str,
311+
oauth_client_id:str,
312+
oauth_client_secret:str,
313+
extra_params:dict=None,
314+
):
315+
self.oauth_client_id=oauth_client_id
316+
self.oauth_client_secret=oauth_client_secret
317+
self.token_url=token_url
318+
self.extra_params=extra_params
319+
self.token:Token=None
320+
self._http_client=DatabricksHttpClient()
321+
322+
defget_token(self)->Token:
323+
ifself.tokenisNoneorself.token.is_expired():
324+
self.token=self.refresh()
325+
returnself.token
326+
327+
defrefresh(self)->None:
328+
headers= {
329+
HttpHeader.CONTENT_TYPE.value:"application/x-www-form-urlencoded",
330+
}
331+
data=urlencode(
332+
{
333+
"grant_type":"client_credentials",
334+
"client_id":self.oauth_client_id,
335+
"client_secret":self.oauth_client_secret,
336+
**self.extra_params,
337+
}
338+
)
339+
340+
response=self._http_client.execute(
341+
method=HttpMethod.POST,url=self.token_url,headers=headers,data=data
342+
)
343+
344+
ifresponse.status_code==200:
345+
oauth_response=OAuthResponse(**response.json())
346+
returnToken(
347+
oauth_response.access_token,
348+
oauth_response.token_type,
349+
oauth_response.refresh_token,
350+
)
351+
else:
352+
raiseException(
353+
f"Failed to get token:{response.status_code}{response.text}"
354+
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp