1010from requests .exceptions import RequestException
1111
1212from databricks .sql .auth .authenticators import CredentialsProvider ,HeaderFactory
13- from databricks .sql .auth .oidc_utils import OIDCDiscoveryUtil
13+ from databricks .sql .auth .oidc_utils import OIDCDiscoveryUtil , is_same_host
1414from databricks .sql .auth .token import Token
1515
1616logger = logging .getLogger (__name__ )
@@ -79,15 +79,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
7979 Configure and return a HeaderFactory that provides authentication headers.
8080 This is called by the ExternalAuthProvider to get headers for authentication.
8181 """
82- # First call the underlying credentials provider to get its headers
83- header_factory = self .credentials_provider (* args ,** kwargs )
84-
85- # Get the standard token endpoint if not already set
86- if self .token_endpoint is None :
87- self .token_endpoint = OIDCDiscoveryUtil .discover_token_endpoint (
88- self .hostname
89- )
90-
9182# Return a function that will get authentication headers
9283return self .get_auth_headers
9384
@@ -156,34 +147,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]:
156147
157148return None
158149
159- def _is_same_host (self ,url1 :str ,url2 :str )-> bool :
160- """
161- Check if two URLs have the same host.
162-
163- Args:
164- url1: First URL
165- url2: Second URL
166-
167- Returns:
168- bool: True if hosts are the same, False otherwise
169- """
170- try :
171- # Add protocol if missing to ensure proper parsing
172- if not url1 .startswith (("http://" ,"https://" )):
173- url1 = f"https://{ url1 } "
174- if not url2 .startswith (("http://" ,"https://" )):
175- url2 = f"https://{ url2 } "
176-
177- # Parse the URLs
178- parsed1 = urlparse (url1 )
179- parsed2 = urlparse (url2 )
180-
181- # Compare the hostnames
182- return parsed1 .netloc .lower ()== parsed2 .netloc .lower ()
183- except Exception as e :
184- logger .warning (f"Error comparing hosts:{ str (e )} " )
185- return False
186-
187150def refresh_token (self )-> Token :
188151"""
189152 Refresh the token and return the new Token object.
@@ -210,24 +173,34 @@ def refresh_token(self) -> Token:
210173token_claims = self ._parse_jwt_claims (access_token )
211174
212175# Create new token based on whether it's from the same host or not
213- if self . _is_same_host (token_claims .get ("iss" ,"" ),self .hostname ):
176+ if is_same_host (token_claims .get ("iss" ,"" ),self .hostname ):
214177# Token is from the same host, no need to exchange
215178logger .debug ("Token from same host, creating token without exchange" )
216-
217179expiry = self ._get_expiry_from_jwt (access_token )
218180if expiry is None :
219181raise ValueError ("Could not determine token expiry from JWT" )
220-
221182new_token = Token (access_token ,token_type ,"" ,expiry )
183+ self .current_token = new_token
184+ return new_token
222185else :
223186# Token is from a different host, need to exchange
224187logger .debug ("Token from different host, exchanging token" )
225- new_token = self ._exchange_token (access_token )
226-
227- # Store the token
228- self .current_token = new_token
229-
230- return new_token
188+ try :
189+ new_token = self ._exchange_token (access_token )
190+ self .current_token = new_token
191+ return new_token
192+ except Exception as e :
193+ logger .error (
194+ f"Token exchange failed:{ e } . Using external token as fallback."
195+ )
196+ expiry = self ._get_expiry_from_jwt (access_token )
197+ if expiry is None :
198+ raise ValueError (
199+ "Could not determine token expiry from JWT (after exchange failure)"
200+ )
201+ fallback_token = Token (access_token ,token_type ,"" ,expiry )
202+ self .current_token = fallback_token
203+ return fallback_token
231204
232205def get_current_token (self )-> Token :
233206"""
@@ -254,24 +227,19 @@ def get_auth_headers(self) -> Dict[str, str]:
254227"""
255228 Get authorization headers using the current token.
256229
257- This method gets the current token and returns it formatted
258- as authorization headers.
259-
260230 Returns:
261- Dict[str, str]: Authorization headers
231+ Dict[str, str]: Authorization headers (may include extra headers from provider)
262232 """
263233try :
264234token = self .get_current_token ()
265- return {"Authorization" :f"{ token .token_type } { token .access_token } " }
235+ # Always get the latest headers from the credentials provider
236+ header_factory = self .credentials_provider ()
237+ headers = dict (header_factory ())if header_factory else {}
238+ headers ["Authorization" ]= f"{ token .token_type } { token .access_token } "
239+ return headers
266240except Exception as e :
267241logger .error (f"Error getting auth headers:{ str (e )} " )
268-
269- # Fall back to external headers if available
270- if self .external_headers :
271- return self .external_headers
272-
273- # Return empty dict as a last resort
274- return {}
242+ return dict (self .external_headers )if self .external_headers else {}
275243
276244def _send_token_exchange_request (
277245self ,token_exchange_data :Dict [str ,str ]
@@ -286,7 +254,7 @@ def _send_token_exchange_request(
286254 Dict[str, Any]: Token exchange response
287255
288256 Raises:
289- ValueError : If token exchange fails
257+ requests.HTTPError : If token exchange fails
290258 """
291259if not self .token_endpoint :
292260raise ValueError ("Token endpoint not initialized" )
@@ -296,9 +264,9 @@ def _send_token_exchange_request(
296264 )
297265
298266if response .status_code != 200 :
299- raise ValueError (
300- f"Token exchange failed with status code{ response .status_code } :"
301- f" { response . text } "
267+ raise requests . HTTPError (
268+ f"Token exchange failed with status code{ response .status_code } :{ response . text } " ,
269+ response = response ,
302270 )
303271
304272return response .json ()
@@ -316,6 +284,10 @@ def _exchange_token(self, access_token: str) -> Token:
316284 Raises:
317285 ValueError: If token exchange fails
318286 """
287+ if self .token_endpoint is None :
288+ self .token_endpoint = OIDCDiscoveryUtil .discover_token_endpoint (
289+ self .hostname
290+ )
319291# Prepare the request data
320292token_exchange_data = dict (self .TOKEN_EXCHANGE_PARAMS )
321293token_exchange_data ["subject_token" ]= access_token