Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

SQLAlchemy 2: add type compilation for all CamelCase types#238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
susodapop merged 2 commits intomainfromimplement-types
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 20 additions & 56 deletionssrc/databricks/sqlalchemy/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -12,12 +12,14 @@

from databricks import sql

# This import is required to process our @compiles decorators
import databricks.sqlalchemy.types


from databricks.sqlalchemy.base import (
DatabricksDDLCompiler,
DatabricksIdentifierPreparer,
)
from databricks.sqlalchemy.compiler import DatabricksTypeCompiler

try:
import alembic
Expand All@@ -30,52 +32,14 @@ class DatabricksImpl(DefaultImpl):
__dialect__ = "databricks"


class DatabricksDecimal(types.TypeDecorator):
"""Translates strings to decimals"""

impl = types.DECIMAL

def process_result_value(self, value, dialect):
if value is not None:
return decimal.Decimal(value)
else:
return None


class DatabricksTimestamp(types.TypeDecorator):
"""Translates timestamp strings to datetime objects"""

impl = types.TIMESTAMP

def process_result_value(self, value, dialect):
return value

def adapt(self, impltype, **kwargs):
return self.impl


class DatabricksDate(types.TypeDecorator):
"""Translates date strings to date objects"""

impl = types.DATE

def process_result_value(self, value, dialect):
return value

def adapt(self, impltype, **kwargs):
return self.impl


class DatabricksDialect(default.DefaultDialect):
"""This dialect implements only those methods required to pass our e2e tests"""

# Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect
name: str = "databricks"
driver: str = "databricks"
default_schema_name: str = "default"

preparer = DatabricksIdentifierPreparer # type: ignore
type_compiler = DatabricksTypeCompiler
ddl_compiler = DatabricksDDLCompiler
supports_statement_cache: bool = True
supports_multivalues_insert: bool = True
Expand DownExpand Up@@ -137,23 +101,23 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
"""

_type_map = {
"boolean": types.Boolean,
"smallint": types.SmallInteger,
"int": types.Integer,
"bigint": types.BigInteger,
"float": types.Float,
"double": types.Float,
"string": types.String,
"varchar": types.String,
"char": types.String,
"binary": types.String,
"array": types.String,
"map": types.String,
"struct": types.String,
"uniontype": types.String,
"decimal":DatabricksDecimal,
"timestamp":DatabricksTimestamp,
"date":DatabricksDate,
"boolean":sqlalchemy.types.Boolean,
"smallint":sqlalchemy.types.SmallInteger,
"int":sqlalchemy.types.Integer,
"bigint":sqlalchemy.types.BigInteger,
"float":sqlalchemy.types.Float,
"double":sqlalchemy.types.Float,
"string":sqlalchemy.types.String,
"varchar":sqlalchemy.types.String,
"char":sqlalchemy.types.String,
"binary":sqlalchemy.types.String,
"array":sqlalchemy.types.String,
"map":sqlalchemy.types.String,
"struct":sqlalchemy.types.String,
"uniontype":sqlalchemy.types.String,
"decimal":sqlalchemy.types.Numeric,
"timestamp":sqlalchemy.types.DateTime,
"date":sqlalchemy.types.Date,
}

with self.get_connection_cursor(connection) as cur:
Expand Down
38 changes: 0 additions & 38 deletionssrc/databricks/sqlalchemy/compiler.py
View file
Open in desktop

This file was deleted.

128 changes: 128 additions & 0 deletionssrc/databricks/sqlalchemy/test_local/test_types.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
import enum

import pytest
from sqlalchemy.types import (
BigInteger,
Boolean,
Date,
DateTime,
Double,
Enum,
Float,
Integer,
Interval,
LargeBinary,
MatchType,
Numeric,
PickleType,
SchemaType,
SmallInteger,
String,
Text,
Time,
TypeEngine,
Unicode,
UnicodeText,
Uuid,
)

from databricks.sqlalchemy import DatabricksDialect


class DatabricksDataType(enum.Enum):
"""https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html"""

BIGINT = enum.auto()
BINARY = enum.auto()
BOOLEAN = enum.auto()
DATE = enum.auto()
DECIMAL = enum.auto()
DOUBLE = enum.auto()
FLOAT = enum.auto()
INT = enum.auto()
INTERVAL = enum.auto()
VOID = enum.auto()
SMALLINT = enum.auto()
STRING = enum.auto()
TIMESTAMP = enum.auto()
TIMESTAMP_NTZ = enum.auto()
TINYINT = enum.auto()
ARRAY = enum.auto()
MAP = enum.auto()
STRUCT = enum.auto()


# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
# Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that.
camel_case_type_map = {
BigInteger: DatabricksDataType.BIGINT,
LargeBinary: DatabricksDataType.BINARY,
Boolean: DatabricksDataType.BOOLEAN,
Date: DatabricksDataType.DATE,
DateTime: DatabricksDataType.TIMESTAMP,
Double: DatabricksDataType.DOUBLE,
Enum: DatabricksDataType.STRING,
Float: DatabricksDataType.FLOAT,
Integer: DatabricksDataType.INT,
Interval: DatabricksDataType.TIMESTAMP,
Numeric: DatabricksDataType.DECIMAL,
PickleType: DatabricksDataType.BINARY,
SmallInteger: DatabricksDataType.SMALLINT,
String: DatabricksDataType.STRING,
Text: DatabricksDataType.STRING,
Time: DatabricksDataType.STRING,
Unicode: DatabricksDataType.STRING,
UnicodeText: DatabricksDataType.STRING,
Uuid: DatabricksDataType.STRING,
}

# Convert the dictionary into a list of tuples for use in pytest.mark.parametrize
_as_tuple_list = [(key, value) for key, value in camel_case_type_map.items()]


class CompilationTestBase:
dialect = DatabricksDialect()

def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType):
"""Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name.

This method initialises the type_ with no arguments.
"""
compiled_result = type_().compile(dialect=self.dialect) # type: ignore
assert compiled_result == expected.name

def _assert_compiled_value_explicit(self, type_: TypeEngine, expected: str):
"""Assert that when type_ is compiled for the databricks dialect, it renders the expected string.

This method expects an initialised type_ so that we can test how a TypeEngine created with arguments
is compiled.
"""
compiled_result = type_.compile(dialect=self.dialect)
assert compiled_result == expected


class TestCamelCaseTypesCompilation(CompilationTestBase):
"""Per the sqlalchemy documentation[^1] here, the camel case members of sqlalchemy.types are
are expected to work across all dialects. These tests verify that the types compile into valid
Databricks SQL type strings. For example, the sqlalchemy.types.Integer() should compile as "INT".

Truly custom types like STRUCT (notice the uppercase) are not expected to work across all dialects.
We test these separately.

Note that these tests have to do with type **name** compiliation. Which is separate from actually
mapping values between Python and Databricks.

Note: SchemaType and MatchType are not tested because it's not used in table definitions

[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types
"""

@pytest.mark.parametrize("type_, expected", _as_tuple_list)
def test_bare_camel_case_types_compile(self, type_, expected):
self._assert_compiled_value(type_, expected)

def test_numeric_renders_as_decimal_with_precision(self):
self._assert_compiled_value_explicit(Numeric(10), "DECIMAL(10)")

def test_numeric_renders_as_decimal_with_precision_and_scale(self):
self._assert_compiled_value_explicit(Numeric(10, 2), "DECIMAL(10, 2)")
75 changes: 75 additions & 0 deletionssrc/databricks/sqlalchemy/types.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import GenericTypeCompiler
from sqlalchemy.types import (
DateTime,
Enum,
Integer,
LargeBinary,
Numeric,
String,
Text,
Time,
Unicode,
UnicodeText,
Uuid,
)


@compiles(Enum, "databricks")
@compiles(String, "databricks")
@compiles(Text, "databricks")
@compiles(Time, "databricks")
@compiles(Unicode, "databricks")
@compiles(UnicodeText, "databricks")
@compiles(Uuid, "databricks")
def compile_string_databricks(type_, compiler, **kw):
"""
We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy
defaults to incompatible / abnormal compiled names

Enum -> VARCHAR
String -> VARCHAR[LENGTH]
Text -> VARCHAR[LENGTH]
Time -> TIME
Unicode -> VARCHAR[LENGTH]
UnicodeText -> TEXT
Uuid -> CHAR[32]

But all of these types will be compiled to STRING in Databricks SQL
"""
return "STRING"


@compiles(Integer, "databricks")
def compile_integer_databricks(type_, compiler, **kw):
"""
We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER"
"""
return "INT"


@compiles(LargeBinary, "databricks")
def compile_binary_databricks(type_, compiler, **kw):
"""
We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB"
"""
return "BINARY"


@compiles(Numeric, "databricks")
def compile_numeric_databricks(type_, compiler, **kw):
"""
We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC"

The built-in visit_DECIMAL behaviour captures the precision and scale. Here we're just mapping calls to compile Numeric
to the SQLAlchemy Decimal() implementation
"""
return compiler.visit_DECIMAL(type_, **kw)


@compiles(DateTime, "databricks")
def compile_datetime_databricks(type_, compiler, **kw):
"""
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME"
"""
return "TIMESTAMP"

[8]ページ先頭

©2009-2025 Movatter.jp