Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit048fae1

Browse files
authored
[PECOBLR-201] add variant support (#560)
1 parent415fb53 commit048fae1

File tree

3 files changed

+214
-14
lines changed

3 files changed

+214
-14
lines changed

‎src/databricks/sql/backend/thrift_backend.py‎

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def convert_col(t_column_desc):
735735
returnpyarrow.schema([convert_col(col)forcolint_table_schema.columns])
736736

737737
@staticmethod
738-
def_col_to_description(col,session_id_hex=None):
738+
def_col_to_description(col,field=None,session_id_hex=None):
739739
type_entry=col.typeDesc.types[0]
740740

741741
iftype_entry.primitiveEntry:
@@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None):
764764
else:
765765
precision,scale=None,None
766766

767+
# Extract variant type from field if available
768+
iffieldisnotNone:
769+
try:
770+
# Check for variant type in metadata
771+
iffield.metadataandb"Spark:DataType:SqlName"infield.metadata:
772+
sql_type=field.metadata.get(b"Spark:DataType:SqlName")
773+
ifsql_type==b"VARIANT":
774+
cleaned_type="variant"
775+
exceptExceptionase:
776+
logger.debug(f"Could not extract variant type from field:{e}")
777+
767778
returncol.columnName,cleaned_type,None,None,precision,scale,None
768779

769780
@staticmethod
770-
def_hive_schema_to_description(t_table_schema,session_id_hex=None):
781+
def_hive_schema_to_description(
782+
t_table_schema,schema_bytes=None,session_id_hex=None
783+
):
784+
field_dict= {}
785+
ifpyarrowandschema_bytes:
786+
try:
787+
arrow_schema=pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
788+
# Build a dictionary mapping column names to fields
789+
forfieldinarrow_schema:
790+
field_dict[field.name]=field
791+
exceptExceptionase:
792+
logger.debug(f"Could not parse arrow schema:{e}")
793+
771794
return [
772-
ThriftDatabricksClient._col_to_description(col,session_id_hex)
795+
ThriftDatabricksClient._col_to_description(
796+
col,
797+
field_dict.get(col.columnName)iffield_dictelseNone,
798+
session_id_hex,
799+
)
773800
forcolint_table_schema.columns
774801
]
775802

@@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
802829
ordirect_results.resultSet.hasMoreRows
803830
)
804831

805-
description=self._hive_schema_to_description(
806-
t_result_set_metadata_resp.schema,
807-
self._session_id_hex,
808-
)
809-
810832
ifpyarrow:
811833
schema_bytes= (
812834
t_result_set_metadata_resp.arrowSchema
@@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state):
819841
else:
820842
schema_bytes=None
821843

844+
description=self._hive_schema_to_description(
845+
t_result_set_metadata_resp.schema,
846+
schema_bytes,
847+
self._session_id_hex,
848+
)
849+
822850
lz4_compressed=t_result_set_metadata_resp.lz4Compressed
823851
command_id=CommandId.from_thrift_handle(resp.operationHandle)
824852

@@ -863,11 +891,6 @@ def get_execution_result(
863891

864892
t_result_set_metadata_resp=resp.resultSetMetadata
865893

866-
description=self._hive_schema_to_description(
867-
t_result_set_metadata_resp.schema,
868-
self._session_id_hex,
869-
)
870-
871894
ifpyarrow:
872895
schema_bytes= (
873896
t_result_set_metadata_resp.arrowSchema
@@ -880,6 +903,12 @@ def get_execution_result(
880903
else:
881904
schema_bytes=None
882905

906+
description=self._hive_schema_to_description(
907+
t_result_set_metadata_resp.schema,
908+
schema_bytes,
909+
self._session_id_hex,
910+
)
911+
883912
lz4_compressed=t_result_set_metadata_resp.lz4Compressed
884913
is_staging_operation=t_result_set_metadata_resp.isStagingOperation
885914
has_more_rows=resp.hasMoreRows

‎tests/e2e/test_variant_types.py‎

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
importpytest
2+
fromdatetimeimportdatetime
3+
importjson
4+
5+
try:
6+
importpyarrow
7+
exceptImportError:
8+
pyarrow=None
9+
10+
fromtests.e2e.test_driverimportPySQLPytestTestCase
11+
fromtests.e2e.common.predicatesimportpysql_supports_arrow
12+
13+
14+
@pytest.mark.skipif(notpysql_supports_arrow(),reason="Requires arrow support")
15+
classTestVariantTypes(PySQLPytestTestCase):
16+
"""Tests for the proper detection and handling of VARIANT type columns"""
17+
18+
@pytest.fixture(scope="class")
19+
defvariant_table(self,connection_details):
20+
"""A pytest fixture that creates a test table and cleans up after tests"""
21+
self.arguments=connection_details.copy()
22+
table_name="pysql_test_variant_types_table"
23+
24+
withself.cursor()ascursor:
25+
try:
26+
# Create the table with variant columns
27+
cursor.execute(
28+
"""
29+
CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table (
30+
id INTEGER,
31+
variant_col VARIANT,
32+
regular_string_col STRING
33+
)
34+
"""
35+
)
36+
37+
# Insert test records with different variant values
38+
cursor.execute(
39+
"""
40+
INSERT INTO pysql_test_variant_types_table
41+
VALUES
42+
(1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'),
43+
(2, PARSE_JSON('[1, 2, 3, 4]'), 'another string')
44+
"""
45+
)
46+
yieldtable_name
47+
finally:
48+
cursor.execute(f"DROP TABLE IF EXISTS{table_name}")
49+
50+
deftest_variant_type_detection(self,variant_table):
51+
"""Test that VARIANT type columns are properly detected in schema"""
52+
withself.cursor()ascursor:
53+
cursor.execute(f"SELECT * FROM{variant_table} LIMIT 0")
54+
55+
# Verify column types in description
56+
assert (
57+
cursor.description[0][1]=="int"
58+
),"Integer column type not correctly identified"
59+
assert (
60+
cursor.description[1][1]=="variant"
61+
),"VARIANT column type not correctly identified"
62+
assert (
63+
cursor.description[2][1]=="string"
64+
),"String column type not correctly identified"
65+
66+
deftest_variant_data_retrieval(self,variant_table):
67+
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
68+
withself.cursor()ascursor:
69+
cursor.execute(f"SELECT * FROM{variant_table} ORDER BY id")
70+
rows=cursor.fetchall()
71+
72+
# First row should have a JSON object
73+
json_obj=rows[0][1]
74+
assertisinstance(
75+
json_obj,str
76+
),"VARIANT column should be returned as string"
77+
78+
parsed=json.loads(json_obj)
79+
assertparsed.get("name")=="John"
80+
assertparsed.get("age")==30
81+
82+
# Second row should have a JSON array
83+
json_array=rows[1][1]
84+
assertisinstance(
85+
json_array,str
86+
),"VARIANT array should be returned as string"
87+
88+
# Parsing to verify it's valid JSON array
89+
parsed_array=json.loads(json_array)
90+
assertisinstance(parsed_array,list)
91+
assertparsed_array== [1,2,3,4]

‎tests/unit/test_thrift_backend.py‎

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly(
23302330
[],
23312331
auth_provider=AuthProvider(),
23322332
ssl_options=SSLOptions(),
2333-
http_client=MagicMock(),
2333+
http_client=MagicMock(),
23342334
**complex_arg_types,
23352335
)
23362336
thrift_backend.execute_command(
@@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly(
23562356
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
23572357
)
23582358

2359+
@unittest.skipIf(pyarrowisNone,"Requires pyarrow")
2360+
deftest_col_to_description(self):
2361+
test_cases= [
2362+
("variant_col", {b"Spark:DataType:SqlName":b"VARIANT"},"variant"),
2363+
("normal_col", {},"string"),
2364+
("weird_field", {b"Spark:DataType:SqlName":b"Some unexpected value"},"string"),
2365+
("missing_field",None,"string"),# None field case
2366+
]
2367+
2368+
forcolumn_name,field_metadata,expected_typeintest_cases:
2369+
withself.subTest(column_name=column_name,expected_type=expected_type):
2370+
col=ttypes.TColumnDesc(
2371+
columnName=column_name,
2372+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2373+
)
2374+
2375+
field= (
2376+
None
2377+
iffield_metadataisNone
2378+
elsepyarrow.field(column_name,pyarrow.string(),metadata=field_metadata)
2379+
)
2380+
2381+
result=ThriftDatabricksClient._col_to_description(col,field)
2382+
2383+
self.assertEqual(result[0],column_name)
2384+
self.assertEqual(result[1],expected_type)
2385+
self.assertIsNone(result[2])
2386+
self.assertIsNone(result[3])
2387+
self.assertIsNone(result[4])
2388+
self.assertIsNone(result[5])
2389+
self.assertIsNone(result[6])
2390+
2391+
@unittest.skipIf(pyarrowisNone,"Requires pyarrow")
2392+
deftest_hive_schema_to_description(self):
2393+
test_cases= [
2394+
(
2395+
[
2396+
("regular_col",ttypes.TTypeId.STRING_TYPE),
2397+
("variant_col",ttypes.TTypeId.STRING_TYPE),
2398+
],
2399+
[
2400+
("regular_col", {}),
2401+
("variant_col", {b"Spark:DataType:SqlName":b"VARIANT"}),
2402+
],
2403+
[("regular_col","string"), ("variant_col","variant")],
2404+
),
2405+
(
2406+
[("regular_col",ttypes.TTypeId.STRING_TYPE)],
2407+
None,# No arrow schema
2408+
[("regular_col","string")],
2409+
),
2410+
]
2411+
2412+
forcolumns,arrow_fields,expected_typesintest_cases:
2413+
withself.subTest(arrow_fields=arrow_fieldsisnotNone):
2414+
t_table_schema=ttypes.TTableSchema(
2415+
columns=[
2416+
ttypes.TColumnDesc(
2417+
columnName=name,typeDesc=self._make_type_desc(col_type)
2418+
)
2419+
forname,col_typeincolumns
2420+
]
2421+
)
2422+
2423+
schema_bytes=None
2424+
ifarrow_fields:
2425+
fields= [
2426+
pyarrow.field(name,pyarrow.string(),metadata=metadata)
2427+
forname,metadatainarrow_fields
2428+
]
2429+
schema_bytes=pyarrow.schema(fields).serialize().to_pybytes()
2430+
2431+
description=ThriftDatabricksClient._hive_schema_to_description(
2432+
t_table_schema,schema_bytes
2433+
)
2434+
2435+
fori, (expected_name,expected_type)inenumerate(expected_types):
2436+
self.assertEqual(description[i][0],expected_name)
2437+
self.assertEqual(description[i][1],expected_type)
2438+
23592439

23602440
if__name__=="__main__":
23612441
unittest.main()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp