1515get_python_sql_connector_auth_provider ,
1616PYSQL_OAUTH_CLIENT_ID ,
1717)
18- from databricks .sql .auth .oauth import OAuthManager
18+ from databricks .sql .auth .oauth import OAuthManager , Token , ClientCredentialsTokenSource
1919from databricks .sql .auth .authenticators import (
2020DatabricksOAuthProvider ,
2121AzureServicePrincipalCredentialProvider ,
22- Token ,
2322)
2423from databricks .sql .auth .endpoint import (
2524CloudType ,
@@ -198,16 +197,16 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
198197self .assertTrue (auth_provider ._client_id ,PYSQL_OAUTH_CLIENT_ID )
199198
200199
201- class TestAzureServicePrincipalCredentialProvider :
200+ class TestClientCredentialsTokenSource :
202201@pytest .fixture
203202def indefinite_token (self ):
204203secret_key = "mysecret"
205204expires_in_100_years = int (time .time ())+ (100 * 365 * 24 * 60 * 60 )
206205
207206payload = {"sub" :"user123" ,"role" :"admin" ,"exp" :expires_in_100_years }
208207
209- token = jwt .encode (payload ,secret_key ,algorithm = "HS256" )
210- return Token (token ,"Bearer" ,"refresh_token" )
208+ access_token = jwt .encode (payload ,secret_key ,algorithm = "HS256" )
209+ return Token (access_token ,"Bearer" ,"refresh_token" )
211210
212211@pytest .fixture
213212def http_response (self ):
@@ -224,67 +223,75 @@ def status_response(response_status_code):
224223return status_response
225224
226225@pytest .fixture
227- def provider (self ):
228- return AzureServicePrincipalCredentialProvider (
229- client_id = "dummy-client " ,
230- client_secret = "dummy-secret " ,
231- tenant_id = "dummy-tenant " ,
226+ def token_source (self ):
227+ return ClientCredentialsTokenSource (
228+ token_url = "https://token_url.com " ,
229+ oauth_client_id = "client_id " ,
230+ oauth_client_secret = "client_secret " ,
232231 )
233232
234- def test_token_refresh (self ,provider ):
235- with patch .object (provider ,"_get_token" )as mock_get_token :
236- mock_get_token .return_value = Token (
237- "access_token" ,"Bearer" ,"refresh_token"
238- )
239- header_factory = provider ()
240- headers = header_factory ()
241-
242- assert headers ["Authorization" ]== "Bearer access_token"
243- mock_get_token .assert_called_once ()
244-
245233def test_no_token_refresh__when_token_is_not_expired (
246- self ,provider ,indefinite_token
234+ self ,token_source ,indefinite_token
247235 ):
248- with patch .object (provider ,"_get_token " )as mock_get_token :
236+ with patch .object (token_source ,"refresh " )as mock_get_token :
249237mock_get_token .return_value = indefinite_token
250238
251- # Call the provider multiple times
252- header_factory1 = provider ()
253- header_factory2 = provider ()
254- header_factory3 = provider ()
255-
256- # Get headers from each factory
257- headers1 = header_factory1 ()
258- headers2 = header_factory2 ()
259- headers3 = header_factory3 ()
239+ # Mulitple calls for token
240+ token1 = token_source .get_token ()
241+ token2 = token_source .get_token ()
242+ token3 = token_source .get_token ()
260243
261- # Verify _get_token was called only once
262- mock_get_token .assert_called_once ()
244+ assert token1 == token2 == token3
245+ assert token1 .access_token == indefinite_token .access_token
246+ assert token1 .token_type == indefinite_token .token_type
247+ assert token1 .refresh_token == indefinite_token .refresh_token
263248
264- # Verify all headers contain the same token
265- expected_auth_header = f"Bearer{ indefinite_token .access_token } "
266- assert headers1 ["Authorization" ]== expected_auth_header
267- assert headers2 ["Authorization" ]== expected_auth_header
268- assert headers3 ["Authorization" ]== expected_auth_header
249+ # should refresh only once as token is not expired
250+ assert mock_get_token .call_count == 1
269251
270- def test_get_token_success (self ,provider ,http_response ):
271-
272- # Patch the HTTP client's execute method
273- with patch .object (
274- provider ._http_client ,"execute" ,return_value = http_response (200 )
275- )as mock_execute :
276- token = provider ._get_token ()
252+ def test_get_token_success (self ,token_source ,http_response ):
253+ with patch .object (token_source ._http_client ,"execute" )as mock_execute :
254+ mock_execute .return_value = http_response (200 )
255+ token = token_source .get_token ()
277256
278257# Assert
279258assert isinstance (token ,Token )
280259assert token .access_token == "abc123"
281260assert token .token_type == "Bearer"
282261assert token .refresh_token is None
283262
284- def test_get_token_failure (self ,provider ,http_response ):
285- with patch .object (
286- provider ._http_client ,"execute" ,return_value = http_response (400 )
287- )as mock_execute :
263+ def test_get_token_failure (self ,token_source ,http_response ):
264+ with patch .object (token_source ._http_client ,"execute" )as mock_execute :
265+ mock_execute .return_value = http_response (400 )
288266with pytest .raises (Exception )as e :
289- provider . _get_token ()
267+ token_source . get_token ()
290268assert "Failed to get token: 400" in str (e .value )
269+
270+
271+ class TestAzureServicePrincipalCredentialProvider :
272+ @pytest .fixture
273+ def credential_provider (self ):
274+ return AzureServicePrincipalCredentialProvider (
275+ hostname = "hostname" ,
276+ oauth_client_id = "client_id" ,
277+ oauth_client_secret = "client_secret" ,
278+ azure_tenant_id = "tenant_id" ,
279+ )
280+
281+ def test_provider_credentials (self ,credential_provider ):
282+
283+ test_token = Token ("access_token" ,"Bearer" ,"refresh_token" )
284+
285+ with patch .object (
286+ credential_provider ,"get_token_source"
287+ )as mock_get_token_source :
288+ mock_get_token_source .return_value = MagicMock ()
289+ mock_get_token_source .return_value .get_token .return_value = test_token
290+
291+ headers = credential_provider ()()
292+
293+ assert headers ["Authorization" ]== f"Bearer{ test_token .access_token } "
294+ assert (
295+ headers ["X-Databricks-Azure-SP-Management-Token" ]
296+ == test_token .access_token
297+ )