Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[PECOBLR-330] Support for complex params#559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
jprakash-db merged 18 commits intomainfromjprakash-db/complex-param
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
18 commits
Select commitHold shift + click to select a range
65fe00d
Basic testing
jprakash-dbApr 16, 2025
cfc58d6
testing examples
jprakash-dbMay 12, 2025
7271866
Basic working prototype
jprakash-dbMay 17, 2025
7c42faa
ttypes fix
jprakash-dbMay 17, 2025
8e5038c
Refractored the ttypes
jprakash-dbMay 19, 2025
ddbe54e
nit
jprakash-dbMay 19, 2025
4dd26d9
Merge branch 'main' into jprakash-db/complex-param
jprakash-dbMay 19, 2025
71be1ed
Added inline support
jprakash-dbMay 19, 2025
b032c3f
Reordered boolean to be above int
jprakash-dbMay 20, 2025
37c89b8
Check
jprakash-dbMay 20, 2025
b2b8a2a
More tests
jprakash-dbMay 22, 2025
4c36d99
Added unit tests
jprakash-dbMay 22, 2025
09f4d18
refractor
jprakash-dbMay 22, 2025
2dc578d
nit
jprakash-dbMay 22, 2025
8346de6
nit
jprakash-dbJun 13, 2025
c22f6c7
Merge branch 'main' into jprakash-db/complex-param
jprakash-dbJun 13, 2025
1631142
nit
jprakash-dbJun 13, 2025
3178332
nit
jprakash-dbJun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletionssrc/databricks/sql/parameters/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -12,4 +12,6 @@
TimestampNTZParameter,
TinyIntParameter,
DecimalParameter,
MapParameter,
ArrayParameter,
)
136 changes: 125 additions & 11 deletionssrc/databricks/sql/parameters/native.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
import datetime
import decimal
from enum import Enum, auto
from typing import Optional, Sequence
from typing import Optional, Sequence, Any

from databricks.sql.exc import NotSupportedError
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
TSparkParameterValue,
TSparkParameterValueArg,
)

import datetime
Expand DownExpand Up@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):


TAllowedParameterValue = Union[
str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None
str,
int,
float,
datetime.datetime,
datetime.date,
bool,
decimal.Decimal,
None,
list,
dict,
tuple,
]


Expand DownExpand Up@@ -82,6 +93,7 @@ class DbsqlParameterBase:

CAST_EXPR: str
name: Optional[str]
value: Any

def as_tspark_param(self, named: bool) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
Expand All@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
def _tspark_param_value(self):
return TSparkParameterValue(stringValue=str(self.value))

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr())

def _cast_expr(self):
return self.CAST_EXPR

Expand DownExpand Up@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
CAST_EXPR = DatabricksSupportedType.TINYINT.name


class ArrayParameter(DbsqlParameterBase):
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""

def __init__(self, value: Sequence[Any], name: Optional[str] = None):
"""
:value:
The value to bind for this parameter. This will be casted to a ARRAY.
:name:
If None, your query must contain a `?` marker. Like:

```sql
SELECT * FROM table WHERE field = ?
```
If not None, your query should contain a named parameter marker. Like:
```sql
SELECT * FROM table WHERE field = :my_param
```

The `name` argument to this function would be `my_param`.
"""
self.name = name
self.value = [dbsql_parameter_from_primitive(val) for val in value]

def as_tspark_param(self, named: bool = False) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""

tsp = TSparkParameter(type=self._cast_expr())
tsp.arguments = [val._tspark_value_arg() for val in self.value]

if named:
tsp.name = self.name
tsp.ordinal = False
elif not named:
tsp.ordinal = True
return tsp

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
tva = TSparkParameterValueArg(type=self._cast_expr())
tva.arguments = [val._tspark_value_arg() for val in self.value]
return tva

CAST_EXPR = DatabricksSupportedType.ARRAY.name


class MapParameter(DbsqlParameterBase):
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""

def __init__(self, value: dict, name: Optional[str] = None):
"""
:value:
The value to bind for this parameter. This will be casted to a MAP.
:name:
If None, your query must contain a `?` marker. Like:

```sql
SELECT * FROM table WHERE field = ?
```
If not None, your query should contain a named parameter marker. Like:
```sql
SELECT * FROM table WHERE field = :my_param
```

The `name` argument to this function would be `my_param`.
"""
self.name = name
self.value = [
dbsql_parameter_from_primitive(item)
for key, val in value.items()
for item in (key, val)
]

def as_tspark_param(self, named: bool = False) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""

tsp = TSparkParameter(type=self._cast_expr())
tsp.arguments = [val._tspark_value_arg() for val in self.value]
if named:
tsp.name = self.name
tsp.ordinal = False
elif not named:
tsp.ordinal = True
return tsp

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
tva = TSparkParameterValueArg(type=self._cast_expr())
tva.arguments = [val._tspark_value_arg() for val in self.value]
return tva

CAST_EXPR = DatabricksSupportedType.MAP.name


class DecimalParameter(DbsqlParameterBase):
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""

Expand DownExpand Up@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
# its logic

if type(value) is int:
if isinstance(value, bool):
return BooleanParameter(value=value, name=name)
elif isinstance(value, int):
return dbsql_parameter_from_int(value, name=name)
eliftype(value) isstr:
elifisinstance(value,str):
return StringParameter(value=value, name=name)
eliftype(value) isfloat:
elifisinstance(value,float):
return FloatParameter(value=value, name=name)
eliftype(value) isdatetime.datetime:
elifisinstance(value,datetime.datetime):
return TimestampParameter(value=value, name=name)
eliftype(value) isdatetime.date:
elifisinstance(value,datetime.date):
return DateParameter(value=value, name=name)
elif type(value) is bool:
return BooleanParameter(value=value, name=name)
elif type(value) is decimal.Decimal:
elif isinstance(value, decimal.Decimal):
return DecimalParameter(value=value, name=name)
elif isinstance(value, dict):
return MapParameter(value=value, name=name)
elif isinstance(value, Sequence) and not isinstance(value, str):
return ArrayParameter(value=value, name=name)
elif value is None:
return VoidParameter(value=value, name=name)

else:
raise NotSupportedError(
f"Could not infer parameter type from value: {value} - {type(value)} \n"
Expand All@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
TimestampNTZParameter,
TinyIntParameter,
DecimalParameter,
ArrayParameter,
MapParameter,
]


Expand Down
27 changes: 19 additions & 8 deletionssrc/databricks/sql/utils.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,10 +5,10 @@
import decimal
from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from collections.abc importIterable
from collections.abc importMapping
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Sequence
import re

import lz4.frame
Expand DownExpand Up@@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
# Taken from PyHive
class ParamEscaper:
_DATE_FORMAT = "%Y-%m-%d"
_TIME_FORMAT = "%H:%M:%S.%f"
_TIME_FORMAT = "%H:%M:%S.%f %z"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Timezone will not be there for TIMESTAMP_NTZ param.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

+1, would like to know how we've accounted/tested for NTZ

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

If it is not there it will be an empty space, there is already a test suite that inserts NTZ and none NTZ and reads back to compare whether it is equal or not

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

@vikrantpuppala@shivam2680 There are already existing tests that insert NTZ and non NTZ values and reads back from table to ensure everything is working as expected -

_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)

def escape_args(self, parameters):
Expand DownExpand Up@@ -458,13 +458,22 @@ def escape_string(self, item):
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))

def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
return "(" + ",".join(l) + ")"
l = map(self.escape_item, item)
l = list(map(str, l))
return "ARRAY(" + ",".join(l) + ")"

def escape_mapping(self, item):
l = map(
self.escape_item,
(element for key, value in item.items() for element in (key, value)),
)
l = list(map(str, l))
return "MAP(" + ",".join(l) + ")"

def escape_datetime(self, item, format, cutoff=0):
dt_str = item.strftime(format)
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
return "'{}'".format(formatted)
return "'{}'".format(formatted.strip())

def escape_decimal(self, item):
return str(item)
Expand All@@ -476,14 +485,16 @@ def escape_item(self, item):
return self.escape_number(item)
elif isinstance(item, str):
return self.escape_string(item)
elif isinstance(item, Iterable):
return self.escape_sequence(item)
elif isinstance(item, datetime.datetime):
return self.escape_datetime(item, self._DATETIME_FORMAT)
elif isinstance(item, datetime.date):
return self.escape_datetime(item, self._DATE_FORMAT)
elif isinstance(item, decimal.Decimal):
return self.escape_decimal(item)
elif isinstance(item, Sequence):
return self.escape_sequence(item)
elif isinstance(item, Mapping):
return self.escape_mapping(item)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))

Expand Down
44 changes: 34 additions & 10 deletionstests/e2e/test_complex_types.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
import pytest
from numpy import ndarray
from typing import Sequence

from tests.e2e.test_driver import PySQLPytestTestCase

Expand All@@ -14,50 +15,73 @@ def table_fixture(self, connection_details):
# Create the table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTSpysql_e2e_test_complex_types_table (
CREATE TABLE IF NOT EXISTSpysql_test_complex_types_table (
array_col ARRAY<STRING>,
map_col MAP<STRING, INTEGER>,
struct_col STRUCT<field1: STRING, field2: INTEGER>
)
struct_col STRUCT<field1: STRING, field2: INTEGER>,
array_array_col ARRAY<ARRAY<STRING>>,
array_map_col ARRAY<MAP<STRING, INTEGER>>,
map_array_col MAP<STRING, ARRAY<STRING>>
) USING DELTA
"""
)
# Insert a record
cursor.execute(
"""
INSERT INTOpysql_e2e_test_complex_types_table
INSERT INTOpysql_test_complex_types_table
VALUES (
ARRAY('a', 'b', 'c'),
MAP('a', 1, 'b', 2, 'c', 3),
NAMED_STRUCT('field1', 'a', 'field2', 1)
NAMED_STRUCT('field1', 'a', 'field2', 1),
ARRAY(ARRAY('a','b','c')),
ARRAY(MAP('a', 1, 'b', 2, 'c', 3)),
MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e'))
)
"""
)
yield
# Clean up the table after the test
cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table")
cursor.execute("DELETE FROM pysql_test_complex_types_table")

@pytest.mark.parametrize(
"field,expected_type",
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
[
("array_col", ndarray),
("map_col", list),
("struct_col", dict),
("array_array_col", ndarray),
("array_map_col", ndarray),
("map_array_col", list),
],
)
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
"""Confirms the return types of a complex type field when reading as arrow"""

with self.cursor() as cursor:
result = cursor.execute(
"SELECT * FROMpysql_e2e_test_complex_types_table LIMIT 1"
"SELECT * FROMpysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], expected_type)

@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
@pytest.mark.parametrize(
"field",
[
("array_col"),
("map_col"),
("struct_col"),
("array_array_col"),
("array_map_col"),
("map_array_col"),
],
)
def test_read_complex_types_as_string(self, field, table_fixture):
"""Confirms the return type of a complex type that is returned as a string"""
with self.cursor(
extra_params={"_use_arrow_native_complex_types": False}
) as cursor:
result = cursor.execute(
"SELECT * FROMpysql_e2e_test_complex_types_table LIMIT 1"
"SELECT * FROMpysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], str)
8 changes: 6 additions & 2 deletionstests/e2e/test_driver.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
raise KeyboardInterrupt("Simulated interrupt")
finally:
if conn is not None:
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
assert (
not conn.open
), "Connection should be closed after KeyboardInterrupt"

def test_cursor_close_properly_closes_operation(self):
"""Test that Cursor.close() properly closes the active operation handle on the server."""
Expand All@@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self):
raise KeyboardInterrupt("Simulated interrupt")
finally:
if cursor is not None:
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
assert (
not cursor.open
), "Cursor should be closed after KeyboardInterrupt"

def test_nested_cursor_context_managers(self):
"""Test that nested cursor context managers properly close operations on the server."""
Expand Down
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp