Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitad2bb62

Browse files
authored
[PECOBLR-330] Support for complex params (#559)
* Basic testing* testing examples* Basic working prototype* ttypes fix* Refractored the ttypes* nit* Added inline support* Reordered boolean to be above int* CheckWorking e2e tests prototype* More tests* Added unit tests* refractor* nit* nit* nit* nit
1 parent3842583 commitad2bb62

File tree

10 files changed

+478
-67
lines changed

10 files changed

+478
-67
lines changed

‎src/databricks/sql/parameters/__init__.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
TimestampNTZParameter,
1313
TinyIntParameter,
1414
DecimalParameter,
15+
MapParameter,
16+
ArrayParameter,
1517
)

‎src/databricks/sql/parameters/native.py‎

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
importdatetime
22
importdecimal
33
fromenumimportEnum,auto
4-
fromtypingimportOptional,Sequence
4+
fromtypingimportOptional,Sequence,Any
55

66
fromdatabricks.sql.excimportNotSupportedError
77
fromdatabricks.sql.thrift_api.TCLIService.ttypesimport (
88
TSparkParameter,
99
TSparkParameterValue,
10+
TSparkParameterValueArg,
1011
)
1112

1213
importdatetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
5455

5556

5657
TAllowedParameterValue=Union[
57-
str,int,float,datetime.datetime,datetime.date,bool,decimal.Decimal,None
58+
str,
59+
int,
60+
float,
61+
datetime.datetime,
62+
datetime.date,
63+
bool,
64+
decimal.Decimal,
65+
None,
66+
list,
67+
dict,
68+
tuple,
5869
]
5970

6071

@@ -82,6 +93,7 @@ class DbsqlParameterBase:
8293

8394
CAST_EXPR:str
8495
name:Optional[str]
96+
value:Any
8597

8698
defas_tspark_param(self,named:bool)->TSparkParameter:
8799
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98110
def_tspark_param_value(self):
99111
returnTSparkParameterValue(stringValue=str(self.value))
100112

113+
def_tspark_value_arg(self):
114+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115+
returnTSparkParameterValueArg(value=str(self.value),type=self._cast_expr())
116+
101117
def_cast_expr(self):
102118
returnself.CAST_EXPR
103119

@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428444
CAST_EXPR=DatabricksSupportedType.TINYINT.name
429445

430446

447+
classArrayParameter(DbsqlParameterBase):
448+
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449+
450+
def__init__(self,value:Sequence[Any],name:Optional[str]=None):
451+
"""
452+
:value:
453+
The value to bind for this parameter. This will be casted to a ARRAY.
454+
:name:
455+
If None, your query must contain a `?` marker. Like:
456+
457+
```sql
458+
SELECT * FROM table WHERE field = ?
459+
```
460+
If not None, your query should contain a named parameter marker. Like:
461+
```sql
462+
SELECT * FROM table WHERE field = :my_param
463+
```
464+
465+
The `name` argument to this function would be `my_param`.
466+
"""
467+
self.name=name
468+
self.value= [dbsql_parameter_from_primitive(val)forvalinvalue]
469+
470+
defas_tspark_param(self,named:bool=False)->TSparkParameter:
471+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472+
473+
tsp=TSparkParameter(type=self._cast_expr())
474+
tsp.arguments= [val._tspark_value_arg()forvalinself.value]
475+
476+
ifnamed:
477+
tsp.name=self.name
478+
tsp.ordinal=False
479+
elifnotnamed:
480+
tsp.ordinal=True
481+
returntsp
482+
483+
def_tspark_value_arg(self):
484+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485+
tva=TSparkParameterValueArg(type=self._cast_expr())
486+
tva.arguments= [val._tspark_value_arg()forvalinself.value]
487+
returntva
488+
489+
CAST_EXPR=DatabricksSupportedType.ARRAY.name
490+
491+
492+
classMapParameter(DbsqlParameterBase):
493+
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494+
495+
def__init__(self,value:dict,name:Optional[str]=None):
496+
"""
497+
:value:
498+
The value to bind for this parameter. This will be casted to a MAP.
499+
:name:
500+
If None, your query must contain a `?` marker. Like:
501+
502+
```sql
503+
SELECT * FROM table WHERE field = ?
504+
```
505+
If not None, your query should contain a named parameter marker. Like:
506+
```sql
507+
SELECT * FROM table WHERE field = :my_param
508+
```
509+
510+
The `name` argument to this function would be `my_param`.
511+
"""
512+
self.name=name
513+
self.value= [
514+
dbsql_parameter_from_primitive(item)
515+
forkey,valinvalue.items()
516+
foritemin (key,val)
517+
]
518+
519+
defas_tspark_param(self,named:bool=False)->TSparkParameter:
520+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521+
522+
tsp=TSparkParameter(type=self._cast_expr())
523+
tsp.arguments= [val._tspark_value_arg()forvalinself.value]
524+
ifnamed:
525+
tsp.name=self.name
526+
tsp.ordinal=False
527+
elifnotnamed:
528+
tsp.ordinal=True
529+
returntsp
530+
531+
def_tspark_value_arg(self):
532+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533+
tva=TSparkParameterValueArg(type=self._cast_expr())
534+
tva.arguments= [val._tspark_value_arg()forvalinself.value]
535+
returntva
536+
537+
CAST_EXPR=DatabricksSupportedType.MAP.name
538+
539+
431540
classDecimalParameter(DbsqlParameterBase):
432541
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433542

@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543652
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544653
# its logic
545654

546-
iftype(value)isint:
655+
ifisinstance(value,bool):
656+
returnBooleanParameter(value=value,name=name)
657+
elifisinstance(value,int):
547658
returndbsql_parameter_from_int(value,name=name)
548-
eliftype(value)isstr:
659+
elifisinstance(value,str):
549660
returnStringParameter(value=value,name=name)
550-
eliftype(value)isfloat:
661+
elifisinstance(value,float):
551662
returnFloatParameter(value=value,name=name)
552-
eliftype(value)isdatetime.datetime:
663+
elifisinstance(value,datetime.datetime):
553664
returnTimestampParameter(value=value,name=name)
554-
eliftype(value)isdatetime.date:
665+
elifisinstance(value,datetime.date):
555666
returnDateParameter(value=value,name=name)
556-
eliftype(value)isbool:
557-
returnBooleanParameter(value=value,name=name)
558-
eliftype(value)isdecimal.Decimal:
667+
elifisinstance(value,decimal.Decimal):
559668
returnDecimalParameter(value=value,name=name)
669+
elifisinstance(value,dict):
670+
returnMapParameter(value=value,name=name)
671+
elifisinstance(value,Sequence)andnotisinstance(value,str):
672+
returnArrayParameter(value=value,name=name)
560673
elifvalueisNone:
561674
returnVoidParameter(value=value,name=name)
562-
563675
else:
564676
raiseNotSupportedError(
565677
f"Could not infer parameter type from value:{value} -{type(value)}\n"
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581693
TimestampNTZParameter,
582694
TinyIntParameter,
583695
DecimalParameter,
696+
ArrayParameter,
697+
MapParameter,
584698
]
585699

586700

‎src/databricks/sql/utils.py‎

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
importdecimal
66
fromabcimportABC,abstractmethod
77
fromcollectionsimportOrderedDict,namedtuple
8-
fromcollections.abcimportIterable
8+
fromcollections.abcimportMapping
99
fromdecimalimportDecimal
1010
fromenumimportEnum
11-
fromtypingimportAny,Dict,List,Optional,Union
11+
fromtypingimportAny,Dict,List,Optional,Union,Sequence
1212
importre
1313

1414
importlz4.frame
@@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
429429
# Taken from PyHive
430430
classParamEscaper:
431431
_DATE_FORMAT="%Y-%m-%d"
432-
_TIME_FORMAT="%H:%M:%S.%f"
432+
_TIME_FORMAT="%H:%M:%S.%f %z"
433433
_DATETIME_FORMAT="{} {}".format(_DATE_FORMAT,_TIME_FORMAT)
434434

435435
defescape_args(self,parameters):
@@ -458,13 +458,22 @@ def escape_string(self, item):
458458
return"'{}'".format(item.replace("\\","\\\\").replace("'","\\'"))
459459

460460
defescape_sequence(self,item):
461-
l=map(str,map(self.escape_item,item))
462-
return"("+",".join(l)+")"
461+
l=map(self.escape_item,item)
462+
l=list(map(str,l))
463+
return"ARRAY("+",".join(l)+")"
464+
465+
defescape_mapping(self,item):
466+
l=map(
467+
self.escape_item,
468+
(elementforkey,valueinitem.items()forelementin (key,value)),
469+
)
470+
l=list(map(str,l))
471+
return"MAP("+",".join(l)+")"
463472

464473
defescape_datetime(self,item,format,cutoff=0):
465474
dt_str=item.strftime(format)
466475
formatted=dt_str[:-cutoff]ifcutoffandformat.endswith(".%f")elsedt_str
467-
return"'{}'".format(formatted)
476+
return"'{}'".format(formatted.strip())
468477

469478
defescape_decimal(self,item):
470479
returnstr(item)
@@ -476,14 +485,16 @@ def escape_item(self, item):
476485
returnself.escape_number(item)
477486
elifisinstance(item,str):
478487
returnself.escape_string(item)
479-
elifisinstance(item,Iterable):
480-
returnself.escape_sequence(item)
481488
elifisinstance(item,datetime.datetime):
482489
returnself.escape_datetime(item,self._DATETIME_FORMAT)
483490
elifisinstance(item,datetime.date):
484491
returnself.escape_datetime(item,self._DATE_FORMAT)
485492
elifisinstance(item,decimal.Decimal):
486493
returnself.escape_decimal(item)
494+
elifisinstance(item,Sequence):
495+
returnself.escape_sequence(item)
496+
elifisinstance(item,Mapping):
497+
returnself.escape_mapping(item)
487498
else:
488499
raiseexc.ProgrammingError("Unsupported object {}".format(item))
489500

‎tests/e2e/test_complex_types.py‎

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
importpytest
22
fromnumpyimportndarray
3+
fromtypingimportSequence
34

45
fromtests.e2e.test_driverimportPySQLPytestTestCase
56

@@ -14,50 +15,73 @@ def table_fixture(self, connection_details):
1415
# Create the table
1516
cursor.execute(
1617
"""
17-
CREATE TABLE IF NOT EXISTSpysql_e2e_test_complex_types_table (
18+
CREATE TABLE IF NOT EXISTSpysql_test_complex_types_table (
1819
array_col ARRAY<STRING>,
1920
map_col MAP<STRING, INTEGER>,
20-
struct_col STRUCT<field1: STRING, field2: INTEGER>
21-
)
21+
struct_col STRUCT<field1: STRING, field2: INTEGER>,
22+
array_array_col ARRAY<ARRAY<STRING>>,
23+
array_map_col ARRAY<MAP<STRING, INTEGER>>,
24+
map_array_col MAP<STRING, ARRAY<STRING>>
25+
) USING DELTA
2226
"""
2327
)
2428
# Insert a record
2529
cursor.execute(
2630
"""
27-
INSERT INTOpysql_e2e_test_complex_types_table
31+
INSERT INTOpysql_test_complex_types_table
2832
VALUES (
2933
ARRAY('a', 'b', 'c'),
3034
MAP('a', 1, 'b', 2, 'c', 3),
31-
NAMED_STRUCT('field1', 'a', 'field2', 1)
35+
NAMED_STRUCT('field1', 'a', 'field2', 1),
36+
ARRAY(ARRAY('a','b','c')),
37+
ARRAY(MAP('a', 1, 'b', 2, 'c', 3)),
38+
MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e'))
3239
)
3340
"""
3441
)
3542
yield
3643
# Clean up the table after the test
37-
cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table")
44+
cursor.execute("DELETE FROM pysql_test_complex_types_table")
3845

3946
@pytest.mark.parametrize(
4047
"field,expected_type",
41-
[("array_col",ndarray), ("map_col",list), ("struct_col",dict)],
48+
[
49+
("array_col",ndarray),
50+
("map_col",list),
51+
("struct_col",dict),
52+
("array_array_col",ndarray),
53+
("array_map_col",ndarray),
54+
("map_array_col",list),
55+
],
4256
)
4357
deftest_read_complex_types_as_arrow(self,field,expected_type,table_fixture):
4458
"""Confirms the return types of a complex type field when reading as arrow"""
4559

4660
withself.cursor()ascursor:
4761
result=cursor.execute(
48-
"SELECT * FROMpysql_e2e_test_complex_types_table LIMIT 1"
62+
"SELECT * FROMpysql_test_complex_types_table LIMIT 1"
4963
).fetchone()
5064

5165
assertisinstance(result[field],expected_type)
5266

53-
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
67+
@pytest.mark.parametrize(
68+
"field",
69+
[
70+
("array_col"),
71+
("map_col"),
72+
("struct_col"),
73+
("array_array_col"),
74+
("array_map_col"),
75+
("map_array_col"),
76+
],
77+
)
5478
deftest_read_complex_types_as_string(self,field,table_fixture):
5579
"""Confirms the return type of a complex type that is returned as a string"""
5680
withself.cursor(
5781
extra_params={"_use_arrow_native_complex_types":False}
5882
)ascursor:
5983
result=cursor.execute(
60-
"SELECT * FROMpysql_e2e_test_complex_types_table LIMIT 1"
84+
"SELECT * FROMpysql_test_complex_types_table LIMIT 1"
6185
).fetchone()
6286

6387
assertisinstance(result[field],str)

‎tests/e2e/test_driver.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
856856
raiseKeyboardInterrupt("Simulated interrupt")
857857
finally:
858858
ifconnisnotNone:
859-
assertnotconn.open,"Connection should be closed after KeyboardInterrupt"
859+
assert (
860+
notconn.open
861+
),"Connection should be closed after KeyboardInterrupt"
860862

861863
deftest_cursor_close_properly_closes_operation(self):
862864
"""Test that Cursor.close() properly closes the active operation handle on the server."""
@@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self):
883885
raiseKeyboardInterrupt("Simulated interrupt")
884886
finally:
885887
ifcursorisnotNone:
886-
assertnotcursor.open,"Cursor should be closed after KeyboardInterrupt"
888+
assert (
889+
notcursor.open
890+
),"Cursor should be closed after KeyboardInterrupt"
887891

888892
deftest_nested_cursor_context_managers(self):
889893
"""Test that nested cursor context managers properly close operations on the server."""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp