Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf99123c

Browse files
[SC-110400] Enabling compression in Python SQL Connector (#49)
Signed-off-by: Mohit Singla <mohit.singla@databricks.com>Co-authored-by: Moe Derakhshani <moe.derakhshani@databricks.com>
1 parent2e681b5 commitf99123c

File tree

9 files changed

+325
-57
lines changed

9 files changed

+325
-57
lines changed

‎poetry.lock‎

Lines changed: 207 additions & 23 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pyproject.toml‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = "^3.7.1"
1313
thrift ="^0.13.0"
1414
pandas ="^1.3.0"
1515
pyarrow ="^9.0.0"
16+
lz4 ="^4.0.2"
1617
requests=">2.18.1"
1718
oauthlib=">=3.1.0"
1819

‎src/databricks/sql/client.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def read(self) -> Optional[OAuthToken]:
152152
self.host=server_hostname
153153
self.port=kwargs.get("_port",443)
154154
self.disable_pandas=kwargs.get("_disable_pandas",False)
155+
self.lz4_compression=kwargs.get("enable_query_result_lz4_compression",True)
155156

156157
auth_provider=get_python_sql_connector_auth_provider(
157158
server_hostname,**kwargs
@@ -318,6 +319,7 @@ def execute(
318319
session_handle=self.connection._session_handle,
319320
max_rows=self.arraysize,
320321
max_bytes=self.buffer_size_bytes,
322+
lz4_compression=self.connection.lz4_compression,
321323
cursor=self,
322324
)
323325
self.active_result_set=ResultSet(
@@ -614,6 +616,7 @@ def __init__(
614616
self.has_been_closed_server_side=execute_response.has_been_closed_server_side
615617
self.has_more_rows=execute_response.has_more_rows
616618
self.buffer_size_bytes=result_buffer_size_bytes
619+
self.lz4_compressed=execute_response.lz4_compressed
617620
self.arraysize=arraysize
618621
self.thrift_backend=thrift_backend
619622
self.description=execute_response.description
@@ -642,6 +645,7 @@ def _fill_results_buffer(self):
642645
max_rows=self.arraysize,
643646
max_bytes=self.buffer_size_bytes,
644647
expected_row_start_offset=self._next_row_index,
648+
lz4_compressed=self.lz4_compressed,
645649
arrow_schema_bytes=self._arrow_schema_bytes,
646650
description=self.description,
647651
)

‎src/databricks/sql/thrift_backend.py‎

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
importmath
55
importtime
66
importthreading
7+
importlz4.frame
78
fromsslimportCERT_NONE,CERT_REQUIRED,create_default_context
89

910
importpyarrow
@@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
451452
initial_namespace=None
452453

453454
open_session_req=ttypes.TOpenSessionReq(
454-
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
455+
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6,
455456
client_protocol=None,
456457
initialNamespace=initial_namespace,
457458
canUseMultipleCatalogs=True,
@@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
507508
)
508509
returnself.make_request(self._client.GetOperationStatus,req)
509510

510-
def_create_arrow_table(self,t_row_set,schema_bytes,description):
511+
def_create_arrow_table(self,t_row_set,lz4_compressed,schema_bytes,description):
511512
ift_row_set.columnsisnotNone:
512513
(
513514
arrow_table,
@@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
520521
arrow_table,
521522
num_rows,
522523
)=ThriftBackend._convert_arrow_based_set_to_arrow_table(
523-
t_row_set.arrowBatches,schema_bytes
524+
t_row_set.arrowBatches,lz4_compressed,schema_bytes
524525
)
525526
else:
526527
raiseOperationalError("Unsupported TRowSet instance {}".format(t_row_set))
@@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
545546
returntable
546547

547548
@staticmethod
548-
def_convert_arrow_based_set_to_arrow_table(arrow_batches,schema_bytes):
549+
def_convert_arrow_based_set_to_arrow_table(
550+
arrow_batches,lz4_compressed,schema_bytes
551+
):
549552
ba=bytearray()
550553
ba+=schema_bytes
551554
n_rows=0
552-
forarrow_batchinarrow_batches:
553-
n_rows+=arrow_batch.rowCount
554-
ba+=arrow_batch.batch
555+
iflz4_compressed:
556+
forarrow_batchinarrow_batches:
557+
n_rows+=arrow_batch.rowCount
558+
ba+=lz4.frame.decompress(arrow_batch.batch)
559+
else:
560+
forarrow_batchinarrow_batches:
561+
n_rows+=arrow_batch.rowCount
562+
ba+=arrow_batch.batch
555563
arrow_table=pyarrow.ipc.open_stream(ba).read_all()
556564
returnarrow_table,n_rows
557565

@@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
708716
]
709717
)
710718
)
711-
712719
direct_results=resp.directResults
713720
has_been_closed_server_side=direct_resultsanddirect_results.closeOperation
714721
has_more_rows= (
@@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
725732
.serialize()
726733
.to_pybytes()
727734
)
728-
735+
lz4_compressed=t_result_set_metadata_resp.lz4Compressed
729736
ifdirect_resultsanddirect_results.resultSet:
730737
assertdirect_results.resultSet.results.startRowOffset==0
731738
assertdirect_results.resultSetMetadata
739+
732740
arrow_results,n_rows=self._create_arrow_table(
733-
direct_results.resultSet.results,schema_bytes,description
741+
direct_results.resultSet.results,
742+
lz4_compressed,
743+
schema_bytes,
744+
description,
734745
)
735746
arrow_queue_opt=ArrowQueue(arrow_results,n_rows,0)
736747
else:
@@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
740751
status=operation_state,
741752
has_been_closed_server_side=has_been_closed_server_side,
742753
has_more_rows=has_more_rows,
754+
lz4_compressed=lz4_compressed,
743755
command_handle=resp.operationHandle,
744756
description=description,
745757
arrow_schema_bytes=schema_bytes,
@@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
783795
t_spark_direct_results.closeOperation
784796
)
785797

