11from __future__import annotations
22
3- from typing import Any ,List ,Optional ,TYPE_CHECKING
3+ from typing import Any ,List ,Optional ,TYPE_CHECKING , Dict
44
55import logging
66
77from databricks .sql .backend .sea .models .base import ResultData ,ResultManifest
88from databricks .sql .backend .sea .utils .conversion import SqlTypeConverter
9+ from databricks .sql .backend .sea .utils .result_column import ResultColumn
910
1011try :
1112import pyarrow
@@ -82,6 +83,10 @@ def __init__(
8283arrow_schema_bytes = execute_response .arrow_schema_bytes ,
8384 )
8485
86+ # Initialize metadata columns for post-fetch transformation
87+ self ._metadata_columns = None
88+ self ._column_index_mapping = None
89+
8590def _convert_json_types (self ,row :List [str ])-> List [Any ]:
8691"""
8792 Convert string values in the row to appropriate Python types based on column metadata.
@@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
160165raise ValueError (f"size argument for fetchmany is{ size } but must be >= 0" )
161166
162167results = self .results .next_n_rows (size )
168+ results = self ._transform_json_rows (results )
163169self ._next_row_index += len (results )
164170
165171return results
@@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]:
173179 """
174180
175181results = self .results .remaining_rows ()
182+ results = self ._transform_json_rows (results )
176183self ._next_row_index += len (results )
177184
178185return results
@@ -197,7 +204,12 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
197204
198205results = self .results .next_n_rows (size )
199206if isinstance (self .results ,JsonQueue ):
200- results = self ._convert_json_to_arrow_table (results )
207+ # Transform JSON first, then convert to Arrow
208+ transformed_json = self ._transform_json_rows (results )
209+ results = self ._convert_json_to_arrow_table (transformed_json )
210+ else :
211+ # Transform Arrow table directly
212+ results = self ._transform_arrow_table (results )
201213
202214self ._next_row_index += results .num_rows
203215
@@ -210,7 +222,12 @@ def fetchall_arrow(self) -> "pyarrow.Table":
210222
211223results = self .results .remaining_rows ()
212224if isinstance (self .results ,JsonQueue ):
213- results = self ._convert_json_to_arrow_table (results )
225+ # Transform JSON first, then convert to Arrow
226+ transformed_json = self ._transform_json_rows (results )
227+ results = self ._convert_json_to_arrow_table (transformed_json )
228+ else :
229+ # Transform Arrow table directly
230+ results = self ._transform_arrow_table (results )
214231
215232self ._next_row_index += results .num_rows
216233
@@ -263,3 +280,108 @@ def fetchall(self) -> List[Row]:
263280return self ._create_json_table (self .fetchall_json ())
264281else :
265282return self ._convert_arrow_table (self .fetchall_arrow ())
283+
284+ def prepare_metadata_columns (self ,metadata_columns :List [ResultColumn ])-> None :
285+ """
286+ Prepare result set for metadata column normalization.
287+
288+ Args:
289+ metadata_columns: List of ResultColumn objects defining the expected columns
290+ and their mappings from SEA column names
291+ """
292+ self ._metadata_columns = metadata_columns
293+ self ._prepare_column_mapping ()
294+
295+ def _prepare_column_mapping (self )-> None :
296+ """
297+ Prepare column index mapping for metadata queries.
298+ Updates description to use JDBC column names.
299+ """
300+ # Ensure description is available
301+ if not self .description :
302+ raise ValueError ("Cannot prepare column mapping without result description" )
303+
304+ # Build mapping from SEA column names to their indices
305+ sea_column_indices = {}
306+ for idx ,col in enumerate (self .description ):
307+ sea_column_indices [col [0 ]]= idx
308+
309+ # Create new description and index mapping
310+ new_description = []
311+ self ._column_index_mapping = {}# Maps new index -> old index
312+
313+ for new_idx ,result_column in enumerate (self ._metadata_columns ):
314+ # Find the corresponding SEA column
315+ if (
316+ result_column .result_set_column_name
317+ and result_column .result_set_column_name in sea_column_indices
318+ ):
319+ old_idx = sea_column_indices [result_column .result_set_column_name ]
320+ self ._column_index_mapping [new_idx ]= old_idx
321+ # Use the original column metadata but with JDBC name
322+ old_col = self .description [old_idx ]
323+ new_description .append (
324+ (
325+ result_column .column_name ,# JDBC name
326+ result_column .column_type ,# Expected type
327+ old_col [2 ],# display_size
328+ old_col [3 ],# internal_size
329+ old_col [4 ],# precision
330+ old_col [5 ],# scale
331+ old_col [6 ],# null_ok
332+ )
333+ )
334+ else :
335+ # Column doesn't exist in SEA - add with None values
336+ new_description .append (
337+ (
338+ result_column .column_name ,
339+ result_column .column_type ,
340+ None ,
341+ None ,
342+ None ,
343+ None ,
344+ True ,
345+ )
346+ )
347+ self ._column_index_mapping [new_idx ]= None
348+
349+ self .description = new_description
350+
351+ def _transform_arrow_table (self ,table :"pyarrow.Table" )-> "pyarrow.Table" :
352+ """Transform arrow table columns for metadata normalization."""
353+ if not self ._metadata_columns :
354+ return table
355+
356+ # Reorder columns and add missing ones
357+ new_columns = []
358+ column_names = []
359+
360+ for new_idx ,result_column in enumerate (self ._metadata_columns ):
361+ old_idx = self ._column_index_mapping .get (new_idx )
362+ if old_idx is not None :
363+ new_columns .append (table .column (old_idx ))
364+ else :
365+ # Create null column for missing data
366+ null_array = pyarrow .nulls (table .num_rows )
367+ new_columns .append (null_array )
368+ column_names .append (result_column .column_name )
369+
370+ return pyarrow .Table .from_arrays (new_columns ,names = column_names )
371+
372+ def _transform_json_rows (self ,rows :List [List [str ]])-> List [List [Any ]]:
373+ """Transform JSON rows for metadata normalization."""
374+ if not self ._metadata_columns :
375+ return rows
376+
377+ transformed_rows = []
378+ for row in rows :
379+ new_row = []
380+ for new_idx in range (len (self ._metadata_columns )):
381+ old_idx = self ._column_index_mapping .get (new_idx )
382+ if old_idx is not None :
383+ new_row .append (row [old_idx ])
384+ else :
385+ new_row .append (None )
386+ transformed_rows .append (new_row )
387+ return transformed_rows