Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[PECO-1026] Add Parameterized Query support to Python#217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
nithinkdb merged 6 commits intodatabricks:mainfromnithinkdb:PECO-1026
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletionssrc/databricks/sql/client.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,7 +14,11 @@
CursorAlreadyClosedError,
)
from databricks.sql.thrift_backend import ThriftBackend
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
from databricks.sql.utils import (
ExecuteResponse,
ParamEscaper,
named_parameters_to_tsparkparams,
)
from databricks.sql.types import Row
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand DownExpand Up@@ -482,7 +486,9 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
)

def execute(
self, operation: str, parameters: Optional[Dict[str, str]] = None
self,
operation: str,
parameters: Optional[Union[List[Any], Dict[str, str]]] = None,
) -> "Cursor":
"""
Execute a query and wait for execution to complete.
Expand All@@ -493,10 +499,10 @@ def execute(
Will result in the query "SELECT * FROM table WHERE field = 'foo' being sent to the server
:returns self
"""
if parameters isnotNone:
operation =inject_parameters(
operation, self.escaper.escape_args(parameters)
)
if parameters is None:
parameters =[]
else:
parameters = named_parameters_to_tsparkparams(parameters)

self._check_not_closed()
self._close_and_clear_active_result_set()
Expand All@@ -508,6 +514,7 @@ def execute(
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=parameters,
)
self.active_result_set = ResultSet(
self.connection,
Expand Down
10 changes: 7 additions & 3 deletionssrc/databricks/sql/thrift_backend.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -224,7 +224,7 @@ def __init__(
def _initialize_retry_args(self, kwargs):
# Configure retries & timing: use user-settings or defaults, and bound
# by policy. Log.warn when given param gets restricted.
for(key, (type_, default, min, max)) in _retry_policy.items():
for key, (type_, default, min, max) in _retry_policy.items():
given_or_default = type_(kwargs.get(key, default))
bound = _bound(min, max, given_or_default)
setattr(self, key, bound)
Expand DownExpand Up@@ -368,7 +368,6 @@ def attempt_request(attempt):

error, error_message, retry_delay = None, None, None
try:

this_method_name = getattr(method, "__name__")

logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
Expand DownExpand Up@@ -614,7 +613,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
num_rows,
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
(
arrow_table,
num_rows,
) = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
Expand DownExpand Up@@ -813,6 +815,7 @@ def execute_command(
lz4_compression,
cursor,
use_cloud_fetch=False,
parameters=[],
):
assert session_handle is not None

Expand All@@ -839,6 +842,7 @@ def execute_command(
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
},
useArrowNativeTypes=spark_arrow_types,
parameters=parameters,
)
resp = self.make_request(self._client.ExecuteStatement, req)
return self._handle_execute_response(resp, cursor)
Expand Down
90 changes: 89 additions & 1 deletionsrc/databricks/sql/utils.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import namedtuple, OrderedDict
from collections.abc import Iterable
Expand All@@ -8,13 +9,17 @@
import lz4.frame
from typing import Dict, List, Union, Any
import pyarrow
from enum import Enum
import copy

from databricks.sql import exc, OperationalError
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkArrowResultLink,
TSparkRowSetType,
TRowSet,
TSparkParameter,
TSparkParameterValue,
)

BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
Expand DownExpand Up@@ -404,7 +409,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema


def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
for(i, col) in enumerate(table.itercolumns()):
for i, col in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
lambda v: v if v is None else Decimal(v)
Expand DownExpand Up@@ -470,3 +475,86 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
result[i] = None

return pyarrow.array(result, type=arrow_type)


class DbSqlType(Enum):
STRING = "STRING"
DATE = "DATE"
TIMESTAMP = "TIMESTAMP"
FLOAT = "FLOAT"
DECIMAL = "DECIMAL"
INTEGER = "INTEGER"
BIGINT = "BIGINT"
SMALLINT = "SMALLINT"
TINYINT = "TINYINT"
BOOLEAN = "BOOLEAN"
INTERVAL_MONTH = "INTERVAL MONTH"
INTERVAL_DAY = "INTERVAL DAY"


class DbSqlParameter:
name: str
value: Any
type: DbSqlType

def __init__(self, name="", value=None, type=None):
self.name = name
self.value = value
self.type = type

def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__


def named_parameters_to_dbsqlparams_v1(parameters: Dict[str, str]):
dbsqlparams = []
for name, parameter in parameters.items():
dbsqlparams.append(DbSqlParameter(name=name, value=parameter))
return dbsqlparams


def named_parameters_to_dbsqlparams_v2(parameters: List[Any]):
dbsqlparams = []
for parameter in parameters:
if isinstance(parameter, DbSqlParameter):
dbsqlparams.append(parameter)
else:
dbsqlparams.append(DbSqlParameter(value=parameter))
return dbsqlparams


def infer_types(params: list[DbSqlParameter]):
type_lookup_table = {
str: DbSqlType.STRING,
int: DbSqlType.INTEGER,
float: DbSqlType.FLOAT,
datetime.datetime: DbSqlType.TIMESTAMP,
bool: DbSqlType.BOOLEAN,
}
newParams = copy.deepcopy(params)
for param in newParams:
if not param.type:
if type(param.value) in type_lookup_table:
param.type = type_lookup_table[type(param.value)]
else:
raise ValueError("Parameter type cannot be inferred")
param.value = str(param.value)
return newParams


def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):
tspark_params = []
if isinstance(parameters, dict):
dbsql_params = named_parameters_to_dbsqlparams_v1(parameters)
else:
dbsql_params = named_parameters_to_dbsqlparams_v2(parameters)
inferred_type_parameters = infer_types(dbsql_params)
for param in inferred_type_parameters:
tspark_params.append(
TSparkParameter(
type=param.type.value,
name=param.name,
value=TSparkParameterValue(stringValue=param.value),
)
)
return tspark_params
75 changes: 75 additions & 0 deletionstests/unit/test_parameters.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
from databricks.sql.utils import (
named_parameters_to_tsparkparams,
infer_types,
named_parameters_to_dbsqlparams_v1,
named_parameters_to_dbsqlparams_v2,
)
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
TSparkParameterValue,
)
from databricks.sql.utils import DbSqlParameter, DbSqlType
import pytest


class TestTSparkParameterConversion(object):
def test_conversion_e2e(self):
"""This behaviour falls back to Python's default string formatting of numbers"""
assert named_parameters_to_tsparkparams(
["a", 1, True, 1.0, DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)]
) == [
TSparkParameter(
name="", type="STRING", value=TSparkParameterValue(stringValue="a")
),
TSparkParameter(
name="", type="INTEGER", value=TSparkParameterValue(stringValue="1")
),
TSparkParameter(
name="", type="BOOLEAN", value=TSparkParameterValue(stringValue="True")
),
TSparkParameter(
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
),
TSparkParameter(
name="", type="DECIMAL", value=TSparkParameterValue(stringValue="1.0")
),
]

def test_basic_conversions_v1(self):
# Test legacy codepath
assert named_parameters_to_dbsqlparams_v1({"1": 1, "2": "foo", "3": 2.0}) == [
DbSqlParameter("1", 1),
DbSqlParameter("2", "foo"),
DbSqlParameter("3", 2.0),
]

def test_basic_conversions_v2(self):
# Test interspersing named params with unnamed
assert named_parameters_to_dbsqlparams_v2(
[DbSqlParameter("1", 1.0, DbSqlType.DECIMAL), 5, DbSqlParameter("3", "foo")]
) == [
DbSqlParameter("1", 1.0, DbSqlType.DECIMAL),
DbSqlParameter("", 5),
DbSqlParameter("3", "foo"),
]

def test_type_inference(self):
with pytest.raises(ValueError):
infer_types([DbSqlParameter("", None)])
with pytest.raises(ValueError):
infer_types([DbSqlParameter("", {1: 1})])
assert infer_types([DbSqlParameter("", 1)]) == [
DbSqlParameter("", "1", DbSqlType.INTEGER)
]
assert infer_types([DbSqlParameter("", True)]) == [
DbSqlParameter("", "True", DbSqlType.BOOLEAN)
]
assert infer_types([DbSqlParameter("", 1.0)]) == [
DbSqlParameter("", "1.0", DbSqlType.FLOAT)
]
assert infer_types([DbSqlParameter("", "foo")]) == [
DbSqlParameter("", "foo", DbSqlType.STRING)
]
assert infer_types([DbSqlParameter("", 1.0, DbSqlType.DECIMAL)]) == [
DbSqlParameter("", "1.0", DbSqlType.DECIMAL)
]

[8]ページ先頭

©2009-2025 Movatter.jp