786-
defexecute_command(self,operation,session_handle,max_rows,max_bytes,cursor):
798+
defexecute_command(
799+
self,operation,session_handle,max_rows,max_bytes,lz4_compression,cursor
800+
):
787801
assertsession_handleisnotNone
788802

789803
spark_arrow_types=ttypes.TSparkArrowTypes(
@@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
802816
maxRows=max_rows,maxBytes=max_bytes
803817
),
804818
canReadArrowResult=True,
805-
canDecompressLZ4Result=False,
819+
canDecompressLZ4Result=lz4_compression,
806820
canDownloadResult=False,
807821
confOverlay={
808822
# We want to receive proper Timestamp arrow types.
@@ -916,6 +930,7 @@ def fetch_results(
916930
max_rows,
917931
max_bytes,
918932
expected_row_start_offset,
933+
lz4_compressed,
919934
arrow_schema_bytes,
920935
description,
921936
):
@@ -941,7 +956,7 @@ def fetch_results(
941956
)
942957
)
943958
arrow_results,n_rows=self._create_arrow_table(
944-
resp.results,arrow_schema_bytes,description
959+
resp.results,lz4_compressed,arrow_schema_bytes,description
945960
)
946961
arrow_queue=ArrowQueue(arrow_results,n_rows)
947962

‎src/databricks/sql/utils.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def remaining_rows(self) -> pyarrow.Table:
4040

4141
ExecuteResponse=namedtuple(
4242
"ExecuteResponse",
43-
"status has_been_closed_server_side has_more_rows description "
43+
"status has_been_closed_server_side has_more_rows descriptionlz4_compressed"
4444
"command_handle arrow_queue arrow_schema_bytes",
4545
)
4646

‎tests/e2e/common/large_queries_mixin.py‎

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,15 @@ def test_query_with_large_wide_result_set(self):
4949
# This is used by PyHive tests to determine the buffer size
5050
self.arraysize=1000
5151
withself.cursor()ascursor:
52-
uuids=", ".join(["uuid() uuid{}".format(i)foriinrange(cols)])
53-
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids,rows=rows))
54-
forrow_id,rowinenumerate(self.fetch_rows(cursor,rows,fetchmany_size)):
55-
self.assertEqual(row[0],row_id)# Verify no rows are dropped in the middle.
56-
self.assertEqual(len(row[1]),36)
52+
forlz4_compressionin [False,True]:
53+
cursor.connection.lz4_compression=lz4_compression
54+
uuids=", ".join(["uuid() uuid{}".format(i)foriinrange(cols)])
55+
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids,rows=rows))
56+
self.assertEqual(lz4_compression,cursor.active_result_set.lz4_compressed)
57+
forrow_id,rowinenumerate(self.fetch_rows(cursor,rows,fetchmany_size)):
58+
self.assertEqual(row[0],row_id)# Verify no rows are dropped in the middle.
59+
self.assertEqual(len(row[1]),36)
60+
5761

5862
deftest_query_with_large_narrow_result_set(self):
5963
resultSize=300*1000*1000# 300 MB
@@ -85,10 +89,10 @@ def test_long_running_query(self):
8589
start=time.time()
8690

8791
cursor.execute("""SELECT count(*)
88-
FROM RANGE({scale}) x
89-
JOIN RANGE({scale0}) y
90-
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
91-
""".format(scale=scale_factor*scale0,scale0=scale0))
92+
FROM RANGE({scale}) x
93+
JOIN RANGE({scale0}) y
94+
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
95+
""".format(scale=scale_factor*scale0,scale0=scale0))
9296

9397
n,=cursor.fetchone()
9498
self.assertEqual(n,0)

‎tests/e2e/driver_tests.py‎

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,20 @@ def test_timezone_with_timestamp(self):
510510
self.assertEqual(arrow_result_table.field(0).type,ts_type)
511511
self.assertEqual(arrow_result_value,expected.timestamp()*1000000)
512512

513+
@skipUnless(pysql_supports_arrow(),'arrow test needs arrow support')
514+
deftest_can_flip_compression(self):
515+
withself.cursor()ascursor:
516+
cursor.execute("SELECT array(1,2,3,4)")
517+
cursor.fetchall()
518+
lz4_compressed=cursor.active_result_set.lz4_compressed
519+
#The endpoint should support compression
520+
self.assertEqual(lz4_compressed,True)
521+
cursor.connection.lz4_compression=False
522+
cursor.execute("SELECT array(1,2,3,4)")
523+
cursor.fetchall()
524+
lz4_compressed=cursor.active_result_set.lz4_compressed
525+
self.assertEqual(lz4_compressed,False)
526+
513527
def_should_have_native_complex_types(self):
514528
returnpysql_has_version(">=",2)andis_thrift_v5_plus(self.arguments)
515529

‎tests/unit/test_fetches.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
3838
has_been_closed_server_side=True,
3939
has_more_rows=False,
4040
description=Mock(),
41+
lz4_compressed=Mock(),
4142
command_handle=None,
4243
arrow_queue=arrow_queue,
4344
arrow_schema_bytes=schema.serialize().to_pybytes()))
@@ -50,7 +51,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
5051
defmake_dummy_result_set_from_batch_list(batch_list):
5152
batch_index=0
5253

53-
deffetch_results(op_handle,max_rows,max_bytes,expected_row_start_offset,
54+
deffetch_results(op_handle,max_rows,max_bytes,expected_row_start_offset,lz4_compressed,
5455
arrow_schema_bytes,description):
5556
nonlocalbatch_index
5657
results=FetchTests.make_arrow_queue(batch_list[batch_index])
@@ -71,6 +72,7 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
7172
has_more_rows=True,
7273
description=[(f'col{col_id}','integer',None,None,None,None,None)
7374
forcol_idinrange(num_cols)],
75+
lz4_compressed=Mock(),
7476
command_handle=None,
7577
arrow_queue=None,
7678
arrow_schema_bytes=None))

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp