55import time
66import uuid
77import threading
8- import lz4 .frame
98from ssl import CERT_NONE ,CERT_REQUIRED ,create_default_context
109from typing import List ,Union
1110
2625)
2726
2827from databricks .sql .utils import (
29- ArrowQueue ,
3028ExecuteResponse ,
3129_bound ,
3230RequestErrorInfo ,
3331NoRetryReason ,
32+ ResultSetQueueFactory ,
33+ convert_arrow_based_set_to_arrow_table ,
34+ convert_decimals_in_arrow_table ,
35+ convert_column_based_set_to_arrow_table ,
3436)
3537
3638logger = logging .getLogger (__name__ )
6769class ThriftBackend :
6870CLOSED_OP_STATE = ttypes .TOperationState .CLOSED_STATE
6971ERROR_OP_STATE = ttypes .TOperationState .ERROR_STATE
70- BIT_MASKS = [1 ,2 ,4 ,8 ,16 ,32 ,64 ,128 ]
7172
7273def __init__ (
7374self ,
@@ -115,6 +116,8 @@ def __init__(
115116# _socket_timeout
116117# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
117118# (defaults to 900)
119+ # max_download_threads
120+ # Number of threads for handling cloud fetch downloads. Defaults to 10
118121
119122port = port or 443
120123if kwargs .get ("_connection_uri" ):
@@ -136,6 +139,9 @@ def __init__(
136139"_use_arrow_native_timestamps" ,True
137140 )
138141
142+ # Cloud fetch
143+ self .max_download_threads = kwargs .get ("max_download_threads" ,10 )
144+
139145# Configure tls context
140146ssl_context = create_default_context (cafile = kwargs .get ("_tls_trusted_ca_file" ))
141147if kwargs .get ("_tls_no_verify" )is True :
@@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558564 (
559565arrow_table ,
560566num_rows ,
561- )= ThriftBackend ._convert_column_based_set_to_arrow_table (
562- t_row_set .columns ,description
563- )
567+ )= convert_column_based_set_to_arrow_table (t_row_set .columns ,description )
564568elif t_row_set .arrowBatches is not None :
565- (
566- arrow_table ,
567- num_rows ,
568- )= ThriftBackend ._convert_arrow_based_set_to_arrow_table (
569+ (arrow_table ,num_rows ,)= convert_arrow_based_set_to_arrow_table (
569570t_row_set .arrowBatches ,lz4_compressed ,schema_bytes
570571 )
571572else :
572573raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
573- return self ._convert_decimals_in_arrow_table (arrow_table ,description ),num_rows
574-
575- @staticmethod
576- def _convert_decimals_in_arrow_table (table ,description ):
577- for (i ,col )in enumerate (table .itercolumns ()):
578- if description [i ][1 ]== "decimal" :
579- decimal_col = col .to_pandas ().apply (
580- lambda v :v if v is None else Decimal (v )
581- )
582- precision ,scale = description [i ][4 ],description [i ][5 ]
583- assert scale is not None
584- assert precision is not None
585- # Spark limits decimal to a maximum scale of 38,
586- # so 128 is guaranteed to be big enough
587- dtype = pyarrow .decimal128 (precision ,scale )
588- col_data = pyarrow .array (decimal_col ,type = dtype )
589- field = table .field (i ).with_type (dtype )
590- table = table .set_column (i ,field ,col_data )
591- return table
592-
593- @staticmethod
594- def _convert_arrow_based_set_to_arrow_table (
595- arrow_batches ,lz4_compressed ,schema_bytes
596- ):
597- ba = bytearray ()
598- ba += schema_bytes
599- n_rows = 0
600- if lz4_compressed :
601- for arrow_batch in arrow_batches :
602- n_rows += arrow_batch .rowCount
603- ba += lz4 .frame .decompress (arrow_batch .batch )
604- else :
605- for arrow_batch in arrow_batches :
606- n_rows += arrow_batch .rowCount
607- ba += arrow_batch .batch
608- arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
609- return arrow_table ,n_rows
610-
611- @staticmethod
612- def _convert_column_based_set_to_arrow_table (columns ,description ):
613- arrow_table = pyarrow .Table .from_arrays (
614- [ThriftBackend ._convert_column_to_arrow_array (c )for c in columns ],
615- # Only use the column names from the schema, the types are determined by the
616- # physical types used in column based set, as they can differ from the
617- # mapping used in _hive_schema_to_arrow_schema.
618- names = [c [0 ]for c in description ],
619- )
620- return arrow_table ,arrow_table .num_rows
621-
622- @staticmethod
623- def _convert_column_to_arrow_array (t_col ):
624- """
625- Return a pyarrow array from the values in a TColumn instance.
626- Note that ColumnBasedSet has no native support for complex types, so they will be converted
627- to strings server-side.
628- """
629- field_name_to_arrow_type = {
630- "boolVal" :pyarrow .bool_ (),
631- "byteVal" :pyarrow .int8 (),
632- "i16Val" :pyarrow .int16 (),
633- "i32Val" :pyarrow .int32 (),
634- "i64Val" :pyarrow .int64 (),
635- "doubleVal" :pyarrow .float64 (),
636- "stringVal" :pyarrow .string (),
637- "binaryVal" :pyarrow .binary (),
638- }
639- for field in field_name_to_arrow_type .keys ():
640- wrapper = getattr (t_col ,field )
641- if wrapper :
642- return ThriftBackend ._create_arrow_array (
643- wrapper ,field_name_to_arrow_type [field ]
644- )
645-
646- raise OperationalError ("Empty TColumn instance {}" .format (t_col ))
647-
648- @staticmethod
649- def _create_arrow_array (t_col_value_wrapper ,arrow_type ):
650- result = t_col_value_wrapper .values
651- nulls = t_col_value_wrapper .nulls # bitfield describing which values are null
652- assert isinstance (nulls ,bytes )
653-
654- # The number of bits in nulls can be both larger or smaller than the number of
655- # elements in result, so take the minimum of both to iterate over.
656- length = min (len (result ),len (nulls )* 8 )
657-
658- for i in range (length ):
659- if nulls [i >> 3 ]& ThriftBackend .BIT_MASKS [i & 0x7 ]:
660- result [i ]= None
661-
662- return pyarrow .array (result ,type = arrow_type )
574+ return convert_decimals_in_arrow_table (arrow_table ,description ),num_rows
663575
664576def _get_metadata_resp (self ,op_handle ):
665577req = ttypes .TGetResultSetMetadataReq (operationHandle = op_handle )
@@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752664if t_result_set_metadata_resp .resultFormat not in [
753665ttypes .TSparkRowSetType .ARROW_BASED_SET ,
754666ttypes .TSparkRowSetType .COLUMN_BASED_SET ,
667+ ttypes .TSparkRowSetType .URL_BASED_SET ,
755668 ]:
756669raise OperationalError (
757670"Expected results to be in Arrow or column based format, "
@@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
783696assert direct_results .resultSet .results .startRowOffset == 0
784697assert direct_results .resultSetMetadata
785698
786- arrow_results ,n_rows = self ._create_arrow_table (
787- direct_results .resultSet .results ,
788- lz4_compressed ,
789- schema_bytes ,
790- description ,
699+ arrow_queue_opt = ResultSetQueueFactory .build_queue (
700+ row_set_type = t_result_set_metadata_resp .resultFormat ,
701+ t_row_set = direct_results .resultSet .results ,
702+ arrow_schema_bytes = schema_bytes ,
703+ max_download_threads = self .max_download_threads ,
704+ lz4_compressed = lz4_compressed ,
705+ description = description ,
791706 )
792- arrow_queue_opt = ArrowQueue (arrow_results ,n_rows ,0 )
793707else :
794708arrow_queue_opt = None
795709return ExecuteResponse (
@@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
843757 )
844758
845759def execute_command (
846- self ,operation ,session_handle ,max_rows ,max_bytes ,lz4_compression ,cursor
760+ self ,
761+ operation ,
762+ session_handle ,
763+ max_rows ,
764+ max_bytes ,
765+ lz4_compression ,
766+ cursor ,
767+ use_cloud_fetch = False ,
847768 ):
848769assert session_handle is not None
849770
@@ -864,7 +785,7 @@ def execute_command(
864785 ),
865786canReadArrowResult = True ,
866787canDecompressLZ4Result = lz4_compression ,
867- canDownloadResult = False ,
788+ canDownloadResult = use_cloud_fetch ,
868789confOverlay = {
869790# We want to receive proper Timestamp arrow types.
870791"spark.thriftserver.arrowBasedRowSet.timestampAsString" :"false"
@@ -993,6 +914,7 @@ def fetch_results(
993914maxRows = max_rows ,
994915maxBytes = max_bytes ,
995916orientation = ttypes .TFetchOrientation .FETCH_NEXT ,
917+ includeResultSetMetadata = True ,
996918 )
997919
998920resp = self .make_request (self ._client .FetchResults ,req )
@@ -1002,12 +924,17 @@ def fetch_results(
1002924expected_row_start_offset ,resp .results .startRowOffset
1003925 )
1004926 )
1005- arrow_results ,n_rows = self ._create_arrow_table (
1006- resp .results ,lz4_compressed ,arrow_schema_bytes ,description
927+
928+ queue = ResultSetQueueFactory .build_queue (
929+ row_set_type = resp .resultSetMetadata .resultFormat ,
930+ t_row_set = resp .results ,
931+ arrow_schema_bytes = arrow_schema_bytes ,
932+ max_download_threads = self .max_download_threads ,
933+ lz4_compressed = lz4_compressed ,
934+ description = description ,
1007935 )
1008- arrow_queue = ArrowQueue (arrow_results ,n_rows )
1009936
1010- return arrow_queue ,resp .hasMoreRows
937+ return queue ,resp .hasMoreRows
1011938
1012939def close_command (self ,op_handle ):
1013940req = ttypes .TCloseOperationReq (operationHandle = op_handle )