Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit49eab2a

Browse files
committed
fmt
1 parent541e82f commit49eab2a

File tree

3 files changed

+100
-53
lines changed

3 files changed

+100
-53
lines changed

‎src/databricks/sql/auth/token_federation.py‎

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def get_headers() -> Dict[str, str]:
150150
else:
151151
# Token is from a different host, need to exchange
152152
logger.debug("Token from different host, attempting exchange")
153-
returnself._try_token_exchange_or_fallback(
154-
access_token,token_type
155-
)
153+
returnself._try_token_exchange_or_fallback(access_token,token_type)
156154
exceptExceptionase:
157155
logger.error(f"Error processing token:{str(e)}")
158156
# Fall back to original headers in case of error
@@ -172,9 +170,7 @@ def _init_oidc_discovery(self):
172170

173171
ifidp_endpoints:
174172
# Get the OpenID configuration URL
175-
openid_config_url=idp_endpoints.get_openid_config_url(
176-
self.hostname
177-
)
173+
openid_config_url=idp_endpoints.get_openid_config_url(self.hostname)
178174

179175
# Fetch the OpenID configuration
180176
response=requests.get(openid_config_url)
@@ -185,7 +181,8 @@ def _init_oidc_discovery(self):
185181
logger.info(f"Discovered token endpoint:{self.token_endpoint}")
186182
else:
187183
logger.warning(
188-
f"Failed to fetch OpenID configuration from{openid_config_url}:{response.status_code}"
184+
f"Failed to fetch OpenID configuration from{openid_config_url}: "
185+
f"{response.status_code}"
189186
)
190187
exceptExceptionase:
191188
logger.warning(
@@ -282,9 +279,15 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]:
282279
self.last_external_token=access_token
283280

284281
# Update the headers with the new token
285-
return {"Authorization":f"{exchanged_token.token_type}{exchanged_token.access_token}"}
282+
return {
283+
"Authorization": (
284+
f"{exchanged_token.token_type}{exchanged_token.access_token}"
285+
)
286+
}
286287
exceptExceptionase:
287-
logger.error(f"Token refresh failed:{str(e)}, falling back to original token")
288+
logger.error(
289+
f"Token refresh failed:{str(e)}, falling back to original token"
290+
)
288291
returnself.external_provider_headers
289292

290293
def_try_token_exchange_or_fallback(
@@ -305,12 +308,20 @@ def _try_token_exchange_or_fallback(
305308
self.last_exchanged_token=exchanged_token
306309
self.last_external_token=access_token
307310

308-
return {"Authorization":f"{exchanged_token.token_type}{exchanged_token.access_token}"}
311+
return {
312+
"Authorization": (
313+
f"{exchanged_token.token_type}{exchanged_token.access_token}"
314+
)
315+
}
309316
exceptExceptionase:
310-
logger.warning(f"Token exchange failed:{str(e)}, falling back to original token")
317+
logger.warning(
318+
f"Token exchange failed:{str(e)}, falling back to original token"
319+
)
311320
returnself.external_provider_headers
312321

313-
def_send_token_exchange_request(self,token_exchange_data:Dict[str,str])->Dict[str,Any]:
322+
def_send_token_exchange_request(
323+
self,token_exchange_data:Dict[str,str]
324+
)->Dict[str,Any]:
314325
"""
315326
Send the token exchange request to the token endpoint.
316327
@@ -325,20 +336,19 @@ def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> D
325336
"""
326337
ifnotself.token_endpoint:
327338
raiseValueError("Token endpoint not initialized")
328-
339+
329340
headers= {"Accept":"*/*","Content-Type":"application/x-www-form-urlencoded"}
330-
341+
331342
response=requests.post(
332-
self.token_endpoint,
333-
data=token_exchange_data,
334-
headers=headers
343+
self.token_endpoint,data=token_exchange_data,headers=headers
335344
)
336-
345+
337346
ifresponse.status_code!=200:
338347
raiseValueError(
339-
f"Token exchange failed with status code{response.status_code}:{response.text}"
348+
f"Token exchange failed with status code{response.status_code}: "
349+
f"{response.text}"
340350
)
341-
351+
342352
returnresponse.json()
343353

344354
def_exchange_token(self,access_token:str)->Token:
@@ -365,26 +375,28 @@ def _exchange_token(self, access_token: str) -> Token:
365375
try:
366376
# Send the token exchange request
367377
resp_data=self._send_token_exchange_request(token_exchange_data)
368-
378+
369379
# Extract token information
370380
new_access_token=resp_data.get("access_token")
371381
ifnotnew_access_token:
372382
raiseValueError("No access token in exchange response")
373-
383+
374384
token_type=resp_data.get("token_type","Bearer")
375385
refresh_token=resp_data.get("refresh_token","")
376-
386+
377387
# Parse expiry time from token claims if possible
378388
expiry=datetime.now(tz=timezone.utc)
379-
389+
380390
# First try to get expiry from the response's expires_in field
381391
if"expires_in"inresp_dataandresp_data["expires_in"]:
382392
try:
383393
expires_in=int(resp_data["expires_in"])
384-
expiry=datetime.now(tz=timezone.utc)+timedelta(seconds=expires_in)
394+
expiry=datetime.now(tz=timezone.utc)+timedelta(
395+
seconds=expires_in
396+
)
385397
except (ValueError,TypeError)ase:
386398
logger.warning(f"Invalid expires_in value:{str(e)}")
387-
399+
388400
# If that didn't work, try to parse JWT claims for expiry
389401
ifexpiry==datetime.now(tz=timezone.utc):
390402
token_claims=self._parse_jwt_claims(new_access_token)
@@ -394,9 +406,9 @@ def _exchange_token(self, access_token: str) -> Token:
394406
expiry=datetime.fromtimestamp(exp_timestamp,tz=timezone.utc)
395407
except (ValueError,TypeError)ase:
396408
logger.warning(f"Invalid exp claim in token:{str(e)}")
397-
409+
398410
returnToken(new_access_token,token_type,refresh_token,expiry)
399-
411+
400412
exceptExceptionase:
401413
logger.error(f"Token exchange failed:{str(e)}")
402414
raise

‎tests/token_federation/github_oidc_test.py‎

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
importbase64
1515
importlogging
1616
fromdatabricksimportsql
17-
importjwt
1817

18+
try:
19+
importjwt
20+
21+
HAS_JWT_LIBRARY=True
22+
exceptImportError:
23+
HAS_JWT_LIBRARY=False
1924

2025

2126
logging.basicConfig(
22-
level=logging.INFO,
23-
format="%(asctime)s - %(levelname)s - %(message)s"
27+
level=logging.INFO,format="%(asctime)s - %(levelname)s - %(message)s"
2428
)
2529
logger=logging.getLogger(__name__)
2630

@@ -35,10 +39,29 @@ def decode_jwt(token):
3539
Returns:
3640
dict: The decoded token claims or None if decoding fails
3741
"""
42+
ifHAS_JWT_LIBRARY:
43+
try:
44+
# Using PyJWT library (preferred method)
45+
# Note: we're not verifying the signature as this is just for debugging
46+
returnjwt.decode(token,options={"verify_signature":False})
47+
exceptExceptionase:
48+
logger.error(f"Failed to decode token with PyJWT:{str(e)}")
49+
50+
# Fallback to manual decoding
3851
try:
39-
returnjwt.decode(token,options={"verify_signature":False})
52+
parts=token.split(".")
53+
iflen(parts)!=3:
54+
raiseValueError("Invalid JWT format")
55+
56+
payload=parts[1]
57+
# Add padding if needed
58+
padding="="* (4-len(payload)%4)
59+
payload+=padding
60+
61+
decoded=base64.b64decode(payload)
62+
returnjson.loads(decoded)
4063
exceptExceptionase:
41-
logger.error(f"Failed to decode token with PyJWT:{str(e)}")
64+
logger.error(f"Failed to decode token:{str(e)}")
4265
return {}
4366

4467

@@ -53,7 +76,7 @@ def get_environment_variables():
5376
host=os.environ.get("DATABRICKS_HOST_FOR_TF")
5477
http_path=os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF")
5578
identity_federation_client_id=os.environ.get("IDENTITY_FEDERATION_CLIENT_ID")
56-
79+
5780
returngithub_token,host,http_path,identity_federation_client_id
5881

5982

@@ -62,7 +85,7 @@ def display_token_info(claims):
6285
ifnotclaims:
6386
logger.warning("No token claims available to display")
6487
return
65-
88+
6689
logger.info("=== GitHub OIDC Token Claims ===")
6790
logger.info(f"Token issuer:{claims.get('iss')}")
6891
logger.info(f"Token subject:{claims.get('sub')}")
@@ -74,7 +97,9 @@ def display_token_info(claims):
7497
logger.info("===============================")
7598

7699

77-
deftest_databricks_connection(host,http_path,github_token,identity_federation_client_id):
100+
deftest_databricks_connection(
101+
host,http_path,github_token,identity_federation_client_id
102+
):
78103
"""
79104
Test connection to Databricks using token federation.
80105
@@ -90,30 +115,30 @@ def test_databricks_connection(host, http_path, github_token, identity_federatio
90115
logger.info("=== Testing Connection via Connector ===")
91116
logger.info(f"Connecting to Databricks at{host}{http_path}")
92117
logger.info(f"Using client ID:{identity_federation_client_id}")
93-
118+
94119
connection_params= {
95120
"server_hostname":host,
96121
"http_path":http_path,
97122
"access_token":github_token,
98123
"auth_type":"token-federation",
99124
"identity_federation_client_id":identity_federation_client_id,
100125
}
101-
126+
102127
try:
103128
withsql.connect(**connection_params)asconnection:
104129
logger.info("Connection established successfully")
105-
130+
106131
# Execute a simple query
107132
cursor=connection.cursor()
108133
cursor.execute("SELECT 1 + 1 as result")
109134
result=cursor.fetchall()
110135
logger.info(f"Query result:{result[0][0]}")
111-
136+
112137
# Show current user
113138
cursor.execute("SELECT current_user() as user")
114139
result=cursor.fetchall()
115140
logger.info(f"Connected as user:{result[0][0]}")
116-
141+
117142
logger.info("Token federation test successful!")
118143
returnTrue
119144
exceptExceptionase:
@@ -125,29 +150,34 @@ def main():
125150
"""Main entry point for the test script."""
126151
try:
127152
# Get environment variables
128-
github_token,host,http_path,identity_federation_client_id=get_environment_variables()
129-
153+
github_token,host,http_path,identity_federation_client_id= (
154+
get_environment_variables()
155+
)
156+
130157
ifnotgithub_token:
131158
logger.error("Missing GitHub OIDC token (OIDC_TOKEN)")
132159
sys.exit(1)
133-
160+
134161
ifnothostornothttp_path:
135-
logger.error("Missing Databricks connection parameters (DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)")
162+
logger.error(
163+
"Missing Databricks connection parameters "
164+
"(DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)"
165+
)
136166
sys.exit(1)
137-
167+
138168
# Display token claims
139169
claims=decode_jwt(github_token)
140170
display_token_info(claims)
141-
171+
142172
# Test Databricks connection
143173
success=test_databricks_connection(
144174
host,http_path,github_token,identity_federation_client_id
145175
)
146-
176+
147177
ifnotsuccess:
148178
logger.error("Token federation test failed")
149179
sys.exit(1)
150-
180+
151181
exceptExceptionase:
152182
logger.error(f"Unexpected error:{str(e)}")
153183
sys.exit(1)

‎tests/unit/test_token_federation.py‎

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Token,
1414
DatabricksTokenFederationProvider,
1515
SimpleCredentialsProvider,
16-
create_token_federation_provider,
1716
TOKEN_REFRESH_BUFFER_SECONDS,
1817
)
1918

@@ -136,19 +135,25 @@ def test_init_oidc_discovery(mock_request_get, mock_get_oauth_endpoints):
136135

137136
@pytest.fixture
138137
defmock_parse_jwt_claims():
139-
withpatch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims")asmock:
138+
withpatch(
139+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims"
140+
)asmock:
140141
yieldmock
141142

142143

143144
@pytest.fixture
144145
defmock_exchange_token():
145-
withpatch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token")asmock:
146+
withpatch(
147+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token"
148+
)asmock:
146149
yieldmock
147150

148151

149152
@pytest.fixture
150153
defmock_is_same_host():
151-
withpatch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host")asmock:
154+
withpatch(
155+
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host"
156+
)asmock:
152157
yieldmock
153158

154159

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp