Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9fb0444

Browse files
Merge branch 'main' into col-normalisation
2 parentsb2ae83c +36d3ec4 commit9fb0444

File tree

8 files changed

+971
-240
lines changed

8 files changed

+971
-240
lines changed

‎src/databricks/sql/auth/retry.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
total=_attempts_remaining,
128128
respect_retry_after_header=True,
129129
backoff_factor=self.delay_min,
130-
allowed_methods=["POST"],
130+
allowed_methods=["POST","GET","DELETE"],
131131
status_forcelist=[429,503,*self.force_dangerous_codes],
132132
)
133133

‎src/databricks/sql/backend/sea/backend.py‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
)
160160

161161
self.use_hybrid_disposition=kwargs.get("use_hybrid_disposition",True)
162+
self.use_cloud_fetch=kwargs.get("use_cloud_fetch",True)
162163

163164
# Extract warehouse ID from http_path
164165
self.warehouse_id=self._extract_warehouse_id(http_path)
@@ -695,7 +696,7 @@ def get_catalogs(
695696
max_bytes=max_bytes,
696697
lz4_compression=False,
697698
cursor=cursor,
698-
use_cloud_fetch=False,
699+
use_cloud_fetch=self.use_cloud_fetch,
699700
parameters=[],
700701
async_op=False,
701702
enforce_embedded_schema_correctness=False,
@@ -731,7 +732,7 @@ def get_schemas(
731732
max_bytes=max_bytes,
732733
lz4_compression=False,
733734
cursor=cursor,
734-
use_cloud_fetch=False,
735+
use_cloud_fetch=self.use_cloud_fetch,
735736
parameters=[],
736737
async_op=False,
737738
enforce_embedded_schema_correctness=False,
@@ -775,7 +776,7 @@ def get_tables(
775776
max_bytes=max_bytes,
776777
lz4_compression=False,
777778
cursor=cursor,
778-
use_cloud_fetch=False,
779+
use_cloud_fetch=self.use_cloud_fetch,
779780
parameters=[],
780781
async_op=False,
781782
enforce_embedded_schema_correctness=False,
@@ -825,7 +826,7 @@ def get_columns(
825826
max_bytes=max_bytes,
826827
lz4_compression=False,
827828
cursor=cursor,
828-
use_cloud_fetch=False,
829+
use_cloud_fetch=self.use_cloud_fetch,
829830
parameters=[],
830831
async_op=False,
831832
enforce_embedded_schema_correctness=False,

‎src/databricks/sql/backend/sea/utils/filters.py‎

Lines changed: 183 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
from __future__importannotations
88

9+
importio
910
importlogging
1011
fromtypingimport (
1112
List,
1213
Optional,
1314
Any,
14-
Callable,
1515
cast,
1616
TYPE_CHECKING,
1717
)
@@ -20,6 +20,16 @@
2020
fromdatabricks.sql.backend.sea.result_setimportSeaResultSet
2121

2222
fromdatabricks.sql.backend.typesimportExecuteResponse
23+
fromdatabricks.sql.backend.sea.models.baseimportResultData
24+
fromdatabricks.sql.backend.sea.backendimportSeaDatabricksClient
25+
fromdatabricks.sql.utilsimportCloudFetchQueue,ArrowQueue
26+
27+
try:
28+
importpyarrow
29+
importpyarrow.computeaspc
30+
exceptImportError:
31+
pyarrow=None
32+
pc=None
2333

2434
logger=logging.getLogger(__name__)
2535

@@ -30,32 +40,18 @@ class ResultSetFilter:
3040
"""
3141

3242
@staticmethod
33-
def_filter_sea_result_set(
34-
result_set:SeaResultSet,filter_func:Callable[[List[Any]],bool]
35-
)->SeaResultSet:
43+
def_create_execute_response(result_set:SeaResultSet)->ExecuteResponse:
3644
"""
37-
Filter a SEA result set using theprovided filter function.
45+
Create an ExecuteResponse with parameters from theoriginal result set.
3846
3947
Args:
40-
result_set: The SEA result set to filter
41-
filter_func: Function that takes a row and returns True if the row should be included
48+
result_set: Original result set to copy parameters from
4249
4350
Returns:
44-
A filtered SEA result set
51+
ExecuteResponse: New execute response object
4552
"""
46-
47-
# Get all remaining rows
48-
all_rows=result_set.results.remaining_rows()
49-
50-
# Filter rows
51-
filtered_rows= [rowforrowinall_rowsiffilter_func(row)]
52-
53-
# Reuse the command_id from the original result set
54-
command_id=result_set.command_id
55-
56-
# Create an ExecuteResponse for the filtered data
57-
execute_response=ExecuteResponse(
58-
command_id=command_id,
53+
returnExecuteResponse(
54+
command_id=result_set.command_id,
5955
status=result_set.status,
6056
description=result_set.description,
6157
has_been_closed_server_side=result_set.has_been_closed_server_side,
@@ -64,39 +60,145 @@ def _filter_sea_result_set(
6460
is_staging_operation=False,
6561
)
6662

67-
# Create a new ResultData object with filtered data
68-
fromdatabricks.sql.backend.sea.models.baseimportResultData
63+
@staticmethod
64+
def_update_manifest(result_set:SeaResultSet,new_row_count:int):
65+
"""
66+
Create a copy of the manifest with updated row count.
6967
70-
result_data=ResultData(data=filtered_rows,external_links=None)
68+
Args:
69+
result_set: Original result set to copy manifest from
70+
new_row_count: New total row count for filtered data
71+
72+
Returns:
73+
Updated manifest copy
74+
"""
75+
filtered_manifest=result_set.manifest
76+
filtered_manifest.total_row_count=new_row_count
77+
returnfiltered_manifest
78+
79+
@staticmethod
80+
def_create_filtered_result_set(
81+
result_set:SeaResultSet,
82+
result_data:ResultData,
83+
row_count:int,
84+
)->"SeaResultSet":
85+
"""
86+
Create a new filtered SeaResultSet with the provided data.
7187
72-
fromdatabricks.sql.backend.sea.backendimportSeaDatabricksClient
88+
Args:
89+
result_set: Original result set to copy parameters from
90+
result_data: New result data for the filtered set
91+
row_count: Number of rows in the filtered data
92+
93+
Returns:
94+
New filtered SeaResultSet
95+
"""
7396
fromdatabricks.sql.backend.sea.result_setimportSeaResultSet
7497

75-
# Create a new SeaResultSet with the filtered data
76-
manifest=result_set.manifest
77-
manifest.total_row_count=len(filtered_rows)
98+
execute_response=ResultSetFilter._create_execute_response(result_set)
99+
filtered_manifest=ResultSetFilter._update_manifest(result_set,row_count)
78100

79-
filtered_result_set=SeaResultSet(
101+
returnSeaResultSet(
80102
connection=result_set.connection,
81103
execute_response=execute_response,
82104
sea_client=cast(SeaDatabricksClient,result_set.backend),
83105
result_data=result_data,
84-
manifest=manifest,
106+
manifest=filtered_manifest,
85107
buffer_size_bytes=result_set.buffer_size_bytes,
86108
arraysize=result_set.arraysize,
87109
)
88110

89-
# Preserve metadata columns setup from original result set
90-
ifhasattr(result_set,"_metadata_columns")andresult_set._metadata_columns:
91-
filtered_result_set._metadata_columns=result_set._metadata_columns
92-
filtered_result_set._column_index_mapping=result_set._column_index_mapping
93-
# Update the description to match the original prepared description
94-
filtered_result_set.description=result_set.description
111+
@staticmethod
112+
def_filter_arrow_table(
113+
table:Any,# pyarrow.Table
114+
column_name:str,
115+
allowed_values:List[str],
116+
case_sensitive:bool=True,
117+
)->Any:# returns pyarrow.Table
118+
"""
119+
Filter a PyArrow table by column values.
120+
121+
Args:
122+
table: The PyArrow table to filter
123+
column_name: The name of the column to filter on
124+
allowed_values: List of allowed values for the column
125+
case_sensitive: Whether to perform case-sensitive comparison
126+
127+
Returns:
128+
A filtered PyArrow table
129+
"""
130+
ifnotpyarrow:
131+
raiseImportError("PyArrow is required for Arrow table filtering")
132+
133+
iftable.num_rows==0:
134+
returntable
135+
136+
# Handle case-insensitive filtering by normalizing both column and allowed values
137+
ifnotcase_sensitive:
138+
# Convert allowed values to uppercase
139+
allowed_values= [v.upper()forvinallowed_values]
140+
# Get column values as uppercase
141+
column=pc.utf8_upper(table[column_name])
142+
else:
143+
# Use column as-is
144+
column=table[column_name]
145+
146+
# Convert allowed_values to PyArrow Array
147+
allowed_array=pyarrow.array(allowed_values)
148+
149+
# Construct a boolean mask: True where column is in allowed_list
150+
mask=pc.is_in(column,value_set=allowed_array)
151+
returntable.filter(mask)
152+
153+
@staticmethod
154+
def_filter_arrow_result_set(
155+
result_set:SeaResultSet,
156+
column_index:int,
157+
allowed_values:List[str],
158+
case_sensitive:bool=True,
159+
)->SeaResultSet:
160+
"""
161+
Filter a SEA result set that contains Arrow tables.
162+
163+
Args:
164+
result_set: The SEA result set to filter (containing Arrow data)
165+
column_index: The index of the column to filter on
166+
allowed_values: List of allowed values for the column
167+
case_sensitive: Whether to perform case-sensitive comparison
168+
169+
Returns:
170+
A filtered SEA result set
171+
"""
172+
# Validate column index and get column name
173+
ifcolumn_index>=len(result_set.description):
174+
raiseValueError(f"Column index{column_index} is out of bounds")
175+
column_name=result_set.description[column_index][0]
176+
177+
# Get all remaining rows as Arrow table and filter it
178+
arrow_table=result_set.results.remaining_rows()
179+
filtered_table=ResultSetFilter._filter_arrow_table(
180+
arrow_table,column_name,allowed_values,case_sensitive
181+
)
182+
183+
# Convert the filtered table to Arrow stream format for ResultData
184+
sink=io.BytesIO()
185+
withpyarrow.ipc.new_stream(sink,filtered_table.schema)aswriter:
186+
writer.write_table(filtered_table)
187+
arrow_stream_bytes=sink.getvalue()
188+
189+
# Create ResultData with attachment containing the filtered data
190+
result_data=ResultData(
191+
data=None,# No JSON data
192+
external_links=None,# No external links
193+
attachment=arrow_stream_bytes,# Arrow data as attachment
194+
)
95195

96-
returnfiltered_result_set
196+
returnResultSetFilter._create_filtered_result_set(
197+
result_set,result_data,filtered_table.num_rows
198+
)
97199

98200
@staticmethod
99-
deffilter_by_column_values(
201+
def_filter_json_result_set(
100202
result_set:SeaResultSet,
101203
column_index:int,
102204
allowed_values:List[str],
@@ -114,22 +216,35 @@ def filter_by_column_values(
114216
Returns:
115217
A filtered result set
116218
"""
219+
# Validate column index (optional - not in arrow version but good practice)
220+
ifcolumn_index>=len(result_set.description):
221+
raiseValueError(f"Column index{column_index} is out of bounds")
117222

118-
# Convert to uppercase for case-insensitive comparison if needed
223+
# Extract rows
224+
all_rows=result_set.results.remaining_rows()
225+
226+
# Convert allowed values if case-insensitive
119227
ifnotcase_sensitive:
120228
allowed_values= [v.upper()forvinallowed_values]
229+
# Helper lambda to get column value based on case sensitivity
230+
get_column_value= (
231+
lambdarow:row[column_index].upper()
232+
ifnotcase_sensitive
233+
elserow[column_index]
234+
)
235+
236+
# Filter rows based on allowed values
237+
filtered_rows= [
238+
row
239+
forrowinall_rows
240+
iflen(row)>column_indexandget_column_value(row)inallowed_values
241+
]
242+
243+
# Create filtered result set
244+
result_data=ResultData(data=filtered_rows,external_links=None)
121245

122-
returnResultSetFilter._filter_sea_result_set(
123-
result_set,
124-
lambdarow: (
125-
len(row)>column_index
126-
and (
127-
row[column_index].upper()
128-
ifnotcase_sensitive
129-
elserow[column_index]
130-
)
131-
inallowed_values
132-
),
246+
returnResultSetFilter._create_filtered_result_set(
247+
result_set,result_data,len(filtered_rows)
133248
)
134249

135250
@staticmethod
@@ -150,14 +265,25 @@ def filter_tables_by_type(
150265
Returns:
151266
A filtered result set containing only tables of the specified types
152267
"""
153-
154268
# Default table types if none specified
155269
DEFAULT_TABLE_TYPES= ["TABLE","VIEW","SYSTEM TABLE"]
156-
valid_types= (
157-
table_typesiftable_typesandlen(table_types)>0elseDEFAULT_TABLE_TYPES
158-
)
270+
valid_types=table_typesiftable_typeselseDEFAULT_TABLE_TYPES
159271

272+
# Check if we have an Arrow table (cloud fetch) or JSON data
160273
# Table type is the 6th column (index 5)
161-
returnResultSetFilter.filter_by_column_values(
162-
result_set,5,valid_types,case_sensitive=True
163-
)
274+
ifisinstance(result_set.results, (CloudFetchQueue,ArrowQueue)):
275+
# For Arrow tables, we need to handle filtering differently
276+
returnResultSetFilter._filter_arrow_result_set(
277+
result_set,
278+
column_index=5,
279+
allowed_values=valid_types,
280+
case_sensitive=True,
281+
)
282+
else:
283+
# For JSON data, use the existing filter method
284+
returnResultSetFilter._filter_json_result_set(
285+
result_set,
286+
column_index=5,
287+
allowed_values=valid_types,
288+
case_sensitive=True,
289+
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp