Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitca562ab

Browse files
NiallEgansusodapop
authored andcommitted
Use Arrow schema if available
This PR changes the Python client to use the Arrow schema if it has been sent by the server, instead of re-constructing an approximation from the Hive schema.The primary difference is in the timezone information for timestamps* Added new unit tests to check the correct field is used* Adapted integration tests to add timezones as appropriate
1 parent963d5b0 commitca562ab

File tree

6 files changed

+90
-48
lines changed

6 files changed

+90
-48
lines changed

‎cmdexec/clients/python/dev_requirements.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ thrift==0.13.0
66
pandas==1.3.4
77
future==0.18.2
88
packaging==21.3
9+
pytz==2021.3

‎cmdexec/clients/python/src/databricks/sql/client.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def __init__(self,
480480
self.arraysize=arraysize
481481
self.thrift_backend=thrift_backend
482482
self.description=execute_response.description
483-
self._arrow_schema=execute_response.arrow_schema
483+
self._arrow_schema_bytes=execute_response.arrow_schema_bytes
484484
self._next_row_index=0
485485

486486
ifexecute_response.arrow_queue:
@@ -505,7 +505,7 @@ def _fill_results_buffer(self):
505505
max_rows=self.arraysize,
506506
max_bytes=self.buffer_size_bytes,
507507
expected_row_start_offset=self._next_row_index,
508-
arrow_schema=self._arrow_schema,
508+
arrow_schema_bytes=self._arrow_schema_bytes,
509509
description=self.description)
510510
self.results=results
511511
self.has_more_rows=has_more_rows

‎cmdexec/clients/python/src/databricks/sql/thrift_backend.py‎

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def open_session(self, session_configuration, catalog, schema):
330330
initial_namespace=None
331331

332332
open_session_req=ttypes.TOpenSessionReq(
333-
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
333+
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
334334
client_protocol=None,
335335
initialNamespace=initial_namespace,
336336
canUseMultipleCatalogs=True,
@@ -376,13 +376,13 @@ def _poll_for_status(self, op_handle):
376376
)
377377
returnself.make_request(self._client.GetOperationStatus,req)
378378

379-
def_create_arrow_table(self,t_row_set,arrow_schema,description):
379+
def_create_arrow_table(self,t_row_set,schema_bytes,description):
380380
ift_row_set.columnsisnotNone:
381381
arrow_table,num_rows=ThriftBackend._convert_column_based_set_to_arrow_table(
382-
t_row_set.columns,arrow_schema)
382+
t_row_set.columns,description)
383383
elift_row_set.arrowBatchesisnotNone:
384384
arrow_table,num_rows=ThriftBackend._convert_arrow_based_set_to_arrow_table(
385-
t_row_set.arrowBatches,arrow_schema)
385+
t_row_set.arrowBatches,schema_bytes)
386386
else:
387387
raiseOperationalError("Unsupported TRowSet instance {}".format(t_row_set))
388388
returnself._convert_decimals_in_arrow_table(arrow_table,description),num_rows
@@ -404,9 +404,8 @@ def _convert_decimals_in_arrow_table(table, description):
404404
returntable
405405

406406
@staticmethod
407-
def_convert_arrow_based_set_to_arrow_table(arrow_batches,schema):
407+
def_convert_arrow_based_set_to_arrow_table(arrow_batches,schema_bytes):
408408
ba=bytearray()
409-
schema_bytes=schema.serialize().to_pybytes()
410409
ba+=schema_bytes
411410
n_rows=0
412411
forarrow_batchinarrow_batches:
@@ -416,13 +415,13 @@ def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema):
416415
returnarrow_table,n_rows
417416

418417
@staticmethod
419-
def_convert_column_based_set_to_arrow_table(columns,schema):
418+
def_convert_column_based_set_to_arrow_table(columns,description):
420419
arrow_table=pyarrow.Table.from_arrays(
421420
[ThriftBackend._convert_column_to_arrow_array(c)forcincolumns],
422421
# Only use the column names from the schema, the types are determined by the
423422
# physical types used in column based set, as they can differ from the
424423
# mapping used in _hive_schema_to_arrow_schema.
425-
names=[c.nameforcinschema])
424+
names=[c[0]forcindescription])
426425
returnarrow_table,arrow_table.num_rows
427426

428427
@staticmethod
@@ -555,13 +554,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
555554
has_more_rows= (notdirect_results)or (notdirect_results.resultSet) \
556555
ordirect_results.resultSet.hasMoreRows
557556
description=self._hive_schema_to_description(t_result_set_metadata_resp.schema)
558-
arrow_schema=self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
557+
schema_bytes= (t_result_set_metadata_resp.arrowSchemaorself._hive_schema_to_arrow_schema(
558+
t_result_set_metadata_resp.schema).serialize().to_pybytes())
559559

560560
ifdirect_resultsanddirect_results.resultSet:
561561
assert (direct_results.resultSet.results.startRowOffset==0)
562562
assert (direct_results.resultSetMetadata)
563563
arrow_results,n_rows=self._create_arrow_table(direct_results.resultSet.results,
564-
arrow_schema,description)
564+
schema_bytes,description)
565565
arrow_queue_opt=ArrowQueue(arrow_results,n_rows,0)
566566
else:
567567
arrow_queue_opt=None
@@ -572,7 +572,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
572572
has_more_rows=has_more_rows,
573573
command_handle=resp.operationHandle,
574574
description=description,
575-
arrow_schema=arrow_schema)
575+
arrow_schema_bytes=schema_bytes)
576576

577577
def_wait_until_command_done(self,op_handle,initial_operation_status_resp):
578578
ifinitial_operation_status_resp:
@@ -697,8 +697,8 @@ def _handle_execute_response(self, resp, cursor):
697697

698698
returnself._results_message_to_execute_response(resp,final_operation_state)
699699

700-
deffetch_results(self,op_handle,max_rows,max_bytes,expected_row_start_offset,arrow_schema,
701-
description):
700+
deffetch_results(self,op_handle,max_rows,max_bytes,expected_row_start_offset,
701+
arrow_schema_bytes,description):
702702
assert (op_handleisnotNone)
703703

704704
req=ttypes.TFetchResultsReq(
@@ -716,7 +716,8 @@ def fetch_results(self, op_handle, max_rows, max_bytes, expected_row_start_offse
716716
ifresp.results.startRowOffset>expected_row_start_offset:
717717
logger.warning("Expected results to start from {} but they instead start at {}".format(
718718
expected_row_start_offset,resp.results.startRowOffset))
719-
arrow_results,n_rows=self._create_arrow_table(resp.results,arrow_schema,description)
719+
arrow_results,n_rows=self._create_arrow_table(resp.results,arrow_schema_bytes,
720+
description)
720721
arrow_queue=ArrowQueue(arrow_results,n_rows)
721722

722723
returnarrow_queue,resp.hasMoreRows

‎cmdexec/clients/python/src/databricks/sql/utils.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def remaining_rows(self) -> pyarrow.Table:
3535

3636
ExecuteResponse=namedtuple(
3737
'ExecuteResponse','status has_been_closed_server_side has_more_rows description '
38-
'command_handle arrow_queuearrow_schema')
38+
'command_handle arrow_queuearrow_schema_bytes')
3939

4040

4141
def_bound(min_x,max_x,x):

‎cmdexec/clients/python/tests/test_fetches.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
4242
description=Mock(),
4343
command_handle=None,
4444
arrow_queue=arrow_queue,
45-
arrow_schema=schema))
45+
arrow_schema_bytes=schema.serialize().to_pybytes()))
4646
num_cols=len(initial_results[0])ifinitial_resultselse0
4747
rs.description= [(f'col{col_id}','integer',None,None,None,None,None)
4848
forcol_idinrange(num_cols)]
@@ -52,8 +52,8 @@ def make_dummy_result_set_from_initial_results(initial_results):
5252
defmake_dummy_result_set_from_batch_list(batch_list):
5353
batch_index=0
5454

55-
deffetch_results(op_handle,max_rows,max_bytes,expected_row_start_offset,arrow_schema,
56-
description):
55+
deffetch_results(op_handle,max_rows,max_bytes,expected_row_start_offset,
56+
arrow_schema_bytes,description):
5757
nonlocalbatch_index
5858
results=FetchTests.make_arrow_queue(batch_list[batch_index])
5959
batch_index+=1
@@ -75,7 +75,7 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, arr
7575
forcol_idinrange(num_cols)],
7676
command_handle=None,
7777
arrow_queue=None,
78-
arrow_schema=None))
78+
arrow_schema_bytes=None))
7979
returnrs
8080

8181
defassertEqualRowValues(self,actual,expected):

‎cmdexec/clients/python/tests/test_thrift_backend.py‎

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,54 @@ def test_handle_execute_response_can_handle_with_direct_results(self):
497497
ttypes.TOperationState.FINISHED_STATE,
498498
)
499499

500+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
501+
deftest_use_arrow_schema_if_available(self,tcli_service_class):
502+
tcli_service_instance=tcli_service_class.return_value
503+
arrow_schema_mock=MagicMock(name="Arrow schema mock")
504+
hive_schema_mock=MagicMock(name="Hive schema mock")
505+
506+
t_get_result_set_metadata_resp=ttypes.TGetResultSetMetadataResp(
507+
status=self.okay_status,
508+
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
509+
schema=hive_schema_mock,
510+
arrowSchema=arrow_schema_mock)
511+
512+
t_execute_resp=ttypes.TExecuteStatementResp(
513+
status=self.okay_status,
514+
directResults=None,
515+
operationHandle=self.operation_handle,
516+
)
517+
518+
tcli_service_instance.GetResultSetMetadata.return_value=t_get_result_set_metadata_resp
519+
thrift_backend=self._make_fake_thrift_backend()
520+
execute_response=thrift_backend._handle_execute_response(t_execute_resp,Mock())
521+
522+
self.assertEqual(execute_response.arrow_schema_bytes,arrow_schema_mock)
523+
524+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
525+
deftest_fall_back_to_hive_schema_if_no_arrow_schema(self,tcli_service_class):
526+
tcli_service_instance=tcli_service_class.return_value
527+
hive_schema_mock=MagicMock(name="Hive schema mock")
528+
529+
hive_schema_req=ttypes.TGetResultSetMetadataResp(
530+
status=self.okay_status,
531+
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
532+
arrowSchema=None,
533+
schema=hive_schema_mock)
534+
535+
t_execute_resp=ttypes.TExecuteStatementResp(
536+
status=self.okay_status,
537+
directResults=None,
538+
operationHandle=self.operation_handle,
539+
)
540+
541+
tcli_service_instance.GetResultSetMetadata.return_value=hive_schema_req
542+
thrift_backend=self._make_fake_thrift_backend()
543+
thrift_backend._handle_execute_response(t_execute_resp,Mock())
544+
545+
self.assertEqual(hive_schema_mock,
546+
thrift_backend._hive_schema_to_arrow_schema.call_args[0][0])
547+
500548
@patch("databricks.sql.thrift_backend.TCLIService.Client")
501549
deftest_handle_execute_response_reads_has_more_rows_in_direct_results(
502550
self,tcli_service_class):
@@ -567,7 +615,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
567615
max_rows=1,
568616
max_bytes=1,
569617
expected_row_start_offset=0,
570-
arrow_schema=Mock(),
618+
arrow_schema_bytes=Mock(),
571619
description=Mock())
572620

573621
self.assertEqual(has_more_rows,has_more_rows_resp)
@@ -591,15 +639,15 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
591639
pyarrow.field("column2",pyarrow.string()),
592640
pyarrow.field("column3",pyarrow.float64()),
593641
pyarrow.field("column3",pyarrow.binary())
594-
])
642+
]).serialize().to_pybytes()
595643

596644
thrift_backend=ThriftBackend("foobar",443,"path", [])
597645
arrow_queue,has_more_results=thrift_backend.fetch_results(
598646
op_handle=Mock(),
599647
max_rows=1,
600648
max_bytes=1,
601649
expected_row_start_offset=0,
602-
arrow_schema=schema,
650+
arrow_schema_bytes=schema,
603651
description=MagicMock())
604652

605653
self.assertEqual(arrow_queue.n_valid_rows,15*10)
@@ -792,24 +840,21 @@ def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mo
792840
schema=Mock()
793841
cols=Mock()
794842
arrow_batches=Mock()
843+
description=Mock()
795844

796845
t_col_set=ttypes.TRowSet(columns=cols)
797-
thrift_backend._create_arrow_table(t_col_set,schema,Mock())
846+
thrift_backend._create_arrow_table(t_col_set,schema,description)
798847
convert_arrow_mock.assert_not_called()
799-
convert_col_mock.assert_called_once_with(cols,schema)
848+
convert_col_mock.assert_called_once_with(cols,description)
800849

801850
t_arrow_set=ttypes.TRowSet(arrowBatches=arrow_batches)
802851
thrift_backend._create_arrow_table(t_arrow_set,schema,Mock())
803852
convert_arrow_mock.assert_called_once_with(arrow_batches,schema)
804-
convert_col_mock.assert_called_once_with(cols,schema)
805853

806854
deftest_convert_column_based_set_to_arrow_table_without_nulls(self):
807-
schema=pyarrow.schema([
808-
pyarrow.field("column1",pyarrow.int32()),
809-
pyarrow.field("column2",pyarrow.string()),
810-
pyarrow.field("column3",pyarrow.float64()),
811-
pyarrow.field("column3",pyarrow.binary())
812-
])
855+
# Deliberately duplicate the column name to check that dups work
856+
field_names= ["column1","column2","column3","column3"]
857+
description= [(name, )fornameinfield_names]
813858

814859
t_cols= [
815860
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1,2,3],nulls=bytes(1))),
@@ -820,7 +865,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
820865
binaryVal=ttypes.TBinaryColumn(values=[b'\x11',b'\x22',b'\x33'],nulls=bytes(1)))
821866
]
822867

823-
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(t_cols,schema)
868+
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(
869+
t_cols,description)
824870
self.assertEqual(n_rows,3)
825871

826872
# Check schema, column names and types
@@ -841,12 +887,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
841887
self.assertEqual(arrow_table.column(3).to_pylist(), [b'\x11',b'\x22',b'\x33'])
842888

843889
deftest_convert_column_based_set_to_arrow_table_with_nulls(self):
844-
schema=pyarrow.schema([
845-
pyarrow.field("column1",pyarrow.int32()),
846-
pyarrow.field("column2",pyarrow.string()),
847-
pyarrow.field("column3",pyarrow.float64()),
848-
pyarrow.field("column3",pyarrow.binary())
849-
])
890+
field_names= ["column1","column2","column3","column3"]
891+
description= [(name, )fornameinfield_names]
850892

851893
t_cols= [
852894
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1,2,3],nulls=bytes([1]))),
@@ -859,7 +901,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
859901
values=[b'\x11',b'\x22',b'\x33'],nulls=bytes([3])))
860902
]
861903

862-
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(t_cols,schema)
904+
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(
905+
t_cols,description)
863906
self.assertEqual(n_rows,3)
864907

865908
# Check data
@@ -869,12 +912,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
869912
self.assertEqual(arrow_table.column(3).to_pylist(), [None,None,b'\x33'])
870913

871914
deftest_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
872-
schema=pyarrow.schema([
873-
pyarrow.field("column1",pyarrow.string()),
874-
pyarrow.field("column2",pyarrow.string()),
875-
pyarrow.field("column3",pyarrow.string()),
876-
pyarrow.field("column3",pyarrow.string())
877-
])
915+
field_names= ["column1","column2","column3","column3"]
916+
description= [(name, )fornameinfield_names]
878917

879918
t_cols= [
880919
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1,2,3],nulls=bytes(1))),
@@ -885,7 +924,8 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
885924
binaryVal=ttypes.TBinaryColumn(values=[b'\x11',b'\x22',b'\x33'],nulls=bytes(1)))
886925
]
887926

888-
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(t_cols,schema)
927+
arrow_table,n_rows=ThriftBackend._convert_column_based_set_to_arrow_table(
928+
t_cols,description)
889929
self.assertEqual(n_rows,3)
890930

891931
# Check schema, column names and types

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp