44import math
55import time
66import threading
7+ import lz4 .frame
78from ssl import CERT_NONE ,CERT_REQUIRED ,create_default_context
89
910import pyarrow
@@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
451452initial_namespace = None
452453
453454open_session_req = ttypes .TOpenSessionReq (
454- client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V5 ,
455+ client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V6 ,
455456client_protocol = None ,
456457initialNamespace = initial_namespace ,
457458canUseMultipleCatalogs = True ,
@@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
507508 )
508509return self .make_request (self ._client .GetOperationStatus ,req )
509510
510- def _create_arrow_table (self ,t_row_set ,schema_bytes ,description ):
511+ def _create_arrow_table (self ,t_row_set ,lz4_compressed , schema_bytes ,description ):
511512if t_row_set .columns is not None :
512513 (
513514arrow_table ,
@@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
520521arrow_table ,
521522num_rows ,
522523 )= ThriftBackend ._convert_arrow_based_set_to_arrow_table (
523- t_row_set .arrowBatches ,schema_bytes
524+ t_row_set .arrowBatches ,lz4_compressed , schema_bytes
524525 )
525526else :
526527raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
@@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
545546return table
546547
547548@staticmethod
548- def _convert_arrow_based_set_to_arrow_table (arrow_batches ,schema_bytes ):
549+ def _convert_arrow_based_set_to_arrow_table (
550+ arrow_batches ,lz4_compressed ,schema_bytes
551+ ):
549552ba = bytearray ()
550553ba += schema_bytes
551554n_rows = 0
552- for arrow_batch in arrow_batches :
553- n_rows += arrow_batch .rowCount
554- ba += arrow_batch .batch
555+ if lz4_compressed :
556+ for arrow_batch in arrow_batches :
557+ n_rows += arrow_batch .rowCount
558+ ba += lz4 .frame .decompress (arrow_batch .batch )
559+ else :
560+ for arrow_batch in arrow_batches :
561+ n_rows += arrow_batch .rowCount
562+ ba += arrow_batch .batch
555563arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
556564return arrow_table ,n_rows
557565
@@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
708716 ]
709717 )
710718 )
711-
712719direct_results = resp .directResults
713720has_been_closed_server_side = direct_results and direct_results .closeOperation
714721has_more_rows = (
@@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
725732 .serialize ()
726733 .to_pybytes ()
727734 )
728-
735+ lz4_compressed = t_result_set_metadata_resp . lz4Compressed
729736if direct_results and direct_results .resultSet :
730737assert direct_results .resultSet .results .startRowOffset == 0
731738assert direct_results .resultSetMetadata
739+
732740arrow_results ,n_rows = self ._create_arrow_table (
733- direct_results .resultSet .results ,schema_bytes ,description
741+ direct_results .resultSet .results ,
742+ lz4_compressed ,
743+ schema_bytes ,
744+ description ,
734745 )
735746arrow_queue_opt = ArrowQueue (arrow_results ,n_rows ,0 )
736747else :
@@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
740751status = operation_state ,
741752has_been_closed_server_side = has_been_closed_server_side ,
742753has_more_rows = has_more_rows ,
754+ lz4_compressed = lz4_compressed ,
743755command_handle = resp .operationHandle ,
744756description = description ,
745757arrow_schema_bytes = schema_bytes ,
@@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
783795t_spark_direct_results .closeOperation
784796 )
785797
786- def execute_command (self ,operation ,session_handle ,max_rows ,max_bytes ,cursor ):
798+ def execute_command (
799+ self ,operation ,session_handle ,max_rows ,max_bytes ,lz4_compression ,cursor
800+ ):
787801assert session_handle is not None
788802
789803spark_arrow_types = ttypes .TSparkArrowTypes (
@@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
802816maxRows = max_rows ,maxBytes = max_bytes
803817 ),
804818canReadArrowResult = True ,
805- canDecompressLZ4Result = False ,
819+ canDecompressLZ4Result = lz4_compression ,
806820canDownloadResult = False ,
807821confOverlay = {
808822# We want to receive proper Timestamp arrow types.
@@ -916,6 +930,7 @@ def fetch_results(
916930max_rows ,
917931max_bytes ,
918932expected_row_start_offset ,
933+ lz4_compressed ,
919934arrow_schema_bytes ,
920935description ,
921936 ):
@@ -941,7 +956,7 @@ def fetch_results(
941956 )
942957 )
943958arrow_results ,n_rows = self ._create_arrow_table (
944- resp .results ,arrow_schema_bytes ,description
959+ resp .results ,lz4_compressed , arrow_schema_bytes ,description
945960 )
946961arrow_queue = ArrowQueue (arrow_results ,n_rows )
947962