Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Enhance Arrow to Pandas conversion with type overrides and additional kwargs#579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
madhav-db wants to merge5 commits intomain
base:main
Choose a base branch
Loading
fromissue-578
Open
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletionssrc/databricks/sql/client.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]:
# (True by default)
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
# _arrow_pandas_type_override
# Override the default pandas dtype mapping for Arrow types.
# This is a dictionary of Arrow types to pandas dtypes.
# _arrow_to_pandas_kwargs
# Additional or modified arguments to pass to pandas.DataFrame constructor.

logger.debug(
"Connection.__init__(server_hostname=%s, http_path=%s)",
Expand All@@ -229,6 +234,8 @@ def read(self) -> Optional[OAuthToken]:
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {})
self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {})

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
Expand DownExpand Up@@ -1346,7 +1353,9 @@ def _convert_arrow_table(self, table):
# Need to use nullable types, as otherwise type can change when there are missing values.
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
dtype_mapping = {
DEFAULT_DTYPE_MAPPING: Dict[
pyarrow.DataType, pandas.api.extensions.ExtensionDtype
] = {
pyarrow.int8(): pandas.Int8Dtype(),
pyarrow.int16(): pandas.Int16Dtype(),
pyarrow.int32(): pandas.Int32Dtype(),
Expand All@@ -1361,13 +1370,35 @@ def _convert_arrow_table(self, table):
pyarrow.string(): pandas.StringDtype(),
}

arrow_pandas_type_override = self.connection._arrow_pandas_type_override
if not isinstance(arrow_pandas_type_override, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

This if block is not needed, let it fail here itself. Don't want the user to give something incorrect and then everything works. This is a new change and nothing to be backward compatible.
just leave it at this -arrow_pandas_type_override = self.connection._arrow_pandas_type_override

logger.debug(
"_arrow_pandas_type_override on connection was not a dict, using default type mapping"
)
arrow_pandas_type_override = {}

dtype_mapping = {
**DEFAULT_DTYPE_MAPPING,
**arrow_pandas_type_override,
}

to_pandas_kwargs: dict[str, Any] = {
"types_mapper": dtype_mapping.get,
"date_as_object": True,
"timestamp_as_object": True,
}

arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs
if isinstance(arrow_to_pandas_kwargs, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Same let it fail when the input format is incorrect. The default python interpreter error of type mismatch is enough

to_pandas_kwargs.update(arrow_to_pandas_kwargs)
else:
logger.debug(
"_arrow_to_pandas_kwargs on connection was not a dict, using default arguments"
)

# Need to rename columns, as the to_pandas function cannot handle duplicate column names
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
df = table_renamed.to_pandas(
types_mapper=dtype_mapping.get,
date_as_object=True,
timestamp_as_object=True,
)
df = table_renamed.to_pandas(**to_pandas_kwargs)

res = df.to_numpy(na_value=None, dtype="object")
return [ResultRow(*v) for v in res]
Expand Down
184 changes: 184 additions & 0 deletionstests/unit/test_arrow_conversion.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
importpytest

try:
importpyarrowaspa
exceptImportError:
pa=None
importpandas
importdatetime
importunittest
fromunittest.mockimportMagicMock

fromdatabricks.sql.clientimportResultSet,Connection,ExecuteResponse
fromdatabricks.sql.typesimportRow
fromdatabricks.sql.utilsimportArrowQueue

@pytest.mark.skipif(paisNone,reason="PyArrow is not installed")
classArrowConversionTests(unittest.TestCase):
@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Move this to use fixtures

defmock_connection_static():
conn=MagicMock(spec=Connection)
conn.disable_pandas=False
conn._arrow_pandas_type_override= {}
conn._arrow_to_pandas_kwargs= {}
returnconn

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Use fixtures or just normal functions, don't need static methods

defsample_arrow_table_static():
data= [
pa.array([1,2,3],type=pa.int32()),
pa.array(["a","b","c"],type=pa.string()),
]
schema=pa.schema([("col_int",pa.int32()), ("col_str",pa.string())])
returnpa.Table.from_arrays(data,schema=schema)

@staticmethod
defmock_thrift_backend_static():
sample_table=ArrowConversionTests.sample_arrow_table_static()
tb=MagicMock()
empty_arrays= [pa.array([],type=field.type)forfieldinsample_table.schema]
empty_table=pa.Table.from_arrays(empty_arrays,schema=sample_table.schema)
tb.fetch_results.return_value= (ArrowQueue(empty_table,0),False)
returntb

@staticmethod
defmock_raw_execute_response_static():
er=MagicMock(spec=ExecuteResponse)
er.description= [
("col_int","int",None,None,None,None,None),
("col_str","string",None,None,None,None,None),
]
er.arrow_schema_bytes=None
er.arrow_queue=None
er.has_more_rows=False
er.lz4_compressed=False
er.command_handle=MagicMock()
er.status=MagicMock()
er.has_been_closed_server_side=False
er.is_staging_operation=False
returner

deftest_convert_arrow_table_default(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

The test_convert_arrow_table_deafult, test_convert_arrow_table_disable_pandas and test_convert_arrow_table_type_override are essentially the same test flow just with different arguments. Plz use pytest's parameterized tests for such tests where only arguments change

mock_connection=ArrowConversionTests.mock_connection_static()
sample_arrow_table=ArrowConversionTests.sample_arrow_table_static()
mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static()
mock_raw_execute_response= (
ArrowConversionTests.mock_raw_execute_response_static()
)

mock_raw_execute_response.arrow_queue=ArrowQueue(
sample_arrow_table,sample_arrow_table.num_rows
)
rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend)
result_one=rs.fetchone()
self.assertIsInstance(result_one,Row)
self.assertEqual(result_one.col_int,1)
self.assertEqual(result_one.col_str,"a")

mock_raw_execute_response.arrow_queue=ArrowQueue(
sample_arrow_table,sample_arrow_table.num_rows
)
rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend)
result_all=rs.fetchall()
self.assertEqual(len(result_all),3)
self.assertIsInstance(result_all[0],Row)
self.assertEqual(result_all[0].col_int,1)
self.assertEqual(result_all[1].col_str,"b")

deftest_convert_arrow_table_disable_pandas(self):
mock_connection=ArrowConversionTests.mock_connection_static()
sample_arrow_table=ArrowConversionTests.sample_arrow_table_static()
mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static()
mock_raw_execute_response= (
ArrowConversionTests.mock_raw_execute_response_static()
)

mock_connection.disable_pandas=True
mock_raw_execute_response.arrow_queue=ArrowQueue(
sample_arrow_table,sample_arrow_table.num_rows
)
rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend)
result=rs.fetchall()
self.assertEqual(len(result),3)
self.assertIsInstance(result[0],Row)
self.assertEqual(result[0].col_int,1)
self.assertEqual(result[0].col_str,"a")
self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(),int)
self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(),str)

deftest_convert_arrow_table_type_override(self):
mock_connection=ArrowConversionTests.mock_connection_static()
sample_arrow_table=ArrowConversionTests.sample_arrow_table_static()
mock_thrift_backend=ArrowConversionTests.mock_thrift_backend_static()
mock_raw_execute_response= (
ArrowConversionTests.mock_raw_execute_response_static()
)

mock_connection._arrow_pandas_type_override= {
pa.int32():pandas.Float64Dtype()
}
mock_raw_execute_response.arrow_queue=ArrowQueue(
sample_arrow_table,sample_arrow_table.num_rows
)
rs=ResultSet(mock_connection,mock_raw_execute_response,mock_thrift_backend)
result=rs.fetchall()
self.assertEqual(len(result),3)
self.assertIsInstance(result[0].col_int,float)
self.assertEqual(result[0].col_int,1.0)
self.assertEqual(result[0].col_str,"a")

deftest_convert_arrow_table_to_pandas_kwargs(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Too much code duplication. Can you create a parameterized test, where in thismock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} and the assertIsInstance values are parameterized. Otherwise the checking part looks the same and is copied repeatedly

mock_connection=ArrowConversionTests.mock_connection_static()
mock_thrift_backend= (
ArrowConversionTests.mock_thrift_backend_static()
)# Does not use sample_arrow_table
mock_raw_execute_response= (
ArrowConversionTests.mock_raw_execute_response_static()
)

dt_obj=datetime.datetime(2021,1,1,12,0,0,tzinfo=datetime.timezone.utc)
ts_array=pa.array([dt_obj],type=pa.timestamp("us",tz="UTC"))
ts_schema=pa.schema([("col_ts",pa.timestamp("us",tz="UTC"))])
ts_table=pa.Table.from_arrays([ts_array],schema=ts_schema)

mock_raw_execute_response.description= [
("col_ts","timestamp",None,None,None,None,None)
]
mock_raw_execute_response.arrow_queue=ArrowQueue(ts_table,ts_table.num_rows)

# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row.
mock_connection._arrow_to_pandas_kwargs= {"timestamp_as_object":True}
rs_ts_true=ResultSet(
mock_connection,mock_raw_execute_response,mock_thrift_backend
)
result_true=rs_ts_true.fetchall()
self.assertEqual(len(result_true),1)
self.assertIsInstance(result_true[0].col_ts,datetime.datetime)

# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input.
mock_raw_execute_response.arrow_queue=ArrowQueue(
ts_table,ts_table.num_rows
)# Reset queue
mock_connection._arrow_to_pandas_kwargs= {"timestamp_as_object":False}
rs_ts_false=ResultSet(
mock_connection,mock_raw_execute_response,mock_thrift_backend
)
result_false=rs_ts_false.fetchall()
self.assertEqual(len(result_false),1)
self.assertIsInstance(result_false[0].col_ts,pandas.Timestamp)

# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default.
mock_raw_execute_response.arrow_queue=ArrowQueue(
ts_table,ts_table.num_rows
)# Reset queue
mock_connection._arrow_to_pandas_kwargs= {}
rs_ts_default=ResultSet(
mock_connection,mock_raw_execute_response,mock_thrift_backend
)
result_default=rs_ts_default.fetchall()
self.assertEqual(len(result_default),1)
self.assertIsInstance(result_default[0].col_ts,datetime.datetime)


if__name__=="__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

this is not needed, as we run tests using pytest. Also can you move everything to pytest and remove unittest

unittest.main()
Loading

[8]ページ先頭

©2009-2025 Movatter.jp