Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[PECO-1857] Use SSL options with HTTPS connection pool#425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
kravets-levko merged 7 commits intomainfromPECO-1857-ssl-options-ignored
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletionssrc/databricks/sql/auth/thrift_http_client.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
import base64
import logging
import urllib.parse
from typing import Dict, Union
from typing import Dict, Union, Optional

import six
import thrift

logger = logging.getLogger(__name__)

import ssl
import warnings
from http.client import HTTPResponse
Expand All@@ -16,6 +14,9 @@
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
from urllib3.util import make_headers
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)


class THttpClient(thrift.transport.THttpClient.THttpClient):
Expand All@@ -25,13 +26,12 @@ def __init__(
uri_or_host,
port=None,
path=None,
cafile=None,
cert_file=None,
key_file=None,
ssl_context=None,
ssl_options: Optional[SSLOptions] = None,
max_connections: int = 1,
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
):
self._ssl_options = ssl_options

if port is not None:
warnings.warn(
"Please use the THttpClient('http{s}://host:port/path') constructor",
Expand All@@ -48,13 +48,11 @@ def __init__(
self.scheme = parsed.scheme
assert self.scheme in ("http", "https")
if self.scheme == "https":
self.certfile = cert_file
self.keyfile = key_file
self.context = (
ssl.create_default_context(cafile=cafile)
if (cafile and not ssl_context)
else ssl_context
)
if self._ssl_options is not None:
# TODO: Not sure if those options are used anywhere - need to double-check
self.certfile = self._ssl_options.tls_client_cert_file
self.keyfile = self._ssl_options.tls_client_cert_key_file
self.context = self._ssl_options.create_ssl_context()
self.port = parsed.port
self.host = parsed.hostname
self.path = parsed.path
Expand DownExpand Up@@ -109,12 +107,23 @@ def startRetryTimer(self):
def open(self):

# self.__pool replaces the self.__http used by the original THttpClient
_pool_kwargs = {"maxsize": self.max_connections}

if self.scheme == "http":
pool_class = HTTPConnectionPool
elif self.scheme == "https":
pool_class = HTTPSConnectionPool

_pool_kwargs = {"maxsize": self.max_connections}
_pool_kwargs.update(
{
"cert_reqs": ssl.CERT_REQUIRED
if self._ssl_options.tls_verify
else ssl.CERT_NONE,
"ca_certs": self._ssl_options.tls_trusted_ca_file,
"cert_file": self._ssl_options.tls_client_cert_file,
"key_file": self._ssl_options.tls_client_cert_key_file,
"key_password": self._ssl_options.tls_client_cert_key_password,
}
)

if self.using_proxy():
proxy_manager = ProxyManager(
Expand Down
18 changes: 16 additions & 2 deletionssrc/databricks/sql/client.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -35,7 +35,7 @@
)


fromdatabricks.sql.typesimportRow
fromdatabricks.sql.typesimportRow,SSLOptions
fromdatabricks.sql.auth.authimportget_python_sql_connector_auth_provider
fromdatabricks.sql.experimental.oauth_persistenceimportOAuthPersistence

Expand DownExpand Up@@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]:
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _retry_stop_after_attempts_count
# The maximum number of attempts during a request retry sequence (defaults to 24)
# _socket_timeout
Expand DownExpand Up@@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]:

base_headers= [("User-Agent",useragent_header)]

self._ssl_options=SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=notkwargs.get(
"_tls_no_verify",False
),# by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname",True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend=ThriftBackend(
self.host,
self.port,
http_path,
(http_headersor [])+base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)
Expand Down
9 changes: 5 additions & 4 deletionssrc/databricks/sql/cloudfetch/download_manager.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
importlogging

fromsslimportSSLContext
fromconcurrent.futuresimportThreadPoolExecutor,Future
fromtypingimportList,Union

Expand All@@ -9,6 +8,8 @@
DownloadableResultSettings,
DownloadedFile,
)
fromdatabricks.sql.typesimportSSLOptions

fromdatabricks.sql.thrift_api.TCLIService.ttypesimportTSparkArrowResultLink

logger=logging.getLogger(__name__)
Expand All@@ -20,7 +21,7 @@ def __init__(
links:List[TSparkArrowResultLink],
max_download_threads:int,
lz4_compressed:bool,
ssl_context:SSLContext,
ssl_options:SSLOptions,
):
self._pending_links:List[TSparkArrowResultLink]= []
forlinkinlinks:
Expand All@@ -38,7 +39,7 @@ def __init__(
self._thread_pool=ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings=DownloadableResultSettings(lz4_compressed)
self._ssl_context=ssl_context
self._ssl_options=ssl_options

defget_next_downloaded_file(
self,next_row_offset:int
Expand DownExpand Up@@ -95,7 +96,7 @@ def _schedule_downloads(self):
handler=ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
task=self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand Down
12 changes: 5 additions & 7 deletionssrc/databricks/sql/cloudfetch/downloader.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -3,13 +3,12 @@

importrequests
fromrequests.adaptersimportHTTPAdapter,Retry
fromsslimportSSLContext,CERT_NONE
importlz4.frame
importtime

fromdatabricks.sql.thrift_api.TCLIService.ttypesimportTSparkArrowResultLink

fromdatabricks.sql.excimportError
fromdatabricks.sql.typesimportSSLOptions

logger=logging.getLogger(__name__)

Expand DownExpand Up@@ -66,11 +65,11 @@ def __init__(
self,
settings:DownloadableResultSettings,
link:TSparkArrowResultLink,
ssl_context:SSLContext,
ssl_options:SSLOptions,
):
self.settings=settings
self.link=link
self._ssl_context=ssl_context
self._ssl_options=ssl_options

defrun(self)->DownloadedFile:
"""
Expand All@@ -95,14 +94,13 @@ def run(self) -> DownloadedFile:
session.mount("http://",HTTPAdapter(max_retries=retryPolicy))
session.mount("https://",HTTPAdapter(max_retries=retryPolicy))

ssl_verify=self._ssl_context.verify_mode!=CERT_NONE

try:
# Get the file via HTTP request
response=session.get(
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=ssl_verify,
verify=self._ssl_options.tls_verify,
# TODO: Pass cert from `self._ssl_options`
)
response.raise_for_status()

Expand Down
43 changes: 6 additions & 37 deletionssrc/databricks/sql/thrift_backend.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,7 +5,6 @@
importtime
importuuid
importthreading
fromsslimportCERT_NONE,CERT_REQUIRED,create_default_context
fromtypingimportList,Union

importpyarrow
Expand DownExpand Up@@ -36,6 +35,7 @@
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)
fromdatabricks.sql.typesimportSSLOptions

logger=logging.getLogger(__name__)

Expand DownExpand Up@@ -85,6 +85,7 @@ def __init__(
http_path:str,
http_headers,
auth_provider:AuthProvider,
ssl_options:SSLOptions,
staging_allowed_local_path:Union[None,str,List[str]]=None,
**kwargs,
):
Expand All@@ -93,16 +94,6 @@ def __init__(
# Tag to add to User-Agent header. For use by partners.
# _username, _password
# Username and password Basic authentication (no official support)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _connection_uri
# Overrides server_hostname and http_path.
# RETRY/ATTEMPT POLICY
Expand DownExpand Up@@ -162,29 +153,7 @@ def __init__(
# Cloud fetch
self.max_download_threads=kwargs.get("max_download_threads",10)

# Configure tls context
ssl_context=create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
ifkwargs.get("_tls_no_verify")isTrue:
ssl_context.check_hostname=False
ssl_context.verify_mode=CERT_NONE
elifkwargs.get("_tls_verify_hostname")isFalse:
ssl_context.check_hostname=False
ssl_context.verify_mode=CERT_REQUIRED
else:
ssl_context.check_hostname=True
ssl_context.verify_mode=CERT_REQUIRED

tls_client_cert_file=kwargs.get("_tls_client_cert_file")
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file")
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password")
iftls_client_cert_file:
ssl_context.load_cert_chain(
certfile=tls_client_cert_file,
keyfile=tls_client_cert_key_file,
password=tls_client_cert_key_password,
)

self._ssl_context=ssl_context
self._ssl_options=ssl_options

self._auth_provider=auth_provider

Expand DownExpand Up@@ -225,7 +194,7 @@ def __init__(
self._transport=databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
**additional_transport_args,# type: ignore
)

Expand DownExpand Up@@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
else:
arrow_queue_opt=None
Expand DownExpand Up@@ -1008,7 +977,7 @@ def fetch_results(
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)

returnqueue,resp.hasMoreRows
Expand Down
48 changes: 48 additions & 0 deletionssrc/databricks/sql/types.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -19,6 +19,54 @@
fromtypingimportAny,Dict,List,Optional,Tuple,Union,TypeVar
importdatetime
importdecimal
fromsslimportSSLContext,CERT_NONE,CERT_REQUIRED,create_default_context


classSSLOptions:
tls_verify:bool
tls_verify_hostname:bool
tls_trusted_ca_file:Optional[str]
tls_client_cert_file:Optional[str]
tls_client_cert_key_file:Optional[str]
tls_client_cert_key_password:Optional[str]

def__init__(
self,
tls_verify:bool=True,
tls_verify_hostname:bool=True,
tls_trusted_ca_file:Optional[str]=None,
tls_client_cert_file:Optional[str]=None,
tls_client_cert_key_file:Optional[str]=None,
tls_client_cert_key_password:Optional[str]=None,
):
self.tls_verify=tls_verify
self.tls_verify_hostname=tls_verify_hostname
self.tls_trusted_ca_file=tls_trusted_ca_file
self.tls_client_cert_file=tls_client_cert_file
self.tls_client_cert_key_file=tls_client_cert_key_file
self.tls_client_cert_key_password=tls_client_cert_key_password

defcreate_ssl_context(self)->SSLContext:
ssl_context=create_default_context(cafile=self.tls_trusted_ca_file)

ifself.tls_verifyisFalse:
ssl_context.check_hostname=False
ssl_context.verify_mode=CERT_NONE
elifself.tls_verify_hostnameisFalse:
ssl_context.check_hostname=False
ssl_context.verify_mode=CERT_REQUIRED
else:
ssl_context.check_hostname=True
ssl_context.verify_mode=CERT_REQUIRED

ifself.tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=self.tls_client_cert_file,
keyfile=self.tls_client_cert_key_file,
password=self.tls_client_cert_key_password,
)

returnssl_context


classRow(tuple):
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp