Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite8fc63b

Browse files
authored
Cloud fetch queue and integration (#151)
* Cloud fetch queue and integrationSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Enable cloudfetch with direct resultsSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Typing and style changesSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Client-settable max_download_threadsSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Docstrings and commentsSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Increase default buffer size bytes to 104857600Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Move max_download_threads to kwargs of ThriftBackend, fix unit testsSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Fix tests: staticmethod make_arrow_table mock not callableSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* cancel_futures in shutdown() only available in python >=3.9.0Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Black lintingSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>* Fix typing errorsSigned-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>---------Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
1 parent061c763 commite8fc63b

File tree

6 files changed

+596
-136
lines changed

6 files changed

+596
-136
lines changed

‎src/databricks/sql/client.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
logger=logging.getLogger(__name__)
1919

20-
DEFAULT_RESULT_BUFFER_SIZE_BYTES=10485760
20+
DEFAULT_RESULT_BUFFER_SIZE_BYTES=104857600
2121
DEFAULT_ARRAY_SIZE=100000
2222

2323

@@ -153,6 +153,8 @@ def read(self) -> Optional[OAuthToken]:
153153
# _use_arrow_native_timestamps
154154
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
155155
# (True by default)
156+
# use_cloud_fetch
157+
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
156158

157159
ifaccess_token:
158160
access_token_kv= {"access_token":access_token}
@@ -189,6 +191,7 @@ def read(self) -> Optional[OAuthToken]:
189191
self._session_handle=self.thrift_backend.open_session(
190192
session_configuration,catalog,schema
191193
)
194+
self.use_cloud_fetch=kwargs.get("use_cloud_fetch",False)
192195
self.open=True
193196
logger.info("Successfully opened session "+str(self.get_session_id_hex()))
194197
self._cursors= []# type: List[Cursor]
@@ -497,6 +500,7 @@ def execute(
497500
max_bytes=self.buffer_size_bytes,
498501
lz4_compression=self.connection.lz4_compression,
499502
cursor=self,
503+
use_cloud_fetch=self.connection.use_cloud_fetch,
500504
)
501505
self.active_result_set=ResultSet(
502506
self.connection,
@@ -822,6 +826,7 @@ def __iter__(self):
822826
break
823827

824828
def_fill_results_buffer(self):
829+
# At initialization or if the server does not have cloud fetch result links available
825830
results,has_more_rows=self.thrift_backend.fetch_results(
826831
op_handle=self.command_id,
827832
max_rows=self.arraysize,

‎src/databricks/sql/cloudfetch/download_manager.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
161161
returnTrue
162162

163163
def_shutdown_manager(self):
164-
# Clear download handlers and shutdown the thread pool to cancel pending futures
164+
# Clear download handlers and shutdown the thread pool
165165
self.download_handlers= []
166-
self.thread_pool.shutdown(wait=False,cancel_futures=True)
166+
self.thread_pool.shutdown(wait=False)

‎src/databricks/sql/thrift_backend.py‎

Lines changed: 39 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
importtime
66
importuuid
77
importthreading
8-
importlz4.frame
98
fromsslimportCERT_NONE,CERT_REQUIRED,create_default_context
109
fromtypingimportList,Union
1110

@@ -26,11 +25,14 @@
2625
)
2726

2827
fromdatabricks.sql.utilsimport (
29-
ArrowQueue,
3028
ExecuteResponse,
3129
_bound,
3230
RequestErrorInfo,
3331
NoRetryReason,
32+
ResultSetQueueFactory,
33+
convert_arrow_based_set_to_arrow_table,
34+
convert_decimals_in_arrow_table,
35+
convert_column_based_set_to_arrow_table,
3436
)
3537

3638
logger=logging.getLogger(__name__)
@@ -67,7 +69,6 @@
6769
classThriftBackend:
6870
CLOSED_OP_STATE=ttypes.TOperationState.CLOSED_STATE
6971
ERROR_OP_STATE=ttypes.TOperationState.ERROR_STATE
70-
BIT_MASKS= [1,2,4,8,16,32,64,128]
7172

7273
def__init__(
7374
self,
@@ -115,6 +116,8 @@ def __init__(
115116
# _socket_timeout
116117
# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
117118
# (defaults to 900)
119+
# max_download_threads
120+
# Number of threads for handling cloud fetch downloads. Defaults to 10
118121

119122
port=portor443
120123
ifkwargs.get("_connection_uri"):
@@ -136,6 +139,9 @@ def __init__(
136139
"_use_arrow_native_timestamps",True
137140
)
138141

142+
# Cloud fetch
143+
self.max_download_threads=kwargs.get("max_download_threads",10)
144+
139145
# Configure tls context
140146
ssl_context=create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
141147
ifkwargs.get("_tls_no_verify")isTrue:
@@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558564
(
559565
arrow_table,
560566
num_rows,
561-
)=ThriftBackend._convert_column_based_set_to_arrow_table(
562-
t_row_set.columns,description
563-
)
567+
)=convert_column_based_set_to_arrow_table(t_row_set.columns,description)
564568
elift_row_set.arrowBatchesisnotNone:
565-
(
566-
arrow_table,
567-
num_rows,
568-
)=ThriftBackend._convert_arrow_based_set_to_arrow_table(
569+
(arrow_table,num_rows,)=convert_arrow_based_set_to_arrow_table(
569570
t_row_set.arrowBatches,lz4_compressed,schema_bytes
570571
)
571572
else:
572573
raiseOperationalError("Unsupported TRowSet instance {}".format(t_row_set))
573-
returnself._convert_decimals_in_arrow_table(arrow_table,description),num_rows
574-
575-
@staticmethod
576-
def_convert_decimals_in_arrow_table(table,description):
577-
for (i,col)inenumerate(table.itercolumns()):
578-
ifdescription[i][1]=="decimal":
579-
decimal_col=col.to_pandas().apply(
580-
lambdav:vifvisNoneelseDecimal(v)
581-
)
582-
precision,scale=description[i][4],description[i][5]
583-
assertscaleisnotNone
584-
assertprecisionisnotNone
585-
# Spark limits decimal to a maximum scale of 38,
586-
# so 128 is guaranteed to be big enough
587-
dtype=pyarrow.decimal128(precision,scale)
588-
col_data=pyarrow.array(decimal_col,type=dtype)
589-
field=table.field(i).with_type(dtype)
590-
table=table.set_column(i,field,col_data)
591-
returntable
592-
593-
@staticmethod
594-
def_convert_arrow_based_set_to_arrow_table(
595-
arrow_batches,lz4_compressed,schema_bytes
596-
):
597-
ba=bytearray()
598-
ba+=schema_bytes
599-
n_rows=0
600-
iflz4_compressed:
601-
forarrow_batchinarrow_batches:
602-
n_rows+=arrow_batch.rowCount
603-
ba+=lz4.frame.decompress(arrow_batch.batch)
604-
else:
605-
forarrow_batchinarrow_batches:
606-
n_rows+=arrow_batch.rowCount
607-
ba+=arrow_batch.batch
608-
arrow_table=pyarrow.ipc.open_stream(ba).read_all()
609-
returnarrow_table,n_rows
610-
611-
@staticmethod
612-
def_convert_column_based_set_to_arrow_table(columns,description):
613-
arrow_table=pyarrow.Table.from_arrays(
614-
[ThriftBackend._convert_column_to_arrow_array(c)forcincolumns],
615-
# Only use the column names from the schema, the types are determined by the
616-
# physical types used in column based set, as they can differ from the
617-
# mapping used in _hive_schema_to_arrow_schema.
618-
names=[c[0]forcindescription],
619-
)
620-
returnarrow_table,arrow_table.num_rows
621-
622-
@staticmethod
623-
def_convert_column_to_arrow_array(t_col):
624-
"""
625-
Return a pyarrow array from the values in a TColumn instance.
626-
Note that ColumnBasedSet has no native support for complex types, so they will be converted
627-
to strings server-side.
628-
"""
629-
field_name_to_arrow_type= {
630-
"boolVal":pyarrow.bool_(),
631-
"byteVal":pyarrow.int8(),
632-
"i16Val":pyarrow.int16(),
633-
"i32Val":pyarrow.int32(),
634-
"i64Val":pyarrow.int64(),
635-
"doubleVal":pyarrow.float64(),
636-
"stringVal":pyarrow.string(),
637-
"binaryVal":pyarrow.binary(),
638-
}
639-
forfieldinfield_name_to_arrow_type.keys():
640-
wrapper=getattr(t_col,field)
641-
ifwrapper:
642-
returnThriftBackend._create_arrow_array(
643-
wrapper,field_name_to_arrow_type[field]
644-
)
645-
646-
raiseOperationalError("Empty TColumn instance {}".format(t_col))
647-
648-
@staticmethod
649-
def_create_arrow_array(t_col_value_wrapper,arrow_type):
650-
result=t_col_value_wrapper.values
651-
nulls=t_col_value_wrapper.nulls# bitfield describing which values are null
652-
assertisinstance(nulls,bytes)
653-
654-
# The number of bits in nulls can be both larger or smaller than the number of
655-
# elements in result, so take the minimum of both to iterate over.
656-
length=min(len(result),len(nulls)*8)
657-
658-
foriinrange(length):
659-
ifnulls[i>>3]&ThriftBackend.BIT_MASKS[i&0x7]:
660-
result[i]=None
661-
662-
returnpyarrow.array(result,type=arrow_type)
574+
returnconvert_decimals_in_arrow_table(arrow_table,description),num_rows
663575

664576
def_get_metadata_resp(self,op_handle):
665577
req=ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
@@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752664
ift_result_set_metadata_resp.resultFormatnotin [
753665
ttypes.TSparkRowSetType.ARROW_BASED_SET,
754666
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
667+
ttypes.TSparkRowSetType.URL_BASED_SET,
755668
]:
756669
raiseOperationalError(
757670
"Expected results to be in Arrow or column based format, "
@@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
783696
assertdirect_results.resultSet.results.startRowOffset==0
784697
assertdirect_results.resultSetMetadata
785698

786-
arrow_results,n_rows=self._create_arrow_table(
787-
direct_results.resultSet.results,
788-
lz4_compressed,
789-
schema_bytes,
790-
description,
699+
arrow_queue_opt=ResultSetQueueFactory.build_queue(
700+
row_set_type=t_result_set_metadata_resp.resultFormat,
701+
t_row_set=direct_results.resultSet.results,
702+
arrow_schema_bytes=schema_bytes,
703+
max_download_threads=self.max_download_threads,
704+
lz4_compressed=lz4_compressed,
705+
description=description,
791706
)
792-
arrow_queue_opt=ArrowQueue(arrow_results,n_rows,0)
793707
else:
794708
arrow_queue_opt=None
795709
returnExecuteResponse(
@@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
843757
)
844758

845759
defexecute_command(
846-
self,operation,session_handle,max_rows,max_bytes,lz4_compression,cursor
760+
self,
761+
operation,
762+
session_handle,
763+
max_rows,
764+
max_bytes,
765+
lz4_compression,
766+
cursor,
767+
use_cloud_fetch=False,
847768
):
848769
assertsession_handleisnotNone
849770

@@ -864,7 +785,7 @@ def execute_command(
864785
),
865786
canReadArrowResult=True,
866787
canDecompressLZ4Result=lz4_compression,
867-
canDownloadResult=False,
788+
canDownloadResult=use_cloud_fetch,
868789
confOverlay={
869790
# We want to receive proper Timestamp arrow types.
870791
"spark.thriftserver.arrowBasedRowSet.timestampAsString":"false"
@@ -993,6 +914,7 @@ def fetch_results(
993914
maxRows=max_rows,
994915
maxBytes=max_bytes,
995916
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
917+
includeResultSetMetadata=True,
996918
)
997919

998920
resp=self.make_request(self._client.FetchResults,req)
@@ -1002,12 +924,17 @@ def fetch_results(
1002924
expected_row_start_offset,resp.results.startRowOffset
1003925
)
1004926
)
1005-
arrow_results,n_rows=self._create_arrow_table(
1006-
resp.results,lz4_compressed,arrow_schema_bytes,description
927+
928+
queue=ResultSetQueueFactory.build_queue(
929+
row_set_type=resp.resultSetMetadata.resultFormat,
930+
t_row_set=resp.results,
931+
arrow_schema_bytes=arrow_schema_bytes,
932+
max_download_threads=self.max_download_threads,
933+
lz4_compressed=lz4_compressed,
934+
description=description,
1007935
)
1008-
arrow_queue=ArrowQueue(arrow_results,n_rows)
1009936

1010-
returnarrow_queue,resp.hasMoreRows
937+
returnqueue,resp.hasMoreRows
1011938

1012939
defclose_command(self,op_handle):
1013940
req=ttypes.TCloseOperationReq(operationHandle=op_handle)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp