Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit85d0cd9

Browse files
committed
addresses comments
1 parent9fc4c0c commit85d0cd9

File tree

3 files changed

+60
-67
lines changed

3 files changed

+60
-67
lines changed

‎src/databricks/sql/auth/oidc_utils.py‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importlogging
22
importrequests
33
fromtypingimportOptional
4+
fromurllib.parseimporturlparse
45

56
fromdatabricks.sql.auth.endpointimport (
67
get_oauth_endpoints,
@@ -56,3 +57,19 @@ def format_hostname(hostname: str) -> str:
5657
ifnothostname.endswith("/"):
5758
hostname=f"{hostname}/"
5859
returnhostname
60+
61+
62+
defis_same_host(url1:str,url2:str)->bool:
63+
"""
64+
Check if two URLs have the same host.
65+
"""
66+
try:
67+
ifnoturl1.startswith(("http://","https://")):
68+
url1=f"https://{url1}"
69+
ifnoturl2.startswith(("http://","https://")):
70+
url2=f"https://{url2}"
71+
parsed1=urlparse(url1)
72+
parsed2=urlparse(url2)
73+
returnparsed1.netloc.lower()==parsed2.netloc.lower()
74+
exceptException:
75+
returnFalse

‎src/databricks/sql/auth/token_federation.py‎

Lines changed: 35 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
fromrequests.exceptionsimportRequestException
1111

1212
fromdatabricks.sql.auth.authenticatorsimportCredentialsProvider,HeaderFactory
13-
fromdatabricks.sql.auth.oidc_utilsimportOIDCDiscoveryUtil
13+
fromdatabricks.sql.auth.oidc_utilsimportOIDCDiscoveryUtil,is_same_host
1414
fromdatabricks.sql.auth.tokenimportToken
1515

1616
logger=logging.getLogger(__name__)
@@ -79,15 +79,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
7979
Configure and return a HeaderFactory that provides authentication headers.
8080
This is called by the ExternalAuthProvider to get headers for authentication.
8181
"""
82-
# First call the underlying credentials provider to get its headers
83-
header_factory=self.credentials_provider(*args,**kwargs)
84-
85-
# Get the standard token endpoint if not already set
86-
ifself.token_endpointisNone:
87-
self.token_endpoint=OIDCDiscoveryUtil.discover_token_endpoint(
88-
self.hostname
89-
)
90-
9182
# Return a function that will get authentication headers
9283
returnself.get_auth_headers
9384

@@ -156,34 +147,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]:
156147

157148
returnNone
158149

159-
def_is_same_host(self,url1:str,url2:str)->bool:
160-
"""
161-
Check if two URLs have the same host.
162-
163-
Args:
164-
url1: First URL
165-
url2: Second URL
166-
167-
Returns:
168-
bool: True if hosts are the same, False otherwise
169-
"""
170-
try:
171-
# Add protocol if missing to ensure proper parsing
172-
ifnoturl1.startswith(("http://","https://")):
173-
url1=f"https://{url1}"
174-
ifnoturl2.startswith(("http://","https://")):
175-
url2=f"https://{url2}"
176-
177-
# Parse the URLs
178-
parsed1=urlparse(url1)
179-
parsed2=urlparse(url2)
180-
181-
# Compare the hostnames
182-
returnparsed1.netloc.lower()==parsed2.netloc.lower()
183-
exceptExceptionase:
184-
logger.warning(f"Error comparing hosts:{str(e)}")
185-
returnFalse
186-
187150
defrefresh_token(self)->Token:
188151
"""
189152
Refresh the token and return the new Token object.
@@ -210,24 +173,34 @@ def refresh_token(self) -> Token:
210173
token_claims=self._parse_jwt_claims(access_token)
211174

212175
# Create new token based on whether it's from the same host or not
213-
ifself._is_same_host(token_claims.get("iss",""),self.hostname):
176+
ifis_same_host(token_claims.get("iss",""),self.hostname):
214177
# Token is from the same host, no need to exchange
215178
logger.debug("Token from same host, creating token without exchange")
216-
217179
expiry=self._get_expiry_from_jwt(access_token)
218180
ifexpiryisNone:
219181
raiseValueError("Could not determine token expiry from JWT")
220-
221182
new_token=Token(access_token,token_type,"",expiry)
183+
self.current_token=new_token
184+
returnnew_token
222185
else:
223186
# Token is from a different host, need to exchange
224187
logger.debug("Token from different host, exchanging token")
225-
new_token=self._exchange_token(access_token)
226-
227-
# Store the token
228-
self.current_token=new_token
229-
230-
returnnew_token
188+
try:
189+
new_token=self._exchange_token(access_token)
190+
self.current_token=new_token
191+
returnnew_token
192+
exceptExceptionase:
193+
logger.error(
194+
f"Token exchange failed:{e}. Using external token as fallback."
195+
)
196+
expiry=self._get_expiry_from_jwt(access_token)
197+
ifexpiryisNone:
198+
raiseValueError(
199+
"Could not determine token expiry from JWT (after exchange failure)"
200+
)
201+
fallback_token=Token(access_token,token_type,"",expiry)
202+
self.current_token=fallback_token
203+
returnfallback_token
231204

232205
defget_current_token(self)->Token:
233206
"""
@@ -254,24 +227,19 @@ def get_auth_headers(self) -> Dict[str, str]:
254227
"""
255228
Get authorization headers using the current token.
256229
257-
This method gets the current token and returns it formatted
258-
as authorization headers.
259-
260230
Returns:
261-
Dict[str, str]: Authorization headers
231+
Dict[str, str]: Authorization headers (may include extra headers from provider)
262232
"""
263233
try:
264234
token=self.get_current_token()
265-
return {"Authorization":f"{token.token_type}{token.access_token}"}
235+
# Always get the latest headers from the credentials provider
236+
header_factory=self.credentials_provider()
237+
headers=dict(header_factory())ifheader_factoryelse {}
238+
headers["Authorization"]=f"{token.token_type}{token.access_token}"
239+
returnheaders
266240
exceptExceptionase:
267241
logger.error(f"Error getting auth headers:{str(e)}")
268-
269-
# Fall back to external headers if available
270-
ifself.external_headers:
271-
returnself.external_headers
272-
273-
# Return empty dict as a last resort
274-
return {}
242+
returndict(self.external_headers)ifself.external_headerselse {}
275243

276244
def_send_token_exchange_request(
277245
self,token_exchange_data:Dict[str,str]
@@ -286,7 +254,7 @@ def _send_token_exchange_request(
286254
Dict[str, Any]: Token exchange response
287255
288256
Raises:
289-
ValueError: If token exchange fails
257+
requests.HTTPError: If token exchange fails
290258
"""
291259
ifnotself.token_endpoint:
292260
raiseValueError("Token endpoint not initialized")
@@ -296,9 +264,9 @@ def _send_token_exchange_request(
296264
)
297265

298266
ifresponse.status_code!=200:
299-
raiseValueError(
300-
f"Token exchange failed with status code{response.status_code}:"
301-
f"{response.text}"
267+
raiserequests.HTTPError(
268+
f"Token exchange failed with status code{response.status_code}:{response.text}",
269+
response=response,
302270
)
303271

304272
returnresponse.json()
@@ -316,6 +284,10 @@ def _exchange_token(self, access_token: str) -> Token:
316284
Raises:
317285
ValueError: If token exchange fails
318286
"""
287+
ifself.token_endpointisNone:
288+
self.token_endpoint=OIDCDiscoveryUtil.discover_token_endpoint(
289+
self.hostname
290+
)
319291
# Prepare the request data
320292
token_exchange_data=dict(self.TOKEN_EXCHANGE_PARAMS)
321293
token_exchange_data["subject_token"]=access_token

‎tests/unit/test_token_federation.py‎

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def mock_dependencies(self):
145145
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token"
146146
)asmock_exchange:
147147
withpatch(
148-
"databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host"
148+
"databricks.sql.auth.oidc_utils.is_same_host"
149149
)asmock_is_same_host:
150150
withpatch(
151151
"databricks.sql.auth.token_federation.requests.post"
@@ -179,9 +179,11 @@ def test_provider_initialization(self, federation_provider):
179179
("databricks.com","https://databricks.com",True),
180180
],
181181
)
182-
deftest_is_same_host(self,federation_provider,url1,url2,expected):
182+
deftest_is_same_host(self,url1,url2,expected):
183183
"""Test host comparison logic with various URL formats."""
184-
assertfederation_provider._is_same_host(url1,url2)isexpected
184+
fromdatabricks.sql.auth.oidc_utilsimportis_same_host
185+
186+
assertis_same_host(url1,url2)isexpected
185187

186188
@pytest.mark.parametrize(
187189
"headers,expected_result,should_raise",
@@ -389,7 +391,9 @@ def test_token_exchange_failure(self, federation_provider):
389391
mock_post.return_value=mock_response
390392

391393
# Call the method and expect an exception
394+
importrequests
395+
392396
withpytest.raises(
393-
ValueError,match="Token exchange failed with status code 401"
397+
requests.HTTPError,match="Token exchange failed with status code 401"
394398
):
395399
federation_provider._exchange_token("original_token")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp