66
77from __future__import annotations
88
9+ import io
910import logging
1011from typing import (
1112List ,
1213Optional ,
1314Any ,
14- Callable ,
1515cast ,
1616TYPE_CHECKING ,
1717)
2020from databricks .sql .backend .sea .result_set import SeaResultSet
2121
2222from databricks .sql .backend .types import ExecuteResponse
23+ from databricks .sql .backend .sea .models .base import ResultData
24+ from databricks .sql .backend .sea .backend import SeaDatabricksClient
25+ from databricks .sql .utils import CloudFetchQueue ,ArrowQueue
26+
27+ try :
28+ import pyarrow
29+ import pyarrow .compute as pc
30+ except ImportError :
31+ pyarrow = None
32+ pc = None
2333
2434logger = logging .getLogger (__name__ )
2535
@@ -30,32 +40,18 @@ class ResultSetFilter:
3040 """
3141
3242@staticmethod
33- def _filter_sea_result_set (
34- result_set :SeaResultSet ,filter_func :Callable [[List [Any ]],bool ]
35- )-> SeaResultSet :
43+ def _create_execute_response (result_set :SeaResultSet )-> ExecuteResponse :
3644"""
37- Filter a SEA result set using theprovided filter function .
45+ Create an ExecuteResponse with parameters from theoriginal result set .
3846
3947 Args:
40- result_set: The SEA result set to filter
41- filter_func: Function that takes a row and returns True if the row should be included
48+ result_set: Original result set to copy parameters from
4249
4350 Returns:
44- A filtered SEA result set
51+ ExecuteResponse: New execute response object
4552 """
46-
47- # Get all remaining rows
48- all_rows = result_set .results .remaining_rows ()
49-
50- # Filter rows
51- filtered_rows = [row for row in all_rows if filter_func (row )]
52-
53- # Reuse the command_id from the original result set
54- command_id = result_set .command_id
55-
56- # Create an ExecuteResponse for the filtered data
57- execute_response = ExecuteResponse (
58- command_id = command_id ,
53+ return ExecuteResponse (
54+ command_id = result_set .command_id ,
5955status = result_set .status ,
6056description = result_set .description ,
6157has_been_closed_server_side = result_set .has_been_closed_server_side ,
@@ -64,39 +60,145 @@ def _filter_sea_result_set(
6460is_staging_operation = False ,
6561 )
6662
67- # Create a new ResultData object with filtered data
68- from databricks .sql .backend .sea .models .base import ResultData
63+ @staticmethod
64+ def _update_manifest (result_set :SeaResultSet ,new_row_count :int ):
65+ """
66+ Create a copy of the manifest with updated row count.
6967
70- result_data = ResultData (data = filtered_rows ,external_links = None )
68+ Args:
69+ result_set: Original result set to copy manifest from
70+ new_row_count: New total row count for filtered data
71+
72+ Returns:
73+ Updated manifest copy
74+ """
75+ filtered_manifest = result_set .manifest
76+ filtered_manifest .total_row_count = new_row_count
77+ return filtered_manifest
78+
79+ @staticmethod
80+ def _create_filtered_result_set (
81+ result_set :SeaResultSet ,
82+ result_data :ResultData ,
83+ row_count :int ,
84+ )-> "SeaResultSet" :
85+ """
86+ Create a new filtered SeaResultSet with the provided data.
7187
72- from databricks .sql .backend .sea .backend import SeaDatabricksClient
88+ Args:
89+ result_set: Original result set to copy parameters from
90+ result_data: New result data for the filtered set
91+ row_count: Number of rows in the filtered data
92+
93+ Returns:
94+ New filtered SeaResultSet
95+ """
7396from databricks .sql .backend .sea .result_set import SeaResultSet
7497
75- # Create a new SeaResultSet with the filtered data
76- manifest = result_set .manifest
77- manifest .total_row_count = len (filtered_rows )
98+ execute_response = ResultSetFilter ._create_execute_response (result_set )
99+ filtered_manifest = ResultSetFilter ._update_manifest (result_set ,row_count )
78100
79- filtered_result_set = SeaResultSet (
101+ return SeaResultSet (
80102connection = result_set .connection ,
81103execute_response = execute_response ,
82104sea_client = cast (SeaDatabricksClient ,result_set .backend ),
83105result_data = result_data ,
84- manifest = manifest ,
106+ manifest = filtered_manifest ,
85107buffer_size_bytes = result_set .buffer_size_bytes ,
86108arraysize = result_set .arraysize ,
87109 )
88110
89- # Preserve metadata columns setup from original result set
90- if hasattr (result_set ,"_metadata_columns" )and result_set ._metadata_columns :
91- filtered_result_set ._metadata_columns = result_set ._metadata_columns
92- filtered_result_set ._column_index_mapping = result_set ._column_index_mapping
93- # Update the description to match the original prepared description
94- filtered_result_set .description = result_set .description
111+ @staticmethod
112+ def _filter_arrow_table (
113+ table :Any ,# pyarrow.Table
114+ column_name :str ,
115+ allowed_values :List [str ],
116+ case_sensitive :bool = True ,
117+ )-> Any :# returns pyarrow.Table
118+ """
119+ Filter a PyArrow table by column values.
120+
121+ Args:
122+ table: The PyArrow table to filter
123+ column_name: The name of the column to filter on
124+ allowed_values: List of allowed values for the column
125+ case_sensitive: Whether to perform case-sensitive comparison
126+
127+ Returns:
128+ A filtered PyArrow table
129+ """
130+ if not pyarrow :
131+ raise ImportError ("PyArrow is required for Arrow table filtering" )
132+
133+ if table .num_rows == 0 :
134+ return table
135+
136+ # Handle case-insensitive filtering by normalizing both column and allowed values
137+ if not case_sensitive :
138+ # Convert allowed values to uppercase
139+ allowed_values = [v .upper ()for v in allowed_values ]
140+ # Get column values as uppercase
141+ column = pc .utf8_upper (table [column_name ])
142+ else :
143+ # Use column as-is
144+ column = table [column_name ]
145+
146+ # Convert allowed_values to PyArrow Array
147+ allowed_array = pyarrow .array (allowed_values )
148+
149+ # Construct a boolean mask: True where column is in allowed_list
150+ mask = pc .is_in (column ,value_set = allowed_array )
151+ return table .filter (mask )
152+
153+ @staticmethod
154+ def _filter_arrow_result_set (
155+ result_set :SeaResultSet ,
156+ column_index :int ,
157+ allowed_values :List [str ],
158+ case_sensitive :bool = True ,
159+ )-> SeaResultSet :
160+ """
161+ Filter a SEA result set that contains Arrow tables.
162+
163+ Args:
164+ result_set: The SEA result set to filter (containing Arrow data)
165+ column_index: The index of the column to filter on
166+ allowed_values: List of allowed values for the column
167+ case_sensitive: Whether to perform case-sensitive comparison
168+
169+ Returns:
170+ A filtered SEA result set
171+ """
172+ # Validate column index and get column name
173+ if column_index >= len (result_set .description ):
174+ raise ValueError (f"Column index{ column_index } is out of bounds" )
175+ column_name = result_set .description [column_index ][0 ]
176+
177+ # Get all remaining rows as Arrow table and filter it
178+ arrow_table = result_set .results .remaining_rows ()
179+ filtered_table = ResultSetFilter ._filter_arrow_table (
180+ arrow_table ,column_name ,allowed_values ,case_sensitive
181+ )
182+
183+ # Convert the filtered table to Arrow stream format for ResultData
184+ sink = io .BytesIO ()
185+ with pyarrow .ipc .new_stream (sink ,filtered_table .schema )as writer :
186+ writer .write_table (filtered_table )
187+ arrow_stream_bytes = sink .getvalue ()
188+
189+ # Create ResultData with attachment containing the filtered data
190+ result_data = ResultData (
191+ data = None ,# No JSON data
192+ external_links = None ,# No external links
193+ attachment = arrow_stream_bytes ,# Arrow data as attachment
194+ )
95195
96- return filtered_result_set
196+ return ResultSetFilter ._create_filtered_result_set (
197+ result_set ,result_data ,filtered_table .num_rows
198+ )
97199
98200@staticmethod
99- def filter_by_column_values (
201+ def _filter_json_result_set (
100202result_set :SeaResultSet ,
101203column_index :int ,
102204allowed_values :List [str ],
@@ -114,22 +216,35 @@ def filter_by_column_values(
114216 Returns:
115217 A filtered result set
116218 """
219+ # Validate column index (optional - not in arrow version but good practice)
220+ if column_index >= len (result_set .description ):
221+ raise ValueError (f"Column index{ column_index } is out of bounds" )
117222
118- # Convert to uppercase for case-insensitive comparison if needed
223+ # Extract rows
224+ all_rows = result_set .results .remaining_rows ()
225+
226+ # Convert allowed values if case-insensitive
119227if not case_sensitive :
120228allowed_values = [v .upper ()for v in allowed_values ]
229+ # Helper lambda to get column value based on case sensitivity
230+ get_column_value = (
231+ lambda row :row [column_index ].upper ()
232+ if not case_sensitive
233+ else row [column_index ]
234+ )
235+
236+ # Filter rows based on allowed values
237+ filtered_rows = [
238+ row
239+ for row in all_rows
240+ if len (row )> column_index and get_column_value (row )in allowed_values
241+ ]
242+
243+ # Create filtered result set
244+ result_data = ResultData (data = filtered_rows ,external_links = None )
121245
122- return ResultSetFilter ._filter_sea_result_set (
123- result_set ,
124- lambda row : (
125- len (row )> column_index
126- and (
127- row [column_index ].upper ()
128- if not case_sensitive
129- else row [column_index ]
130- )
131- in allowed_values
132- ),
246+ return ResultSetFilter ._create_filtered_result_set (
247+ result_set ,result_data ,len (filtered_rows )
133248 )
134249
135250@staticmethod
@@ -150,14 +265,25 @@ def filter_tables_by_type(
150265 Returns:
151266 A filtered result set containing only tables of the specified types
152267 """
153-
154268# Default table types if none specified
155269DEFAULT_TABLE_TYPES = ["TABLE" ,"VIEW" ,"SYSTEM TABLE" ]
156- valid_types = (
157- table_types if table_types and len (table_types )> 0 else DEFAULT_TABLE_TYPES
158- )
270+ valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
159271
272+ # Check if we have an Arrow table (cloud fetch) or JSON data
160273# Table type is the 6th column (index 5)
161- return ResultSetFilter .filter_by_column_values (
162- result_set ,5 ,valid_types ,case_sensitive = True
163- )
274+ if isinstance (result_set .results , (CloudFetchQueue ,ArrowQueue )):
275+ # For Arrow tables, we need to handle filtering differently
276+ return ResultSetFilter ._filter_arrow_result_set (
277+ result_set ,
278+ column_index = 5 ,
279+ allowed_values = valid_types ,
280+ case_sensitive = True ,
281+ )
282+ else :
283+ # For JSON data, use the existing filter method
284+ return ResultSetFilter ._filter_json_result_set (
285+ result_set ,
286+ column_index = 5 ,
287+ allowed_values = valid_types ,
288+ case_sensitive = True ,
289+ )