- Notifications
You must be signed in to change notification settings - Fork126
Enhance Arrow to Pandas conversion with type overrides and additional kwargs#579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
048af730b1b05b647ed3931b44d42f32c6cFile filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: | ||
| # (True by default) | ||
| # use_cloud_fetch | ||
| # Enable use of cloud fetch to extract large query results in parallel via cloud storage | ||
| # _arrow_pandas_type_override | ||
| # Override the default pandas dtype mapping for Arrow types. | ||
| # This is a dictionary of Arrow types to pandas dtypes. | ||
| # _arrow_to_pandas_kwargs | ||
| # Additional or modified arguments to pass to pandas.DataFrame constructor. | ||
| logger.debug( | ||
| "Connection.__init__(server_hostname=%s, http_path=%s)", | ||
| @@ -229,6 +234,8 @@ def read(self) -> Optional[OAuthToken]: | ||
| self.port = kwargs.get("_port", 443) | ||
| self.disable_pandas = kwargs.get("_disable_pandas", False) | ||
| self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) | ||
| self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {}) | ||
| self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) | ||
| auth_provider = get_python_sql_connector_auth_provider( | ||
| server_hostname, **kwargs | ||
| @@ -1346,7 +1353,9 @@ def _convert_arrow_table(self, table): | ||
| # Need to use nullable types, as otherwise type can change when there are missing values. | ||
| # See https://arrow.apache.org/docs/python/pandas.html#nullable-types | ||
| # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html | ||
| DEFAULT_DTYPE_MAPPING: Dict[ | ||
| pyarrow.DataType, pandas.api.extensions.ExtensionDtype | ||
| ] = { | ||
| pyarrow.int8(): pandas.Int8Dtype(), | ||
| pyarrow.int16(): pandas.Int16Dtype(), | ||
| pyarrow.int32(): pandas.Int32Dtype(), | ||
| @@ -1361,13 +1370,35 @@ def _convert_arrow_table(self, table): | ||
| pyarrow.string(): pandas.StringDtype(), | ||
| } | ||
| arrow_pandas_type_override = self.connection._arrow_pandas_type_override | ||
| if not isinstance(arrow_pandas_type_override, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. This if block is not needed, let it fail here itself. Don't want the user to give something incorrect and then everything works. This is a new change and nothing to be backward compatible. | ||
| logger.debug( | ||
| "_arrow_pandas_type_override on connection was not a dict, using default type mapping" | ||
| ) | ||
| arrow_pandas_type_override = {} | ||
| dtype_mapping = { | ||
| **DEFAULT_DTYPE_MAPPING, | ||
| **arrow_pandas_type_override, | ||
| } | ||
| to_pandas_kwargs: dict[str, Any] = { | ||
| "types_mapper": dtype_mapping.get, | ||
| "date_as_object": True, | ||
| "timestamp_as_object": True, | ||
| } | ||
| arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs | ||
| if isinstance(arrow_to_pandas_kwargs, dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Same let it fail when the input format is incorrect. The default python interpreter error of type mismatch is enough | ||
| to_pandas_kwargs.update(arrow_to_pandas_kwargs) | ||
| else: | ||
| logger.debug( | ||
| "_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" | ||
| ) | ||
| # Need to rename columns, as the to_pandas function cannot handle duplicate column names | ||
| table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) | ||
| df = table_renamed.to_pandas(**to_pandas_kwargs) | ||
| res = df.to_numpy(na_value=None, dtype="object") | ||
| return [ResultRow(*v) for v in res] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| importpytest | ||
| try: | ||
| importpyarrowaspa | ||
| exceptImportError: | ||
| pa=None | ||
| importpandas | ||
| importdatetime | ||
| importunittest | ||
| fromunittest.mockimportMagicMock | ||
| fromdatabricks.sql.clientimportResultSet,Connection,ExecuteResponse | ||
| fromdatabricks.sql.typesimportRow | ||
| fromdatabricks.sql.utilsimportArrowQueue | ||
| @pytest.mark.skipif(paisNone,reason="PyArrow is not installed") | ||
| classArrowConversionTests(unittest.TestCase): | ||
| @staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Move this to use fixtures | ||
| defmock_connection_static(): | ||
| conn=MagicMock(spec=Connection) | ||
| conn.disable_pandas=False | ||
| conn._arrow_pandas_type_override= {} | ||
| conn._arrow_to_pandas_kwargs= {} | ||
| returnconn | ||
| @staticmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Use fixtures or just normal functions, don't need static methods | ||
| defsample_arrow_table_static(): | ||
| data= [ | ||
| pa.array([1,2,3],type=pa.int32()), | ||
| pa.array(["a","b","c"],type=pa.string()), | ||
| ] | ||
| schema=pa.schema([("col_int",pa.int32()), ("col_str",pa.string())]) | ||
| returnpa.Table.from_arrays(data,schema=schema) | ||
| @staticmethod | ||
| defmock_thrift_backend_static(): | ||
| sample_table=ArrowConversionTests.sample_arrow_table_static() | ||
| tb=MagicMock() | ||
| empty_arrays= [pa.array([],type=field.type)forfieldinsample_table.schema] | ||
| empty_table=pa.Table.from_arrays(empty_arrays,schema=sample_table.schema) | ||
| tb.fetch_results.return_value= (ArrowQueue(empty_table,0),False) | ||
| returntb | ||
| @staticmethod | ||
| defmock_raw_execute_response_static(): | ||
| er=MagicMock(spec=ExecuteResponse) | ||
| er.description= [ | ||
| ("col_int","int",None,None,None,None,None), | ||
| ("col_str","string",None,None,None,None,None), | ||
| ] | ||
| er.arrow_schema_bytes=None | ||
| er.arrow_queue=None | ||
| er.has_more_rows=False | ||
| er.lz4_compressed=False | ||
| er.command_handle=MagicMock() | ||
| er.status=MagicMock() | ||
| er.has_been_closed_server_side=False | ||
| er.is_staging_operation=False | ||
| returner | ||
| deftest_convert_arrow_table_default(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. The test_convert_arrow_table_deafult, test_convert_arrow_table_disable_pandas and test_convert_arrow_table_type_override are essentially the same test flow just with different arguments. Plz use pytest's parameterized tests for such tests where only arguments change | ||
| mock_connection=ArrowConversionTests.mock_connection_static() | ||
| sample_arrow_table=ArrowConversionTests.sample_arrow_table_static() | ||
| mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static() | ||
| mock_raw_execute_response= ( | ||
| ArrowConversionTests.mock_raw_execute_response_static() | ||
| ) | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| sample_arrow_table,sample_arrow_table.num_rows | ||
| ) | ||
| rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend) | ||
| result_one=rs.fetchone() | ||
| self.assertIsInstance(result_one,Row) | ||
| self.assertEqual(result_one.col_int,1) | ||
| self.assertEqual(result_one.col_str,"a") | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| sample_arrow_table,sample_arrow_table.num_rows | ||
| ) | ||
| rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend) | ||
| result_all=rs.fetchall() | ||
| self.assertEqual(len(result_all),3) | ||
| self.assertIsInstance(result_all[0],Row) | ||
| self.assertEqual(result_all[0].col_int,1) | ||
| self.assertEqual(result_all[1].col_str,"b") | ||
| deftest_convert_arrow_table_disable_pandas(self): | ||
| mock_connection=ArrowConversionTests.mock_connection_static() | ||
| sample_arrow_table=ArrowConversionTests.sample_arrow_table_static() | ||
| mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static() | ||
| mock_raw_execute_response= ( | ||
| ArrowConversionTests.mock_raw_execute_response_static() | ||
| ) | ||
| mock_connection.disable_pandas=True | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| sample_arrow_table,sample_arrow_table.num_rows | ||
| ) | ||
| rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend) | ||
| result=rs.fetchall() | ||
| self.assertEqual(len(result),3) | ||
| self.assertIsInstance(result[0],Row) | ||
| self.assertEqual(result[0].col_int,1) | ||
| self.assertEqual(result[0].col_str,"a") | ||
| self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(),int) | ||
| self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(),str) | ||
| deftest_convert_arrow_table_type_override(self): | ||
| mock_connection=ArrowConversionTests.mock_connection_static() | ||
| sample_arrow_table=ArrowConversionTests.sample_arrow_table_static() | ||
| mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static() | ||
| mock_raw_execute_response= ( | ||
| ArrowConversionTests.mock_raw_execute_response_static() | ||
| ) | ||
| mock_connection._arrow_pandas_type_override= { | ||
| pa.int32():pandas.Float64Dtype() | ||
| } | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| sample_arrow_table,sample_arrow_table.num_rows | ||
| ) | ||
| rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend) | ||
| result=rs.fetchall() | ||
| self.assertEqual(len(result),3) | ||
| self.assertIsInstance(result[0].col_int,float) | ||
| self.assertEqual(result[0].col_int,1.0) | ||
| self.assertEqual(result[0].col_str,"a") | ||
| deftest_convert_arrow_table_to_pandas_kwargs(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Too much code duplication. Can you create a parameterized test, where in this | ||
| mock_connection=ArrowConversionTests.mock_connection_static() | ||
| mock_thrift_backend= ( | ||
| ArrowConversionTests.mock_thrift_backend_static() | ||
| )# Does not use sample_arrow_table | ||
| mock_raw_execute_response= ( | ||
| ArrowConversionTests.mock_raw_execute_response_static() | ||
| ) | ||
| dt_obj=datetime.datetime(2021,1,1,12,0,0,tzinfo=datetime.timezone.utc) | ||
| ts_array=pa.array([dt_obj],type=pa.timestamp("us",tz="UTC")) | ||
| ts_schema=pa.schema([("col_ts",pa.timestamp("us",tz="UTC"))]) | ||
| ts_table=pa.Table.from_arrays([ts_array],schema=ts_schema) | ||
| mock_raw_execute_response.description= [ | ||
| ("col_ts","timestamp",None,None,None,None,None) | ||
| ] | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue(ts_table,ts_table.num_rows) | ||
| # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. | ||
| mock_connection._arrow_to_pandas_kwargs= {"timestamp_as_object":True} | ||
| rs_ts_true=ResultSet( | ||
| mock_connection,mock_raw_execute_response,mock_thrift_backend | ||
| ) | ||
| result_true=rs_ts_true.fetchall() | ||
| self.assertEqual(len(result_true),1) | ||
| self.assertIsInstance(result_true[0].col_ts,datetime.datetime) | ||
| # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| ts_table,ts_table.num_rows | ||
| )# Reset queue | ||
| mock_connection._arrow_to_pandas_kwargs= {"timestamp_as_object":False} | ||
| rs_ts_false=ResultSet( | ||
| mock_connection,mock_raw_execute_response,mock_thrift_backend | ||
| ) | ||
| result_false=rs_ts_false.fetchall() | ||
| self.assertEqual(len(result_false),1) | ||
| self.assertIsInstance(result_false[0].col_ts,pandas.Timestamp) | ||
| # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. | ||
| mock_raw_execute_response.arrow_queue=ArrowQueue( | ||
| ts_table,ts_table.num_rows | ||
| )# Reset queue | ||
| mock_connection._arrow_to_pandas_kwargs= {} | ||
| rs_ts_default=ResultSet( | ||
| mock_connection,mock_raw_execute_response,mock_thrift_backend | ||
| ) | ||
| result_default=rs_ts_default.fetchall() | ||
| self.assertEqual(len(result_default),1) | ||
| self.assertIsInstance(result_default[0].col_ts,datetime.datetime) | ||
| if__name__=="__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. this is not needed, as we run tests using pytest. Also can you move everything to pytest and remove unittest | ||
| unittest.main() | ||
Uh oh!
There was an error while loading.Please reload this page.