- Notifications
You must be signed in to change notification settings - Fork126
[PECOBLR-330] Support for complex params#559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
65fe00dcfc58d672718667c42faa8e5038cddbe54e4dd26d971be1edb032c3f37c89b8b2b8a2a4c36d9909f4d182dc578d8346de6c22f6c716311423178332File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -12,4 +12,6 @@ | ||
| TimestampNTZParameter, | ||
| TinyIntParameter, | ||
| DecimalParameter, | ||
| MapParameter, | ||
| ArrayParameter, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,13 @@ | ||
| import datetime | ||
| import decimal | ||
| from enum import Enum, auto | ||
| from typing import Optional, Sequence, Any | ||
| from databricks.sql.exc import NotSupportedError | ||
| from databricks.sql.thrift_api.TCLIService.ttypes import ( | ||
| TSparkParameter, | ||
| TSparkParameterValue, | ||
| TSparkParameterValueArg, | ||
| ) | ||
| import datetime | ||
| @@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum): | ||
| TAllowedParameterValue = Union[ | ||
| str, | ||
| int, | ||
| float, | ||
| datetime.datetime, | ||
| datetime.date, | ||
| bool, | ||
| decimal.Decimal, | ||
| None, | ||
| list, | ||
| dict, | ||
| tuple, | ||
| ] | ||
| @@ -82,6 +93,7 @@ class DbsqlParameterBase: | ||
| CAST_EXPR: str | ||
| name: Optional[str] | ||
| value: Any | ||
| def as_tspark_param(self, named: bool) -> TSparkParameter: | ||
| """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" | ||
| @@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter: | ||
| def _tspark_param_value(self): | ||
| return TSparkParameterValue(stringValue=str(self.value)) | ||
| def _tspark_value_arg(self): | ||
| """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" | ||
| return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr()) | ||
| def _cast_expr(self): | ||
| return self.CAST_EXPR | ||
| @@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None): | ||
| CAST_EXPR = DatabricksSupportedType.TINYINT.name | ||
| class ArrayParameter(DbsqlParameterBase): | ||
| """Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type.""" | ||
| def __init__(self, value: Sequence[Any], name: Optional[str] = None): | ||
| """ | ||
| :value: | ||
| The value to bind for this parameter. This will be casted to a ARRAY. | ||
| :name: | ||
| If None, your query must contain a `?` marker. Like: | ||
| ```sql | ||
| SELECT * FROM table WHERE field = ? | ||
| ``` | ||
| If not None, your query should contain a named parameter marker. Like: | ||
| ```sql | ||
| SELECT * FROM table WHERE field = :my_param | ||
| ``` | ||
| The `name` argument to this function would be `my_param`. | ||
| """ | ||
| self.name = name | ||
| self.value = [dbsql_parameter_from_primitive(val) for val in value] | ||
jprakash-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| def as_tspark_param(self, named: bool = False) -> TSparkParameter: | ||
| """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" | ||
| tsp = TSparkParameter(type=self._cast_expr()) | ||
| tsp.arguments = [val._tspark_value_arg() for val in self.value] | ||
| if named: | ||
| tsp.name = self.name | ||
| tsp.ordinal = False | ||
| elif not named: | ||
| tsp.ordinal = True | ||
jprakash-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| return tsp | ||
| def _tspark_value_arg(self): | ||
| """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" | ||
| tva = TSparkParameterValueArg(type=self._cast_expr()) | ||
| tva.arguments = [val._tspark_value_arg() for val in self.value] | ||
| return tva | ||
| CAST_EXPR = DatabricksSupportedType.ARRAY.name | ||
| class MapParameter(DbsqlParameterBase): | ||
| """Wrap a Python `dict` that will be bound to a Databricks SQL MAP type.""" | ||
| def __init__(self, value: dict, name: Optional[str] = None): | ||
| """ | ||
| :value: | ||
| The value to bind for this parameter. This will be casted to a MAP. | ||
| :name: | ||
| If None, your query must contain a `?` marker. Like: | ||
| ```sql | ||
| SELECT * FROM table WHERE field = ? | ||
| ``` | ||
| If not None, your query should contain a named parameter marker. Like: | ||
| ```sql | ||
| SELECT * FROM table WHERE field = :my_param | ||
| ``` | ||
| The `name` argument to this function would be `my_param`. | ||
| """ | ||
| self.name = name | ||
| self.value = [ | ||
| dbsql_parameter_from_primitive(item) | ||
| for key, val in value.items() | ||
| for item in (key, val) | ||
| ] | ||
| def as_tspark_param(self, named: bool = False) -> TSparkParameter: | ||
| """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" | ||
| tsp = TSparkParameter(type=self._cast_expr()) | ||
| tsp.arguments = [val._tspark_value_arg() for val in self.value] | ||
| if named: | ||
| tsp.name = self.name | ||
| tsp.ordinal = False | ||
| elif not named: | ||
| tsp.ordinal = True | ||
| return tsp | ||
| def _tspark_value_arg(self): | ||
| """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" | ||
| tva = TSparkParameterValueArg(type=self._cast_expr()) | ||
| tva.arguments = [val._tspark_value_arg() for val in self.value] | ||
| return tva | ||
| CAST_EXPR = DatabricksSupportedType.MAP.name | ||
| class DecimalParameter(DbsqlParameterBase): | ||
| """Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type.""" | ||
| @@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive( | ||
| # havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust | ||
| # its logic | ||
| if isinstance(value, bool): | ||
| return BooleanParameter(value=value, name=name) | ||
| elif isinstance(value, int): | ||
| return dbsql_parameter_from_int(value, name=name) | ||
| elifisinstance(value,str): | ||
| return StringParameter(value=value, name=name) | ||
| elifisinstance(value,float): | ||
| return FloatParameter(value=value, name=name) | ||
| elifisinstance(value,datetime.datetime): | ||
| return TimestampParameter(value=value, name=name) | ||
| elifisinstance(value,datetime.date): | ||
| return DateParameter(value=value, name=name) | ||
| elif isinstance(value, decimal.Decimal): | ||
| return DecimalParameter(value=value, name=name) | ||
| elif isinstance(value, dict): | ||
| return MapParameter(value=value, name=name) | ||
| elif isinstance(value, Sequence) and not isinstance(value, str): | ||
jprakash-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| return ArrayParameter(value=value, name=name) | ||
| elif value is None: | ||
| return VoidParameter(value=value, name=name) | ||
| else: | ||
| raise NotSupportedError( | ||
| f"Could not infer parameter type from value: {value} - {type(value)} \n" | ||
| @@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive( | ||
| TimestampNTZParameter, | ||
| TinyIntParameter, | ||
| DecimalParameter, | ||
| ArrayParameter, | ||
| MapParameter, | ||
| ] | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -5,10 +5,10 @@ | ||||
| import decimal | ||||
| from abc import ABC, abstractmethod | ||||
| from collections import OrderedDict, namedtuple | ||||
| from collections.abc importMapping | ||||
| from decimal import Decimal | ||||
| from enum import Enum | ||||
| from typing import Any, Dict, List, Optional, Union, Sequence | ||||
| import re | ||||
| import lz4.frame | ||||
| @@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed): | ||||
| # Taken from PyHive | ||||
| class ParamEscaper: | ||||
| _DATE_FORMAT = "%Y-%m-%d" | ||||
| _TIME_FORMAT = "%H:%M:%S.%f %z" | ||||
Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Timezone will not be there for TIMESTAMP_NTZ param. Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. +1, would like to know how we've accounted/tested for NTZ ContributorAuthor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. If it is not there it will be an empty space, there is already a test suite that inserts NTZ and none NTZ and reads back to compare whether it is equal or not ContributorAuthor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. @vikrantpuppala@shivam2680 There are already existing tests that insert NTZ and non NTZ values and reads back from table to ensure everything is working as expected -
| ||||
| _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) | ||||
| def escape_args(self, parameters): | ||||
| @@ -458,13 +458,22 @@ def escape_string(self, item): | ||||
| return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'")) | ||||
| def escape_sequence(self, item): | ||||
| l = map(self.escape_item, item) | ||||
| l = list(map(str, l)) | ||||
| return "ARRAY(" + ",".join(l) + ")" | ||||
| def escape_mapping(self, item): | ||||
| l = map( | ||||
| self.escape_item, | ||||
| (element for key, value in item.items() for element in (key, value)), | ||||
| ) | ||||
| l = list(map(str, l)) | ||||
| return "MAP(" + ",".join(l) + ")" | ||||
| def escape_datetime(self, item, format, cutoff=0): | ||||
| dt_str = item.strftime(format) | ||||
| formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str | ||||
| return "'{}'".format(formatted.strip()) | ||||
jprakash-db marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||||
| def escape_decimal(self, item): | ||||
| return str(item) | ||||
| @@ -476,14 +485,16 @@ def escape_item(self, item): | ||||
| return self.escape_number(item) | ||||
| elif isinstance(item, str): | ||||
| return self.escape_string(item) | ||||
| elif isinstance(item, datetime.datetime): | ||||
| return self.escape_datetime(item, self._DATETIME_FORMAT) | ||||
| elif isinstance(item, datetime.date): | ||||
| return self.escape_datetime(item, self._DATE_FORMAT) | ||||
| elif isinstance(item, decimal.Decimal): | ||||
| return self.escape_decimal(item) | ||||
| elif isinstance(item, Sequence): | ||||
| return self.escape_sequence(item) | ||||
| elif isinstance(item, Mapping): | ||||
| return self.escape_mapping(item) | ||||
| else: | ||||
| raise exc.ProgrammingError("Unsupported object {}".format(item)) | ||||
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.