@@ -497,6 +497,54 @@ def test_handle_execute_response_can_handle_with_direct_results(self):
497497ttypes .TOperationState .FINISHED_STATE ,
498498 )
499499
500+ @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
501+ def test_use_arrow_schema_if_available (self ,tcli_service_class ):
502+ tcli_service_instance = tcli_service_class .return_value
503+ arrow_schema_mock = MagicMock (name = "Arrow schema mock" )
504+ hive_schema_mock = MagicMock (name = "Hive schema mock" )
505+
506+ t_get_result_set_metadata_resp = ttypes .TGetResultSetMetadataResp (
507+ status = self .okay_status ,
508+ resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
509+ schema = hive_schema_mock ,
510+ arrowSchema = arrow_schema_mock )
511+
512+ t_execute_resp = ttypes .TExecuteStatementResp (
513+ status = self .okay_status ,
514+ directResults = None ,
515+ operationHandle = self .operation_handle ,
516+ )
517+
518+ tcli_service_instance .GetResultSetMetadata .return_value = t_get_result_set_metadata_resp
519+ thrift_backend = self ._make_fake_thrift_backend ()
520+ execute_response = thrift_backend ._handle_execute_response (t_execute_resp ,Mock ())
521+
522+ self .assertEqual (execute_response .arrow_schema_bytes ,arrow_schema_mock )
523+
524+ @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
525+ def test_fall_back_to_hive_schema_if_no_arrow_schema (self ,tcli_service_class ):
526+ tcli_service_instance = tcli_service_class .return_value
527+ hive_schema_mock = MagicMock (name = "Hive schema mock" )
528+
529+ hive_schema_req = ttypes .TGetResultSetMetadataResp (
530+ status = self .okay_status ,
531+ resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
532+ arrowSchema = None ,
533+ schema = hive_schema_mock )
534+
535+ t_execute_resp = ttypes .TExecuteStatementResp (
536+ status = self .okay_status ,
537+ directResults = None ,
538+ operationHandle = self .operation_handle ,
539+ )
540+
541+ tcli_service_instance .GetResultSetMetadata .return_value = hive_schema_req
542+ thrift_backend = self ._make_fake_thrift_backend ()
543+ thrift_backend ._handle_execute_response (t_execute_resp ,Mock ())
544+
545+ self .assertEqual (hive_schema_mock ,
546+ thrift_backend ._hive_schema_to_arrow_schema .call_args [0 ][0 ])
547+
500548@patch ("databricks.sql.thrift_backend.TCLIService.Client" )
501549def test_handle_execute_response_reads_has_more_rows_in_direct_results (
502550self ,tcli_service_class ):
@@ -567,7 +615,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
567615max_rows = 1 ,
568616max_bytes = 1 ,
569617expected_row_start_offset = 0 ,
570- arrow_schema = Mock (),
618+ arrow_schema_bytes = Mock (),
571619description = Mock ())
572620
573621self .assertEqual (has_more_rows ,has_more_rows_resp )
@@ -591,15 +639,15 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
591639pyarrow .field ("column2" ,pyarrow .string ()),
592640pyarrow .field ("column3" ,pyarrow .float64 ()),
593641pyarrow .field ("column3" ,pyarrow .binary ())
594- ])
642+ ]). serialize (). to_pybytes ()
595643
596644thrift_backend = ThriftBackend ("foobar" ,443 ,"path" , [])
597645arrow_queue ,has_more_results = thrift_backend .fetch_results (
598646op_handle = Mock (),
599647max_rows = 1 ,
600648max_bytes = 1 ,
601649expected_row_start_offset = 0 ,
602- arrow_schema = schema ,
650+ arrow_schema_bytes = schema ,
603651description = MagicMock ())
604652
605653self .assertEqual (arrow_queue .n_valid_rows ,15 * 10 )
@@ -792,24 +840,21 @@ def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mo
792840schema = Mock ()
793841cols = Mock ()
794842arrow_batches = Mock ()
843+ description = Mock ()
795844
796845t_col_set = ttypes .TRowSet (columns = cols )
797- thrift_backend ._create_arrow_table (t_col_set ,schema ,Mock () )
846+ thrift_backend ._create_arrow_table (t_col_set ,schema ,description )
798847convert_arrow_mock .assert_not_called ()
799- convert_col_mock .assert_called_once_with (cols ,schema )
848+ convert_col_mock .assert_called_once_with (cols ,description )
800849
801850t_arrow_set = ttypes .TRowSet (arrowBatches = arrow_batches )
802851thrift_backend ._create_arrow_table (t_arrow_set ,schema ,Mock ())
803852convert_arrow_mock .assert_called_once_with (arrow_batches ,schema )
804- convert_col_mock .assert_called_once_with (cols ,schema )
805853
806854def test_convert_column_based_set_to_arrow_table_without_nulls (self ):
807- schema = pyarrow .schema ([
808- pyarrow .field ("column1" ,pyarrow .int32 ()),
809- pyarrow .field ("column2" ,pyarrow .string ()),
810- pyarrow .field ("column3" ,pyarrow .float64 ()),
811- pyarrow .field ("column3" ,pyarrow .binary ())
812- ])
855+ # Deliberately duplicate the column name to check that dups work
856+ field_names = ["column1" ,"column2" ,"column3" ,"column3" ]
857+ description = [(name , )for name in field_names ]
813858
814859t_cols = [
815860ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 ,2 ,3 ],nulls = bytes (1 ))),
@@ -820,7 +865,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
820865binaryVal = ttypes .TBinaryColumn (values = [b'\x11 ' ,b'\x22 ' ,b'\x33 ' ],nulls = bytes (1 )))
821866 ]
822867
823- arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols ,schema )
868+ arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
869+ t_cols ,description )
824870self .assertEqual (n_rows ,3 )
825871
826872# Check schema, column names and types
@@ -841,12 +887,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
841887self .assertEqual (arrow_table .column (3 ).to_pylist (), [b'\x11 ' ,b'\x22 ' ,b'\x33 ' ])
842888
843889def test_convert_column_based_set_to_arrow_table_with_nulls (self ):
844- schema = pyarrow .schema ([
845- pyarrow .field ("column1" ,pyarrow .int32 ()),
846- pyarrow .field ("column2" ,pyarrow .string ()),
847- pyarrow .field ("column3" ,pyarrow .float64 ()),
848- pyarrow .field ("column3" ,pyarrow .binary ())
849- ])
890+ field_names = ["column1" ,"column2" ,"column3" ,"column3" ]
891+ description = [(name , )for name in field_names ]
850892
851893t_cols = [
852894ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 ,2 ,3 ],nulls = bytes ([1 ]))),
@@ -859,7 +901,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
859901values = [b'\x11 ' ,b'\x22 ' ,b'\x33 ' ],nulls = bytes ([3 ])))
860902 ]
861903
862- arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols ,schema )
904+ arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
905+ t_cols ,description )
863906self .assertEqual (n_rows ,3 )
864907
865908# Check data
@@ -869,12 +912,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
869912self .assertEqual (arrow_table .column (3 ).to_pylist (), [None ,None ,b'\x33 ' ])
870913
871914def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set (self ):
872- schema = pyarrow .schema ([
873- pyarrow .field ("column1" ,pyarrow .string ()),
874- pyarrow .field ("column2" ,pyarrow .string ()),
875- pyarrow .field ("column3" ,pyarrow .string ()),
876- pyarrow .field ("column3" ,pyarrow .string ())
877- ])
915+ field_names = ["column1" ,"column2" ,"column3" ,"column3" ]
916+ description = [(name , )for name in field_names ]
878917
879918t_cols = [
880919ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 ,2 ,3 ],nulls = bytes (1 ))),
@@ -885,7 +924,8 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
885924binaryVal = ttypes .TBinaryColumn (values = [b'\x11 ' ,b'\x22 ' ,b'\x33 ' ],nulls = bytes (1 )))
886925 ]
887926
888- arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols ,schema )
927+ arrow_table ,n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
928+ t_cols ,description )
889929self .assertEqual (n_rows ,3 )
890930
891931# Check schema, column names and types