Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit270edcf

Browse files
kravets-levkovarun-edachali-dbx
authored andcommitted
[PECO-1857] Use SSL options with HTTPS connection pool (#425)
* [PECO-1857] Use SSL options with HTTPS connection poolSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Some cleanupSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Resolve circular dependenciesSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Update existing testsSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Fix MyPy issuesSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Fix `_tls_no_verify` handlingSigned-off-by: Levko Kravets <levko.ne@gmail.com>* Add testsSigned-off-by: Levko Kravets <levko.ne@gmail.com>---------Signed-off-by: Levko Kravets <levko.ne@gmail.com>Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parentb1faa09 commit270edcf

File tree

11 files changed

+267
-159
lines changed

11 files changed

+267
-159
lines changed

‎src/databricks/sql/auth/thrift_http_client.py‎

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
importbase64
22
importlogging
33
importurllib.parse
4-
fromtypingimportDict,Union
4+
fromtypingimportDict,Union,Optional
55

66
importsix
77
importthrift
88

9-
logger=logging.getLogger(__name__)
10-
119
importssl
1210
importwarnings
1311
fromhttp.clientimportHTTPResponse
@@ -16,6 +14,9 @@
1614
fromurllib3importHTTPConnectionPool,HTTPSConnectionPool,ProxyManager
1715
fromurllib3.utilimportmake_headers
1816
fromdatabricks.sql.auth.retryimportCommandType,DatabricksRetryPolicy
17+
fromdatabricks.sql.typesimportSSLOptions
18+
19+
logger=logging.getLogger(__name__)
1920

2021

2122
classTHttpClient(thrift.transport.THttpClient.THttpClient):
@@ -25,13 +26,12 @@ def __init__(
2526
uri_or_host,
2627
port=None,
2728
path=None,
28-
cafile=None,
29-
cert_file=None,
30-
key_file=None,
31-
ssl_context=None,
29+
ssl_options:Optional[SSLOptions]=None,
3230
max_connections:int=1,
3331
retry_policy:Union[DatabricksRetryPolicy,int]=0,
3432
):
33+
self._ssl_options=ssl_options
34+
3535
ifportisnotNone:
3636
warnings.warn(
3737
"Please use the THttpClient('http{s}://host:port/path') constructor",
@@ -48,13 +48,11 @@ def __init__(
4848
self.scheme=parsed.scheme
4949
assertself.schemein ("http","https")
5050
ifself.scheme=="https":
51-
self.certfile=cert_file
52-
self.keyfile=key_file
53-
self.context= (
54-
ssl.create_default_context(cafile=cafile)
55-
if (cafileandnotssl_context)
56-
elsessl_context
57-
)
51+
ifself._ssl_optionsisnotNone:
52+
# TODO: Not sure if those options are used anywhere - need to double-check
53+
self.certfile=self._ssl_options.tls_client_cert_file
54+
self.keyfile=self._ssl_options.tls_client_cert_key_file
55+
self.context=self._ssl_options.create_ssl_context()
5856
self.port=parsed.port
5957
self.host=parsed.hostname
6058
self.path=parsed.path
@@ -109,12 +107,23 @@ def startRetryTimer(self):
109107
defopen(self):
110108

111109
# self.__pool replaces the self.__http used by the original THttpClient
110+
_pool_kwargs= {"maxsize":self.max_connections}
111+
112112
ifself.scheme=="http":
113113
pool_class=HTTPConnectionPool
114114
elifself.scheme=="https":
115115
pool_class=HTTPSConnectionPool
116-
117-
_pool_kwargs= {"maxsize":self.max_connections}
116+
_pool_kwargs.update(
117+
{
118+
"cert_reqs":ssl.CERT_REQUIRED
119+
ifself._ssl_options.tls_verify
120+
elsessl.CERT_NONE,
121+
"ca_certs":self._ssl_options.tls_trusted_ca_file,
122+
"cert_file":self._ssl_options.tls_client_cert_file,
123+
"key_file":self._ssl_options.tls_client_cert_key_file,
124+
"key_password":self._ssl_options.tls_client_cert_key_password,
125+
}
126+
)
118127

119128
ifself.using_proxy():
120129
proxy_manager=ProxyManager(

‎src/databricks/sql/client.py‎

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
fromdatabricks.sql.typesimportRow
38+
fromdatabricks.sql.typesimportRow,SSLOptions
3939
fromdatabricks.sql.auth.authimportget_python_sql_connector_auth_provider
4040
fromdatabricks.sql.experimental.oauth_persistenceimportOAuthPersistence
4141

@@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]:
178178
# _tls_trusted_ca_file
179179
# Set to the path of the file containing trusted CA certificates for server certificate
180180
# verification. If not provide, uses system truststore.
181-
# _tls_client_cert_file, _tls_client_cert_key_file
181+
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
182182
# Set client SSL certificate.
183+
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
183184
# _retry_stop_after_attempts_count
184185
# The maximum number of attempts during a request retry sequence (defaults to 24)
185186
# _socket_timeout
@@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]:
220221

221222
base_headers= [("User-Agent",useragent_header)]
222223

224+
self._ssl_options=SSLOptions(
225+
# Double negation is generally a bad thing, but we have to keep backward compatibility
226+
tls_verify=notkwargs.get(
227+
"_tls_no_verify",False
228+
),# by default - verify cert and host
229+
tls_verify_hostname=kwargs.get("_tls_verify_hostname",True),
230+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
231+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
232+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
233+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
234+
)
235+
223236
self.thrift_backend=ThriftBackend(
224237
self.host,
225238
self.port,
226239
http_path,
227240
(http_headersor [])+base_headers,
228241
auth_provider,
242+
ssl_options=self._ssl_options,
229243
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
230244
**kwargs,
231245
)

‎src/databricks/sql/cloudfetch/download_manager.py‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
importlogging
22

3-
fromsslimportSSLContext
43
fromconcurrent.futuresimportThreadPoolExecutor,Future
54
fromtypingimportList,Union
65

@@ -9,6 +8,8 @@
98
DownloadableResultSettings,
109
DownloadedFile,
1110
)
11+
fromdatabricks.sql.typesimportSSLOptions
12+
1213
fromdatabricks.sql.thrift_api.TCLIService.ttypesimportTSparkArrowResultLink
1314

1415
logger=logging.getLogger(__name__)
@@ -20,7 +21,7 @@ def __init__(
2021
links:List[TSparkArrowResultLink],
2122
max_download_threads:int,
2223
lz4_compressed:bool,
23-
ssl_context:SSLContext,
24+
ssl_options:SSLOptions,
2425
):
2526
self._pending_links:List[TSparkArrowResultLink]= []
2627
forlinkinlinks:
@@ -38,7 +39,7 @@ def __init__(
3839
self._thread_pool=ThreadPoolExecutor(max_workers=self._max_download_threads)
3940

4041
self._downloadable_result_settings=DownloadableResultSettings(lz4_compressed)
41-
self._ssl_context=ssl_context
42+
self._ssl_options=ssl_options
4243

4344
defget_next_downloaded_file(
4445
self,next_row_offset:int
@@ -95,7 +96,7 @@ def _schedule_downloads(self):
9596
handler=ResultSetDownloadHandler(
9697
settings=self._downloadable_result_settings,
9798
link=link,
98-
ssl_context=self._ssl_context,
99+
ssl_options=self._ssl_options,
99100
)
100101
task=self._thread_pool.submit(handler.run)
101102
self._download_tasks.append(task)

‎src/databricks/sql/cloudfetch/downloader.py‎

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
importrequests
55
fromrequests.adaptersimportHTTPAdapter,Retry
6-
fromsslimportSSLContext,CERT_NONE
76
importlz4.frame
87
importtime
98

109
fromdatabricks.sql.thrift_api.TCLIService.ttypesimportTSparkArrowResultLink
11-
1210
fromdatabricks.sql.excimportError
11+
fromdatabricks.sql.typesimportSSLOptions
1312

1413
logger=logging.getLogger(__name__)
1514

@@ -66,11 +65,11 @@ def __init__(
6665
self,
6766
settings:DownloadableResultSettings,
6867
link:TSparkArrowResultLink,
69-
ssl_context:SSLContext,
68+
ssl_options:SSLOptions,
7069
):
7170
self.settings=settings
7271
self.link=link
73-
self._ssl_context=ssl_context
72+
self._ssl_options=ssl_options
7473

7574
defrun(self)->DownloadedFile:
7675
"""
@@ -95,14 +94,13 @@ def run(self) -> DownloadedFile:
9594
session.mount("http://",HTTPAdapter(max_retries=retryPolicy))
9695
session.mount("https://",HTTPAdapter(max_retries=retryPolicy))
9796

98-
ssl_verify=self._ssl_context.verify_mode!=CERT_NONE
99-
10097
try:
10198
# Get the file via HTTP request
10299
response=session.get(
103100
self.link.fileLink,
104101
timeout=self.settings.download_timeout,
105-
verify=ssl_verify,
102+
verify=self._ssl_options.tls_verify,
103+
# TODO: Pass cert from `self._ssl_options`
106104
)
107105
response.raise_for_status()
108106

‎src/databricks/sql/thrift_backend.py‎

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
importtime
66
importuuid
77
importthreading
8-
fromsslimportCERT_NONE,CERT_REQUIRED,create_default_context
98
fromtypingimportList,Union
109

1110
importpyarrow
@@ -36,6 +35,7 @@
3635
convert_decimals_in_arrow_table,
3736
convert_column_based_set_to_arrow_table,
3837
)
38+
fromdatabricks.sql.typesimportSSLOptions
3939

4040
logger=logging.getLogger(__name__)
4141

@@ -85,6 +85,7 @@ def __init__(
8585
http_path:str,
8686
http_headers,
8787
auth_provider:AuthProvider,
88+
ssl_options:SSLOptions,
8889
staging_allowed_local_path:Union[None,str,List[str]]=None,
8990
**kwargs,
9091
):
@@ -93,16 +94,6 @@ def __init__(
9394
# Tag to add to User-Agent header. For use by partners.
9495
# _username, _password
9596
# Username and password Basic authentication (no official support)
96-
# _tls_no_verify
97-
# Set to True (Boolean) to completely disable SSL verification.
98-
# _tls_verify_hostname
99-
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
100-
# _tls_trusted_ca_file
101-
# Set to the path of the file containing trusted CA certificates for server certificate
102-
# verification. If not provide, uses system truststore.
103-
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
104-
# Set client SSL certificate.
105-
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
10697
# _connection_uri
10798
# Overrides server_hostname and http_path.
10899
# RETRY/ATTEMPT POLICY
@@ -162,29 +153,7 @@ def __init__(
162153
# Cloud fetch
163154
self.max_download_threads=kwargs.get("max_download_threads",10)
164155

165-
# Configure tls context
166-
ssl_context=create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
167-
ifkwargs.get("_tls_no_verify")isTrue:
168-
ssl_context.check_hostname=False
169-
ssl_context.verify_mode=CERT_NONE
170-
elifkwargs.get("_tls_verify_hostname")isFalse:
171-
ssl_context.check_hostname=False
172-
ssl_context.verify_mode=CERT_REQUIRED
173-
else:
174-
ssl_context.check_hostname=True
175-
ssl_context.verify_mode=CERT_REQUIRED
176-
177-
tls_client_cert_file=kwargs.get("_tls_client_cert_file")
178-
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file")
179-
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password")
180-
iftls_client_cert_file:
181-
ssl_context.load_cert_chain(
182-
certfile=tls_client_cert_file,
183-
keyfile=tls_client_cert_key_file,
184-
password=tls_client_cert_key_password,
185-
)
186-
187-
self._ssl_context=ssl_context
156+
self._ssl_options=ssl_options
188157

189158
self._auth_provider=auth_provider
190159

@@ -225,7 +194,7 @@ def __init__(
225194
self._transport=databricks.sql.auth.thrift_http_client.THttpClient(
226195
auth_provider=self._auth_provider,
227196
uri_or_host=uri,
228-
ssl_context=self._ssl_context,
197+
ssl_options=self._ssl_options,
229198
**additional_transport_args,# type: ignore
230199
)
231200

@@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
776745
max_download_threads=self.max_download_threads,
777746
lz4_compressed=lz4_compressed,
778747
description=description,
779-
ssl_context=self._ssl_context,
748+
ssl_options=self._ssl_options,
780749
)
781750
else:
782751
arrow_queue_opt=None
@@ -1008,7 +977,7 @@ def fetch_results(
1008977
max_download_threads=self.max_download_threads,
1009978
lz4_compressed=lz4_compressed,
1010979
description=description,
1011-
ssl_context=self._ssl_context,
980+
ssl_options=self._ssl_options,
1012981
)
1013982

1014983
returnqueue,resp.hasMoreRows

‎src/databricks/sql/types.py‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,54 @@
1919
fromtypingimportAny,Dict,List,Optional,Tuple,Union,TypeVar
2020
importdatetime
2121
importdecimal
22+
fromsslimportSSLContext,CERT_NONE,CERT_REQUIRED,create_default_context
23+
24+
25+
classSSLOptions:
26+
tls_verify:bool
27+
tls_verify_hostname:bool
28+
tls_trusted_ca_file:Optional[str]
29+
tls_client_cert_file:Optional[str]
30+
tls_client_cert_key_file:Optional[str]
31+
tls_client_cert_key_password:Optional[str]
32+
33+
def__init__(
34+
self,
35+
tls_verify:bool=True,
36+
tls_verify_hostname:bool=True,
37+
tls_trusted_ca_file:Optional[str]=None,
38+
tls_client_cert_file:Optional[str]=None,
39+
tls_client_cert_key_file:Optional[str]=None,
40+
tls_client_cert_key_password:Optional[str]=None,
41+
):
42+
self.tls_verify=tls_verify
43+
self.tls_verify_hostname=tls_verify_hostname
44+
self.tls_trusted_ca_file=tls_trusted_ca_file
45+
self.tls_client_cert_file=tls_client_cert_file
46+
self.tls_client_cert_key_file=tls_client_cert_key_file
47+
self.tls_client_cert_key_password=tls_client_cert_key_password
48+
49+
defcreate_ssl_context(self)->SSLContext:
50+
ssl_context=create_default_context(cafile=self.tls_trusted_ca_file)
51+
52+
ifself.tls_verifyisFalse:
53+
ssl_context.check_hostname=False
54+
ssl_context.verify_mode=CERT_NONE
55+
elifself.tls_verify_hostnameisFalse:
56+
ssl_context.check_hostname=False
57+
ssl_context.verify_mode=CERT_REQUIRED
58+
else:
59+
ssl_context.check_hostname=True
60+
ssl_context.verify_mode=CERT_REQUIRED
61+
62+
ifself.tls_client_cert_file:
63+
ssl_context.load_cert_chain(
64+
certfile=self.tls_client_cert_file,
65+
keyfile=self.tls_client_cert_key_file,
66+
password=self.tls_client_cert_key_password,
67+
)
68+
69+
returnssl_context
2270

2371

2472
classRow(tuple):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp