Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb94f59e

Browse files
author
Jesse
authored
[PECO-1109] Parameterized Query: add suport for inferring decimal types (#228)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent9489087 commitb94f59e

File tree

3 files changed

+190
-32
lines changed

3 files changed

+190
-32
lines changed

‎src/databricks/sql/utils.py‎

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
from __future__importannotations
2+
3+
importcopy
4+
importdatetime
5+
importdecimal
26
fromabcimportABC,abstractmethod
3-
fromcollectionsimportnamedtuple,OrderedDict
7+
fromcollectionsimportOrderedDict,namedtuple
48
fromcollections.abcimportIterable
59
fromdecimalimportDecimal
6-
importdatetime
7-
importdecimal
810
fromenumimportEnum
11+
fromtypingimportAny,Dict,List,Union
12+
913
importlz4.frame
10-
fromtypingimportDict,List,Union,Any
1114
importpyarrow
12-
fromenumimportEnum
13-
importcopy
1415

15-
fromdatabricks.sqlimportexc,OperationalError
16+
fromdatabricks.sqlimportOperationalError,exc
1617
fromdatabricks.sql.cloudfetch.download_managerimportResultFileDownloadManager
1718
fromdatabricks.sql.thrift_api.TCLIService.ttypesimport (
18-
TSparkArrowResultLink,
19-
TSparkRowSetType,
2019
TRowSet,
20+
TSparkArrowResultLink,
2121
TSparkParameter,
2222
TSparkParameterValue,
23+
TSparkRowSetType,
2324
)
2425

2526
BIT_MASKS= [1,2,4,8,16,32,64,128]
@@ -478,6 +479,10 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
478479

479480

480481
classDbSqlType(Enum):
482+
"""The values of this enumeration are passed as literals to be used in a CAST
483+
evaluation by the thrift server.
484+
"""
485+
481486
STRING="STRING"
482487
DATE="DATE"
483488
TIMESTAMP="TIMESTAMP"
@@ -495,7 +500,7 @@ class DbSqlType(Enum):
495500
classDbSqlParameter:
496501
name:str
497502
value:Any
498-
type:DbSqlType
503+
type:Union[DbSqlType,DbsqlDynamicDecimalType,Enum]
499504

500505
def__init__(self,name="",value=None,type=None):
501506
self.name=name
@@ -506,6 +511,11 @@ def __eq__(self, other):
506511
returnisinstance(other,self.__class__)andself.__dict__==other.__dict__
507512

508513

514+
classDbsqlDynamicDecimalType:
515+
def__init__(self,value):
516+
self.value=value
517+
518+
509519
defnamed_parameters_to_dbsqlparams_v1(parameters:Dict[str,str]):
510520
dbsqlparams= []
511521
forname,parameterinparameters.items():
@@ -531,16 +541,49 @@ def infer_types(params: list[DbSqlParameter]):
531541
datetime.datetime:DbSqlType.TIMESTAMP,
532542
datetime.date:DbSqlType.DATE,
533543
bool:DbSqlType.BOOLEAN,
544+
Decimal:DbSqlType.DECIMAL,
534545
}
535-
newParams=copy.deepcopy(params)
536-
forparaminnewParams:
546+
new_params=copy.deepcopy(params)
547+
forparaminnew_params:
537548
ifnotparam.type:
538549
iftype(param.value)intype_lookup_table:
539550
param.type=type_lookup_table[type(param.value)]
540551
else:
541552
raiseValueError("Parameter type cannot be inferred")
553+
554+
ifparam.type==DbSqlType.DECIMAL:
555+
cast_exp=calculate_decimal_cast_string(param.value)
556+
param.type=DbsqlDynamicDecimalType(cast_exp)
557+
542558
param.value=str(param.value)
543-
returnnewParams
559+
returnnew_params
560+
561+
562+
defcalculate_decimal_cast_string(input:Decimal)->str:
563+
"""Returns the smallest SQL cast argument that can contain the passed decimal
564+
565+
Example:
566+
Input: Decimal("1234.5678")
567+
Output: DECIMAL(8,4)
568+
"""
569+
570+
string_decimal=str(input)
571+
572+
ifstring_decimal.startswith("0."):
573+
# This decimal is less than 1
574+
overall=after=len(string_decimal)-2
575+
elif"."notinstring_decimal:
576+
# This decimal has no fractional component
577+
overall=len(string_decimal)
578+
after=0
579+
else:
580+
# This decimal has both whole and fractional parts
581+
parts=string_decimal.split(".")
582+
parts_lengths= [len(i)foriinparts]
583+
before,after=parts_lengths[:2]
584+
overall=before+after
585+
586+
returnf"DECIMAL({overall},{after})"
544587

545588

546589
defnamed_parameters_to_tsparkparams(parameters:Union[List[Any],Dict[str,str]]):

‎tests/e2e/common/parameterized_query_tests.py‎

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
importdatetime
22
fromdecimalimportDecimal
3+
fromenumimportEnum
34
fromtypingimportDict,List,Tuple,Union
45

56
importpytz
67

7-
fromdatabricks.sql.clientimportConnection
8-
fromdatabricks.sql.utilsimportDbSqlParameter,DbSqlType
8+
fromdatabricks.sql.excimportDatabaseError
9+
fromdatabricks.sql.utilsimport (
10+
DbSqlParameter,
11+
DbSqlType,
12+
calculate_decimal_cast_string,
13+
)
14+
15+
16+
classMyCustomDecimalType(Enum):
17+
DECIMAL_38_0="DECIMAL(38,0)"
18+
DECIMAL_38_2="DECIMAL(38,2)"
19+
DECIMAL_18_9="DECIMAL(18,9)"
920

1021

1122
classPySQLParameterizedQueryTestSuiteMixin:
@@ -63,6 +74,11 @@ def test_primitive_inferred_string(self):
6374
result=self._get_one_result(self.QUERY,params)
6475
assertresult.col=="Hello"
6576

77+
deftest_primitive_inferred_decimal(self):
78+
params= {"p":Decimal("1234.56")}
79+
result=self._get_one_result(self.QUERY,params)
80+
assertresult.col==Decimal("1234.56")
81+
6682
deftest_dbsqlparam_inferred_bool(self):
6783

6884
params= [DbSqlParameter(name="p",value=True,type=None)]
@@ -103,6 +119,11 @@ def test_dbsqlparam_inferred_string(self):
103119
result=self._get_one_result(self.QUERY,params)
104120
assertresult.col=="Hello"
105121

122+
deftest_dbsqlparam_inferred_decimal(self):
123+
params= [DbSqlParameter(name="p",value=Decimal("1234.56"),type=None)]
124+
result=self._get_one_result(self.QUERY,params)
125+
assertresult.col==Decimal("1234.56")
126+
106127
deftest_dbsqlparam_explicit_bool(self):
107128

108129
params= [DbSqlParameter(name="p",value=True,type=DbSqlType.BOOLEAN)]
@@ -142,3 +163,50 @@ def test_dbsqlparam_explicit_string(self):
142163
params= [DbSqlParameter(name="p",value="Hello",type=DbSqlType.STRING)]
143164
result=self._get_one_result(self.QUERY,params)
144165
assertresult.col=="Hello"
166+
167+
deftest_dbsqlparam_explicit_decimal(self):
168+
params= [
169+
DbSqlParameter(name="p",value=Decimal("1234.56"),type=DbSqlType.DECIMAL)
170+
]
171+
result=self._get_one_result(self.QUERY,params)
172+
assertresult.col==Decimal("1234.56")
173+
174+
deftest_dbsqlparam_custom_explicit_decimal_38_0(self):
175+
176+
# This DECIMAL can be contained in a DECIMAL(38,0) column in Databricks
177+
value=Decimal("12345678912345678912345678912345678912")
178+
params= [
179+
DbSqlParameter(name="p",value=value,type=MyCustomDecimalType.DECIMAL_38_0)
180+
]
181+
result=self._get_one_result(self.QUERY,params)
182+
assertresult.col==value
183+
184+
deftest_dbsqlparam_custom_explicit_decimal_38_2(self):
185+
186+
# This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks
187+
value=Decimal("123456789123456789123456789123456789.12")
188+
params= [
189+
DbSqlParameter(name="p",value=value,type=MyCustomDecimalType.DECIMAL_38_2)
190+
]
191+
result=self._get_one_result(self.QUERY,params)
192+
assertresult.col==value
193+
194+
deftest_dbsqlparam_custom_explicit_decimal_18_9(self):
195+
196+
# This DECIMAL can be contained in a DECIMAL(18,9) column in Databricks
197+
value=Decimal("123456789.123456789")
198+
params= [
199+
DbSqlParameter(name="p",value=value,type=MyCustomDecimalType.DECIMAL_18_9)
200+
]
201+
result=self._get_one_result(self.QUERY,params)
202+
assertresult.col==value
203+
204+
deftest_calculate_decimal_cast_string(self):
205+
206+
assertcalculate_decimal_cast_string(Decimal("10.00"))=="DECIMAL(4,2)"
207+
assert (
208+
calculate_decimal_cast_string(
209+
Decimal("123456789123456789.123456789123456789")
210+
)
211+
=="DECIMAL(36,18)"
212+
)

‎tests/unit/test_parameters.py‎

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
infer_types,
44
named_parameters_to_dbsqlparams_v1,
55
named_parameters_to_dbsqlparams_v2,
6+
calculate_decimal_cast_string,
7+
DbsqlDynamicDecimalType
68
)
79
fromdatabricks.sql.thrift_api.TCLIService.ttypesimport (
810
TSparkParameter,
@@ -11,6 +13,10 @@
1113
fromdatabricks.sql.utilsimportDbSqlParameter,DbSqlType
1214
importpytest
1315

16+
fromdecimalimportDecimal
17+
18+
fromtypingimportList
19+
1420

1521
classTestTSparkParameterConversion(object):
1622
deftest_conversion_e2e(self):
@@ -31,7 +37,7 @@ def test_conversion_e2e(self):
3137
name="",type="FLOAT",value=TSparkParameterValue(stringValue="1.0")
3238
),
3339
TSparkParameter(
34-
name="",type="DECIMAL",value=TSparkParameterValue(stringValue="1.0")
40+
name="",type="DECIMAL(2,1)",value=TSparkParameterValue(stringValue="1.0")
3541
),
3642
]
3743

@@ -53,23 +59,64 @@ def test_basic_conversions_v2(self):
5359
DbSqlParameter("3","foo"),
5460
]
5561

56-
deftest_type_inference(self):
62+
deftest_infer_types_none(self):
5763
withpytest.raises(ValueError):
5864
infer_types([DbSqlParameter("",None)])
65+
66+
deftest_infer_types_dict(self):
5967
withpytest.raises(ValueError):
6068
infer_types([DbSqlParameter("", {1:1})])
61-
assertinfer_types([DbSqlParameter("",1)])== [
62-
DbSqlParameter("","1",DbSqlType.INTEGER)
63-
]
64-
assertinfer_types([DbSqlParameter("",True)])== [
65-
DbSqlParameter("","True",DbSqlType.BOOLEAN)
66-
]
67-
assertinfer_types([DbSqlParameter("",1.0)])== [
68-
DbSqlParameter("","1.0",DbSqlType.FLOAT)
69-
]
70-
assertinfer_types([DbSqlParameter("","foo")])== [
71-
DbSqlParameter("","foo",DbSqlType.STRING)
72-
]
73-
assertinfer_types([DbSqlParameter("",1.0,DbSqlType.DECIMAL)])== [
74-
DbSqlParameter("","1.0",DbSqlType.DECIMAL)
75-
]
69+
70+
deftest_infer_types_integer(self):
71+
input=DbSqlParameter("",1)
72+
output=infer_types([input])
73+
assertoutput== [DbSqlParameter("","1",DbSqlType.INTEGER)]
74+
75+
deftest_infer_types_boolean(self):
76+
input=DbSqlParameter("",True)
77+
output=infer_types([input])
78+
assertoutput== [DbSqlParameter("","True",DbSqlType.BOOLEAN)]
79+
80+
deftest_infer_types_float(self):
81+
input=DbSqlParameter("",1.0)
82+
output=infer_types([input])
83+
assertoutput== [DbSqlParameter("","1.0",DbSqlType.FLOAT)]
84+
85+
deftest_infer_types_string(self):
86+
input=DbSqlParameter("","foo")
87+
output=infer_types([input])
88+
assertoutput== [DbSqlParameter("","foo",DbSqlType.STRING)]
89+
90+
deftest_infer_types_decimal(self):
91+
# The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1)
92+
input=DbSqlParameter("",Decimal("1.0"))
93+
output:List[DbSqlParameter]=infer_types([input])
94+
95+
x=output[0]
96+
97+
assertx.value=="1.0"
98+
assertisinstance(x.type,DbsqlDynamicDecimalType)
99+
assertx.type.value=="DECIMAL(2,1)"
100+
101+
102+
classTestCalculateDecimalCast(object):
103+
104+
deftest_38_38(self):
105+
input=Decimal(".12345678912345678912345678912345678912")
106+
output=calculate_decimal_cast_string(input)
107+
assertoutput=="DECIMAL(38,38)"
108+
109+
deftest_18_9(self):
110+
input=Decimal("123456789.123456789")
111+
output=calculate_decimal_cast_string(input)
112+
assertoutput=="DECIMAL(18,9)"
113+
114+
deftest_38_0(self):
115+
input=Decimal("12345678912345678912345678912345678912")
116+
output=calculate_decimal_cast_string(input)
117+
assertoutput=="DECIMAL(38,0)"
118+
119+
deftest_6_2(self):
120+
input=Decimal("1234.56")
121+
output=calculate_decimal_cast_string(input)
122+
assertoutput=="DECIMAL(6,2)"

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp