Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit776f34b

Browse files
authored
[PECO-1026] Add Parameterized Query support to Python (#217)
* Initial commitSigned-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>* Added tsparkparam handlingSigned-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>* Added basic testSigned-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>* Addressed commentsSigned-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>* Addressed missed commentsSigned-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>* Resolved comments---------Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent8211337 commit776f34b

File tree

4 files changed

+184
-10
lines changed

4 files changed

+184
-10
lines changed

‎src/databricks/sql/client.py‎

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
CursorAlreadyClosedError,
1515
)
1616
fromdatabricks.sql.thrift_backendimportThriftBackend
17-
fromdatabricks.sql.utilsimportExecuteResponse,ParamEscaper,inject_parameters
17+
fromdatabricks.sql.utilsimport (
18+
ExecuteResponse,
19+
ParamEscaper,
20+
named_parameters_to_tsparkparams,
21+
)
1822
fromdatabricks.sql.typesimportRow
1923
fromdatabricks.sql.auth.authimportget_python_sql_connector_auth_provider
2024
fromdatabricks.sql.experimental.oauth_persistenceimportOAuthPersistence
@@ -482,7 +486,9 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
482486
)
483487

484488
defexecute(
485-
self,operation:str,parameters:Optional[Dict[str,str]]=None
489+
self,
490+
operation:str,
491+
parameters:Optional[Union[List[Any],Dict[str,str]]]=None,
486492
)->"Cursor":
487493
"""
488494
Execute a query and wait for execution to complete.
@@ -493,10 +499,10 @@ def execute(
493499
Will result in the query "SELECT * FROM table WHERE field = 'foo' being sent to the server
494500
:returns self
495501
"""
496-
ifparametersisnotNone:
497-
operation=inject_parameters(
498-
operation,self.escaper.escape_args(parameters)
499-
)
502+
ifparametersisNone:
503+
parameters=[]
504+
else:
505+
parameters=named_parameters_to_tsparkparams(parameters)
500506

501507
self._check_not_closed()
502508
self._close_and_clear_active_result_set()
@@ -508,6 +514,7 @@ def execute(
508514
lz4_compression=self.connection.lz4_compression,
509515
cursor=self,
510516
use_cloud_fetch=self.connection.use_cloud_fetch,
517+
parameters=parameters,
511518
)
512519
self.active_result_set=ResultSet(
513520
self.connection,

‎src/databricks/sql/thrift_backend.py‎

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
def_initialize_retry_args(self,kwargs):
225225
# Configure retries & timing: use user-settings or defaults, and bound
226226
# by policy. Log.warn when given param gets restricted.
227-
for(key, (type_,default,min,max))in_retry_policy.items():
227+
forkey, (type_,default,min,max)in_retry_policy.items():
228228
given_or_default=type_(kwargs.get(key,default))
229229
bound=_bound(min,max,given_or_default)
230230
setattr(self,key,bound)
@@ -368,7 +368,6 @@ def attempt_request(attempt):
368368

369369
error,error_message,retry_delay=None,None,None
370370
try:
371-
372371
this_method_name=getattr(method,"__name__")
373372

374373
logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
@@ -614,7 +613,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
614613
num_rows,
615614
)=convert_column_based_set_to_arrow_table(t_row_set.columns,description)
616615
elift_row_set.arrowBatchesisnotNone:
617-
(arrow_table,num_rows,)=convert_arrow_based_set_to_arrow_table(
616+
(
617+
arrow_table,
618+
num_rows,
619+
)=convert_arrow_based_set_to_arrow_table(
618620
t_row_set.arrowBatches,lz4_compressed,schema_bytes
619621
)
620622
else:
@@ -813,6 +815,7 @@ def execute_command(
813815
lz4_compression,
814816
cursor,
815817
use_cloud_fetch=False,
818+
parameters=[],
816819
):
817820
assertsession_handleisnotNone
818821

@@ -839,6 +842,7 @@ def execute_command(
839842
"spark.thriftserver.arrowBasedRowSet.timestampAsString":"false"
840843
},
841844
useArrowNativeTypes=spark_arrow_types,
845+
parameters=parameters,
842846
)
843847
resp=self.make_request(self._client.ExecuteStatement,req)
844848
returnself._handle_execute_response(resp,cursor)

‎src/databricks/sql/utils.py‎

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__importannotations
12
fromabcimportABC,abstractmethod
23
fromcollectionsimportnamedtuple,OrderedDict
34
fromcollections.abcimportIterable
@@ -8,13 +9,17 @@
89
importlz4.frame
910
fromtypingimportDict,List,Union,Any
1011
importpyarrow
12+
fromenumimportEnum
13+
importcopy
1114

1215
fromdatabricks.sqlimportexc,OperationalError
1316
fromdatabricks.sql.cloudfetch.download_managerimportResultFileDownloadManager
1417
fromdatabricks.sql.thrift_api.TCLIService.ttypesimport (
1518
TSparkArrowResultLink,
1619
TSparkRowSetType,
1720
TRowSet,
21+
TSparkParameter,
22+
TSparkParameterValue,
1823
)
1924

2025
BIT_MASKS= [1,2,4,8,16,32,64,128]
@@ -404,7 +409,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
404409

405410

406411
defconvert_decimals_in_arrow_table(table,description)->pyarrow.Table:
407-
for(i,col)inenumerate(table.itercolumns()):
412+
fori,colinenumerate(table.itercolumns()):
408413
ifdescription[i][1]=="decimal":
409414
decimal_col=col.to_pandas().apply(
410415
lambdav:vifvisNoneelseDecimal(v)
@@ -470,3 +475,86 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
470475
result[i]=None
471476

472477
returnpyarrow.array(result,type=arrow_type)
478+
479+
480+
classDbSqlType(Enum):
481+
STRING="STRING"
482+
DATE="DATE"
483+
TIMESTAMP="TIMESTAMP"
484+
FLOAT="FLOAT"
485+
DECIMAL="DECIMAL"
486+
INTEGER="INTEGER"
487+
BIGINT="BIGINT"
488+
SMALLINT="SMALLINT"
489+
TINYINT="TINYINT"
490+
BOOLEAN="BOOLEAN"
491+
INTERVAL_MONTH="INTERVAL MONTH"
492+
INTERVAL_DAY="INTERVAL DAY"
493+
494+
495+
classDbSqlParameter:
496+
name:str
497+
value:Any
498+
type:DbSqlType
499+
500+
def__init__(self,name="",value=None,type=None):
501+
self.name=name
502+
self.value=value
503+
self.type=type
504+
505+
def__eq__(self,other):
506+
returnisinstance(other,self.__class__)andself.__dict__==other.__dict__
507+
508+
509+
defnamed_parameters_to_dbsqlparams_v1(parameters:Dict[str,str]):
510+
dbsqlparams= []
511+
forname,parameterinparameters.items():
512+
dbsqlparams.append(DbSqlParameter(name=name,value=parameter))
513+
returndbsqlparams
514+
515+
516+
defnamed_parameters_to_dbsqlparams_v2(parameters:List[Any]):
517+
dbsqlparams= []
518+
forparameterinparameters:
519+
ifisinstance(parameter,DbSqlParameter):
520+
dbsqlparams.append(parameter)
521+
else:
522+
dbsqlparams.append(DbSqlParameter(value=parameter))
523+
returndbsqlparams
524+
525+
526+
definfer_types(params:list[DbSqlParameter]):
527+
type_lookup_table= {
528+
str:DbSqlType.STRING,
529+
int:DbSqlType.INTEGER,
530+
float:DbSqlType.FLOAT,
531+
datetime.datetime:DbSqlType.TIMESTAMP,
532+
bool:DbSqlType.BOOLEAN,
533+
}
534+
newParams=copy.deepcopy(params)
535+
forparaminnewParams:
536+
ifnotparam.type:
537+
iftype(param.value)intype_lookup_table:
538+
param.type=type_lookup_table[type(param.value)]
539+
else:
540+
raiseValueError("Parameter type cannot be inferred")
541+
param.value=str(param.value)
542+
returnnewParams
543+
544+
545+
defnamed_parameters_to_tsparkparams(parameters:Union[List[Any],Dict[str,str]]):
546+
tspark_params= []
547+
ifisinstance(parameters,dict):
548+
dbsql_params=named_parameters_to_dbsqlparams_v1(parameters)
549+
else:
550+
dbsql_params=named_parameters_to_dbsqlparams_v2(parameters)
551+
inferred_type_parameters=infer_types(dbsql_params)
552+
forparamininferred_type_parameters:
553+
tspark_params.append(
554+
TSparkParameter(
555+
type=param.type.value,
556+
name=param.name,
557+
value=TSparkParameterValue(stringValue=param.value),
558+
)
559+
)
560+
returntspark_params

‎tests/unit/test_parameters.py‎

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
fromdatabricks.sql.utilsimport (
2+
named_parameters_to_tsparkparams,
3+
infer_types,
4+
named_parameters_to_dbsqlparams_v1,
5+
named_parameters_to_dbsqlparams_v2,
6+
)
7+
fromdatabricks.sql.thrift_api.TCLIService.ttypesimport (
8+
TSparkParameter,
9+
TSparkParameterValue,
10+
)
11+
fromdatabricks.sql.utilsimportDbSqlParameter,DbSqlType
12+
importpytest
13+
14+
15+
classTestTSparkParameterConversion(object):
16+
deftest_conversion_e2e(self):
17+
"""This behaviour falls back to Python's default string formatting of numbers"""
18+
assertnamed_parameters_to_tsparkparams(
19+
["a",1,True,1.0,DbSqlParameter(value="1.0",type=DbSqlType.DECIMAL)]
20+
)== [
21+
TSparkParameter(
22+
name="",type="STRING",value=TSparkParameterValue(stringValue="a")
23+
),
24+
TSparkParameter(
25+
name="",type="INTEGER",value=TSparkParameterValue(stringValue="1")
26+
),
27+
TSparkParameter(
28+
name="",type="BOOLEAN",value=TSparkParameterValue(stringValue="True")
29+
),
30+
TSparkParameter(
31+
name="",type="FLOAT",value=TSparkParameterValue(stringValue="1.0")
32+
),
33+
TSparkParameter(
34+
name="",type="DECIMAL",value=TSparkParameterValue(stringValue="1.0")
35+
),
36+
]
37+
38+
deftest_basic_conversions_v1(self):
39+
# Test legacy codepath
40+
assertnamed_parameters_to_dbsqlparams_v1({"1":1,"2":"foo","3":2.0})== [
41+
DbSqlParameter("1",1),
42+
DbSqlParameter("2","foo"),
43+
DbSqlParameter("3",2.0),
44+
]
45+
46+
deftest_basic_conversions_v2(self):
47+
# Test interspersing named params with unnamed
48+
assertnamed_parameters_to_dbsqlparams_v2(
49+
[DbSqlParameter("1",1.0,DbSqlType.DECIMAL),5,DbSqlParameter("3","foo")]
50+
)== [
51+
DbSqlParameter("1",1.0,DbSqlType.DECIMAL),
52+
DbSqlParameter("",5),
53+
DbSqlParameter("3","foo"),
54+
]
55+
56+
deftest_type_inference(self):
57+
withpytest.raises(ValueError):
58+
infer_types([DbSqlParameter("",None)])
59+
withpytest.raises(ValueError):
60+
infer_types([DbSqlParameter("", {1:1})])
61+
assertinfer_types([DbSqlParameter("",1)])== [
62+
DbSqlParameter("","1",DbSqlType.INTEGER)
63+
]
64+
assertinfer_types([DbSqlParameter("",True)])== [
65+
DbSqlParameter("","True",DbSqlType.BOOLEAN)
66+
]
67+
assertinfer_types([DbSqlParameter("",1.0)])== [
68+
DbSqlParameter("","1.0",DbSqlType.FLOAT)
69+
]
70+
assertinfer_types([DbSqlParameter("","foo")])== [
71+
DbSqlParameter("","foo",DbSqlType.STRING)
72+
]
73+
assertinfer_types([DbSqlParameter("",1.0,DbSqlType.DECIMAL)])== [
74+
DbSqlParameter("","1.0",DbSqlType.DECIMAL)
75+
]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp