Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add external auth provider#101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletionsexamples/README.md
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -38,3 +38,4 @@ To run all of these examples you can clone the entire repository to your disk. O
this example the string `ExamplePartnerTag` will be added to the the user agent on every request.
- **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`.
- **`sqlalchemy.py`** shows a basic example of connecting to Databricks with [SQLAlchemy](https://www.sqlalchemy.org/).
- **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example.
29 changes: 29 additions & 0 deletionsexamples/custom_cred_provider.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
# please pip install databricks-sdk prior to running this example.

from databricks import sql
from databricks.sdk.oauth import OAuthClient
import os

oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
redirect_url=os.getenv("APP_REDIRECT_URL"),
scopes=['all-apis', 'offline_access'])

consent = oauth_client.initiate_consent()

creds = consent.launch_external_browser()

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=creds) as connection:

for x in range(1, 5):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
result = cursor.fetchall()
for row in result:
print(row)
cursor.close()

connection.close()
6 changes: 6 additions & 0 deletionssrc/databricks/sql/auth/auth.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,6 +5,7 @@
AuthProvider,
AccessTokenAuthProvider,
BasicAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
)
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand All@@ -30,6 +31,7 @@ def __init__(
use_cert_as_auth: str = None,
tls_client_cert_file: str = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.username = username
Expand All@@ -42,9 +44,12 @@ def __init__(
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider


def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
Expand DownExpand Up@@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)
29 changes: 28 additions & 1 deletionsrc/databricks/sql/auth/authenticators.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
import abc
import base64
import logging
from typing import Dict, List
from typing importCallable,Dict, List

from databricks.sql.auth.oauth import OAuthManager

Expand All@@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]):
pass


HeaderFactory = Callable[[], Dict[str, str]]

# In order to keep compatibility with SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
for authenticating requests to Databricks REST APIs"""

@abc.abstractmethod
def auth_type(self) -> str:
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
...


# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
class AccessTokenAuthProvider(AuthProvider):
Expand DownExpand Up@@ -120,3 +137,13 @@ def _update_token_if_expired(self):
except Exception as e:
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
raise e


class ExternalAuthProvider(AuthProvider):
def __init__(self, credentials_provider: CredentialsProvider) -> None:
self._header_factory = credentials_provider()

def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v
37 changes: 36 additions & 1 deletiontests/unit/test_auth.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
import unittest

from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory


class Auth(unittest.TestCase):
Expand DownExpand Up@@ -37,6 +38,22 @@ def test_noop_auth_provider(self):
self.assertEqual(len(http_request.keys()), 1)
self.assertEqual(http_request['myKey'], 'myVal')

def test_external_provider(self):
class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

auth = ExternalAuthProvider(MyProvider())

http_request = {'myKey': 'myVal'}
auth.add_headers(http_request)
self.assertEqual(http_request['foo'], 'bar')
self.assertEqual(len(http_request.keys()), 2)
self.assertEqual(http_request['myKey'], 'myVal')

def test_get_python_sql_connector_auth_provider_access_token(self):
hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'access_token': 'dpi123'}
Expand All@@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
auth_provider.add_headers(headers)
self.assertEqual(headers['Authorization'], 'Bearer dpi123')

def test_get_python_sql_connector_auth_provider_external(self):

class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'credentials_provider': MyProvider()}
auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs)
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")

headers = {}
auth_provider.add_headers(headers)
self.assertEqual(headers['foo'], 'bar')

def test_get_python_sql_connector_auth_provider_username_password(self):
username = "moderakh"
password = "Elevate Databricks 123!!!"
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp