33
44import sqlalchemy
55from sqlalchemy import event ,DDL
6- from sqlalchemy .engine import Engine ,default ,reflection
6+ from sqlalchemy .engine import Engine ,default ,reflection , Connection , Row , CursorResult
77from sqlalchemy .engine .interfaces import (
88ReflectedForeignKeyConstraint ,
99ReflectedPrimaryKeyConstraint ,
3131class DatabricksImpl (DefaultImpl ):
3232__dialect__ = "databricks"
3333
34+
3435DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
3536DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
3637
38+
39+ def _match_table_not_found_string (message :str )-> bool :
40+ """Return True if the message contains a substring indicating that a table was not found"""
41+ return any (
42+ [
43+ DBR_LTE_12_NOT_FOUND_STRING in message ,
44+ DBR_GT_12_NOT_FOUND_STRING in message ,
45+ ]
46+ )
47+
48+
49+ def _describe_table_extended_result_to_dict (result :CursorResult )-> dict :
50+ """Transform the output of DESCRIBE TABLE EXTENDED into a dictionary
51+
52+ The output from DESCRIBE TABLE EXTENDED puts all values in the `data_type` column
53+ Even CONSTRAINT descriptions are contained in the `data_type` column
54+ Some rows have an empty string for their col_name. These are present only for spacing
55+ so we ignore them.
56+ """
57+
58+ result_dict = {row .col_name :row .data_type for row in result if row .col_name != "" }
59+
60+ return result_dict
61+
62+
3763COLUMN_TYPE_MAP = {
3864"boolean" :sqlalchemy .types .Boolean ,
3965"smallint" :sqlalchemy .types .SmallInteger ,
@@ -54,6 +80,7 @@ class DatabricksImpl(DefaultImpl):
5480"date" :sqlalchemy .types .Date ,
5581}
5682
83+
5784class DatabricksDialect (default .DefaultDialect ):
5885"""This dialect implements only those methods required to pass our e2e tests"""
5986
@@ -156,6 +183,46 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
156183
157184return columns
158185
186+ def _describe_table_extended (
187+ self ,
188+ connection :Connection ,
189+ table_name :str ,
190+ catalog_name :Optional [str ]= None ,
191+ schema_name :Optional [str ]= None ,
192+ expect_result = True ,
193+ ):
194+ """Run DESCRIBE TABLE EXTENDED on a table and return a dictionary of the result.
195+
196+ This method is the fastest way to check for the presence of a table in a schema.
197+
198+ If expect_result is False, this method returns None as the output dict isn't required.
199+
200+ Raises NoSuchTableError if the table is not present in the schema.
201+ """
202+
203+ _target_catalog = catalog_name or self .catalog
204+ _target_schema = schema_name or self .schema
205+ _target = f"`{ _target_catalog } `.`{ _target_schema } `.`{ table_name } `"
206+
207+ # sql injection risk?
208+ # DESCRIBE TABLE EXTENDED in DBR doesn't support parameterised inputs :(
209+ stmt = DDL (f"DESCRIBE TABLE EXTENDED{ _target } " )
210+
211+ try :
212+ result = connection .execute (stmt ).all ()
213+ except DatabaseError as e :
214+ if _match_table_not_found_string (str (e )):
215+ raise sqlalchemy .exc .NoSuchTableError (
216+ f"No such table{ table_name } "
217+ )from e
218+ raise e
219+
220+ if not expect_result :
221+ return None
222+
223+ fmt_result = _describe_table_extended_result_to_dict (result )
224+ return fmt_result
225+
159226@reflection .cache
160227def get_pk_constraint (
161228self ,
@@ -169,16 +236,18 @@ def get_pk_constraint(
169236 """
170237
171238with self .get_connection_cursor (connection )as cursor :
172-
173239try :
174240# DESCRIBE TABLE EXTENDED doesn't support parameterised inputs :(
175- result = cursor .execute (f"DESCRIBE TABLE EXTENDED{ table_name } " ).fetchall ()
241+ result = cursor .execute (
242+ f"DESCRIBE TABLE EXTENDED{ table_name } "
243+ ).fetchall ()
176244except ServerOperationError as e :
177245if DBR_GT_12_NOT_FOUND_STRING in str (
178246e
179247 )or DBR_LTE_12_NOT_FOUND_STRING in str (e ):
180- raise sqlalchemy .exc .NoSuchTableError (f"No such table{ table_name } " )from e
181-
248+ raise sqlalchemy .exc .NoSuchTableError (
249+ f"No such table{ table_name } "
250+ )from e
182251
183252# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
184253# a primary key constraint will be found in its output. So we cycle through its
@@ -237,7 +306,9 @@ def get_foreign_keys(
237306if DBR_GT_12_NOT_FOUND_STRING in str (
238307e
239308 )or DBR_LTE_12_NOT_FOUND_STRING in str (e ):
240- raise sqlalchemy .exc .NoSuchTableError (f"No such table{ table_name } " )from e
309+ raise sqlalchemy .exc .NoSuchTableError (
310+ f"No such table{ table_name } "
311+ )from e
241312
242313# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
243314# a foreign key constraint will be found in its output. So we cycle through its
@@ -333,29 +404,20 @@ def do_rollback(self, dbapi_connection):
333404def has_table (
334405self ,connection ,table_name ,schema = None ,catalog = None ,** kwargs
335406 )-> bool :
336- """SQLAlchemy docstrings say dialect providers must implement this method"""
337-
338- _schema = schema or self .schema
339- _catalog = catalog or self .catalog
340-
341- # DBR >12.x uses underscores in error messages
342- DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
343- DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
407+ """For internal dialect use, check the existence of a particular table
408+ or view in the database.
409+ """
344410
345411try :
346- res = connection .execute (
347- sqlalchemy .text (
348- f"DESCRIBE TABLE `{ _catalog } `.`{ _schema } `.`{ table_name } `"
349- )
412+ self ._describe_table_extended (
413+ connection = connection ,
414+ table_name = table_name ,
415+ catalog_name = catalog ,
416+ schema_name = schema ,
350417 )
351418return True
352- except DatabaseError as e :
353- if DBR_GT_12_NOT_FOUND_STRING in str (
354- e
355- )or DBR_LTE_12_NOT_FOUND_STRING in str (e ):
356- return False
357- else :
358- raise e
419+ except sqlalchemy .exc .NoSuchTableError as e :
420+ return False
359421
360422def get_connection_cursor (self ,connection ):
361423"""Added for backwards compatibility with 1.3.x"""
@@ -372,8 +434,7 @@ def get_connection_cursor(self, connection):
372434
373435@reflection .cache
374436def get_schema_names (self ,connection ,** kw ):
375- """Return a list of all schema names available in the database.
376- """
437+ """Return a list of all schema names available in the database."""
377438stmt = DDL ("SHOW SCHEMAS" )
378439result = connection .execute (stmt )
379440schema_list = [row [0 ]for row in result ]