|
| 1 | +importsqlalchemy |
| 2 | +fromsqlalchemy.ext.compilerimportcompiles |
| 3 | + |
| 4 | +fromtypingimportUnion |
| 5 | + |
| 6 | +fromdatetimeimportdatetime,time |
| 7 | + |
| 8 | + |
| 9 | +fromdatabricks.sql.utilsimportParamEscaper |
| 10 | + |
| 11 | + |
| 12 | +@compiles(sqlalchemy.types.Enum,"databricks") |
| 13 | +@compiles(sqlalchemy.types.String,"databricks") |
| 14 | +@compiles(sqlalchemy.types.Text,"databricks") |
| 15 | +@compiles(sqlalchemy.types.Time,"databricks") |
| 16 | +@compiles(sqlalchemy.types.Unicode,"databricks") |
| 17 | +@compiles(sqlalchemy.types.UnicodeText,"databricks") |
| 18 | +@compiles(sqlalchemy.types.Uuid,"databricks") |
| 19 | +defcompile_string_databricks(type_,compiler,**kw): |
| 20 | +""" |
| 21 | + We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy |
| 22 | + defaults to incompatible / abnormal compiled names |
| 23 | +
|
| 24 | + Enum -> VARCHAR |
| 25 | + String -> VARCHAR[LENGTH] |
| 26 | + Text -> VARCHAR[LENGTH] |
| 27 | + Time -> TIME |
| 28 | + Unicode -> VARCHAR[LENGTH] |
| 29 | + UnicodeText -> TEXT |
| 30 | + Uuid -> CHAR[32] |
| 31 | +
|
| 32 | + But all of these types will be compiled to STRING in Databricks SQL |
| 33 | + """ |
| 34 | +return"STRING" |
| 35 | + |
| 36 | + |
| 37 | +@compiles(sqlalchemy.types.Integer,"databricks") |
| 38 | +defcompile_integer_databricks(type_,compiler,**kw): |
| 39 | +""" |
| 40 | + We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER" |
| 41 | + """ |
| 42 | +return"INT" |
| 43 | + |
| 44 | + |
| 45 | +@compiles(sqlalchemy.types.LargeBinary,"databricks") |
| 46 | +defcompile_binary_databricks(type_,compiler,**kw): |
| 47 | +""" |
| 48 | + We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB" |
| 49 | + """ |
| 50 | +return"BINARY" |
| 51 | + |
| 52 | + |
| 53 | +@compiles(sqlalchemy.types.Numeric,"databricks") |
| 54 | +defcompile_numeric_databricks(type_,compiler,**kw): |
| 55 | +""" |
| 56 | + We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC" |
| 57 | +
|
| 58 | + The built-in visit_DECIMAL behaviour captures the precision and scale. Here we're just mapping calls to compile Numeric |
| 59 | + to the SQLAlchemy Decimal() implementation |
| 60 | + """ |
| 61 | +returncompiler.visit_DECIMAL(type_,**kw) |
| 62 | + |
| 63 | + |
| 64 | +@compiles(sqlalchemy.types.DateTime,"databricks") |
| 65 | +defcompile_datetime_databricks(type_,compiler,**kw): |
| 66 | +""" |
| 67 | + We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME" |
| 68 | + """ |
| 69 | +return"TIMESTAMP_NTZ" |
| 70 | + |
| 71 | + |
| 72 | +@compiles(sqlalchemy.types.ARRAY,"databricks") |
| 73 | +defcompile_array_databricks(type_,compiler,**kw): |
| 74 | +""" |
| 75 | + SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql. |
| 76 | + The Postgres implementation works for Databricks SQL, so we duplicate that here. |
| 77 | +
|
| 78 | + :type_: |
| 79 | + This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute |
| 80 | + which is itself an instance of TypeEngine |
| 81 | +
|
| 82 | + https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY |
| 83 | + """ |
| 84 | + |
| 85 | +inner=compiler.process(type_.item_type,**kw) |
| 86 | + |
| 87 | +returnf"ARRAY<{inner}>" |
| 88 | + |
| 89 | + |
| 90 | +classDatabricksDateTimeNoTimezoneType(sqlalchemy.types.TypeDecorator): |
| 91 | +"""The decimal that pysql creates when it receives the contents of a TIMESTAMP_NTZ |
| 92 | + includes a timezone of 'Etc/UTC'. But since SQLAlchemy's test suite assumes that |
| 93 | + the sqlalchemy.types.DateTime type will return a datetime.datetime _without_ any |
| 94 | + timezone set, we need to strip the timezone off the value received from pysql. |
| 95 | +
|
| 96 | + It's not clear if DBR sends a timezone to pysql or if pysql is adding it. This could be a bug. |
| 97 | + """ |
| 98 | + |
| 99 | +impl=sqlalchemy.types.DateTime |
| 100 | + |
| 101 | +cache_ok=True |
| 102 | + |
| 103 | +defprocess_result_value(self,value:Union[None,datetime],dialect): |
| 104 | +ifvalueisNone: |
| 105 | +returnNone |
| 106 | +returnvalue.replace(tzinfo=None) |
| 107 | + |
| 108 | + |
| 109 | +classDatabricksTimeType(sqlalchemy.types.TypeDecorator): |
| 110 | +"""Databricks has no native TIME type. So we store it as a string.""" |
| 111 | + |
| 112 | +impl=sqlalchemy.types.Time |
| 113 | +cache_ok=True |
| 114 | + |
| 115 | +TIME_WITH_MICROSECONDS_FMT="%H:%M:%S.%f" |
| 116 | +TIME_NO_MICROSECONDS_FMT="%H:%M:%S" |
| 117 | + |
| 118 | +defprocess_bind_param(self,value:Union[time,None],dialect)->Union[None,str]: |
| 119 | +"""Values sent to the database are converted to %:H:%M:%S strings.""" |
| 120 | +ifvalueisNone: |
| 121 | +returnNone |
| 122 | +returnvalue.strftime(self.TIME_WITH_MICROSECONDS_FMT) |
| 123 | + |
| 124 | +# mypy doesn't like this workaround because TypeEngine wants process_literal_param to return a string |
| 125 | +defprocess_literal_param(self,value,dialect)->time:# type: ignore |
| 126 | +"""It's not clear to me why this is necessary. Without it, SQLAlchemy's Timetest:test_literal fails |
| 127 | + because the string literal renderer receives a str() object and calls .isoformat() on it. |
| 128 | +
|
| 129 | + Whereas this method receives a datetime.time() object which is subsequently passed to that |
| 130 | + same renderer. And that works. |
| 131 | +
|
| 132 | + UPDATE: After coping with the literal_processor override in DatabricksStringType, I suspect a similar |
| 133 | + mechanism is at play. Two different processors are are called in sequence. This is likely a byproduct |
| 134 | + of Databricks not having a true TIME type. I think the string representation of Time() types is |
| 135 | + somehow affecting the literal rendering process. But as long as this passes the tests, I'm not |
| 136 | + worried about it. |
| 137 | + """ |
| 138 | +returnvalue |
| 139 | + |
| 140 | +defprocess_result_value( |
| 141 | +self,value:Union[None,str],dialect |
| 142 | + )->Union[time,None]: |
| 143 | +"""Values received from the database are parsed into datetime.time() objects""" |
| 144 | +ifvalueisNone: |
| 145 | +returnNone |
| 146 | + |
| 147 | +try: |
| 148 | +_parsed=datetime.strptime(value,self.TIME_WITH_MICROSECONDS_FMT) |
| 149 | +exceptValueError: |
| 150 | +# If the string doesn't have microseconds, try parsing it without them |
| 151 | +_parsed=datetime.strptime(value,self.TIME_NO_MICROSECONDS_FMT) |
| 152 | + |
| 153 | +return_parsed.time() |
| 154 | + |
| 155 | + |
| 156 | +classDatabricksStringType(sqlalchemy.types.TypeDecorator): |
| 157 | +"""We have to implement our own String() type because SQLAlchemy's default implementation |
| 158 | + wants to escape single-quotes with a doubled single-quote. Databricks uses a backslash for |
| 159 | + escaping of literal strings. And SQLAlchemy's default escaping breaks Databricks SQL. |
| 160 | + """ |
| 161 | + |
| 162 | +impl=sqlalchemy.types.String |
| 163 | +cache_ok=True |
| 164 | +pe=ParamEscaper() |
| 165 | + |
| 166 | +defprocess_literal_param(self,value,dialect)->str: |
| 167 | +"""SQLAlchemy's default string escaping for backslashes doesn't work for databricks. The logic here |
| 168 | + implements the same logic as our legacy inline escaping logic. |
| 169 | + """ |
| 170 | + |
| 171 | +returnself.pe.escape_string(value) |
| 172 | + |
| 173 | +defliteral_processor(self,dialect): |
| 174 | +"""We manually override this method to prevent further processing of the string literal beyond |
| 175 | + what happens in the process_literal_param() method. |
| 176 | +
|
| 177 | + The SQLAlchemy docs _specifically_ say to not override this method. |
| 178 | +
|
| 179 | + It appears that any processing that happens from TypeEngine.process_literal_param happens _before_ |
| 180 | + and _in addition to_ whatever the class's impl.literal_processor() method does. The String.literal_processor() |
| 181 | + method performs a string replacement that doubles any single-quote in the contained string. This raises a syntax |
| 182 | + error in Databricks. And it's not necessary because ParamEscaper() already implements all the escaping we need. |
| 183 | +
|
| 184 | + We should consider opening an issue on the SQLAlchemy project to see if I'm using it wrong. |
| 185 | +
|
| 186 | + See type_api.py::TypeEngine.literal_processor: |
| 187 | +
|
| 188 | + ```python |
| 189 | + def process(value: Any) -> str: |
| 190 | + return fixed_impl_processor( |
| 191 | + fixed_process_literal_param(value, dialect) |
| 192 | + ) |
| 193 | + ``` |
| 194 | +
|
| 195 | + That call to fixed_impl_processor wraps the result of fixed_process_literal_param (which is the |
| 196 | + process_literal_param defined in our Databricks dialect) |
| 197 | +
|
| 198 | + https://docs.sqlalchemy.org/en/20/core/custom_types.html#sqlalchemy.types.TypeDecorator.literal_processor |
| 199 | + """ |
| 200 | + |
| 201 | +defprocess(value): |
| 202 | +"""This is a copy of the default String.literal_processor() method but stripping away |
| 203 | + its double-escaping behaviour for single-quotes. |
| 204 | + """ |
| 205 | + |
| 206 | +_step1=self.process_literal_param(value,dialect="databricks") |
| 207 | +ifdialect.identifier_preparer._double_percents: |
| 208 | +_step2=_step1.replace("%","%%") |
| 209 | +else: |
| 210 | +_step2=_step1 |
| 211 | + |
| 212 | +return"%s"%_step2 |
| 213 | + |
| 214 | +returnprocess |