Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Test parameter escaping#46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
susodapop merged 10 commits intomainfromtest-param-escaping
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletionssrc/databricks/sql/client.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -7,7 +7,7 @@
from databricks.sql import *
from databricks.sql.exc import OperationalError
from databricks.sql.thrift_backend import ThriftBackend
from databricks.sql.utils import ExecuteResponse, ParamEscaper
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
from databricks.sql.types import Row
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand DownExpand Up@@ -309,7 +309,9 @@ def execute(
:returns self
"""
if parameters is not None:
operation = operation % self.escaper.escape_args(parameters)
operation = inject_parameters(
operation, self.escaper.escape_args(parameters)
)

self._check_not_closed()
self._close_and_clear_active_result_set()
Expand Down
8 changes: 6 additions & 2 deletionssrc/databricks/sql/utils.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -2,7 +2,7 @@
from collections.abc import Iterable
import datetime
from enum import Enum

from typing import Dict
import pyarrow


Expand DownExpand Up@@ -146,7 +146,7 @@ def escape_string(self, item):
# This is good enough when backslashes are literal, newlines are just followed, and the way
# to escape a single quote is to put two single quotes.
# (i.e. only special character is single quote)
return "'{}'".format(item.replace("'", "''"))
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))

def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
Expand All@@ -172,3 +172,7 @@ def escape_item(self, item):
return self.escape_datetime(item, self._DATE_FORMAT)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))


def inject_parameters(operation: str, parameters: Dict[str, str]):
return operation % parameters
14 changes: 14 additions & 0 deletionstests/e2e/driver_tests.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -288,6 +288,20 @@ def test_get_columns(self):
for table in table_names:
cursor.execute('DROP TABLE IF EXISTS {}'.format(table))

def test_escape_single_quotes(self):
with self.cursor({}) as cursor:
table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))
# Test escape syntax directly
cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name))
cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name))
rows = cursor.fetchall()
assert rows[0]["col_1"] == "you're"

# Test escape syntax in parameter
cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"})
rows = cursor.fetchall()
assert rows[0]["col_1"] == "you're"

def test_get_schemas(self):
with self.cursor({}) as cursor:
database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_'))
Expand Down
150 changes: 150 additions & 0 deletionstests/unit/test_param_escaper.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
from datetime import date, datetime
import unittest, pytest

from databricks.sql.utils import ParamEscaper, inject_parameters

pe = ParamEscaper()

class TestIndividualFormatters(object):

# Test individual type escapers
def test_escape_number_integer(self):
"""This behaviour falls back to Python's default string formatting of numbers
"""
assert pe.escape_number(100) == 100

def test_escape_number_float(self):
"""This behaviour falls back to Python's default string formatting of numbers
"""
assert pe.escape_number(100.1234) == 100.1234

def test_escape_string_normal(self):
"""
"""

assert pe.escape_string("golly bob howdy") == "'golly bob howdy'"

def test_escape_string_that_includes_special_characters(self):
"""Tests for how special characters are treated.

When passed a string, the `escape_string` method wraps it in single quotes
and escapes any special characters with a back stroke (\)

Example:

IN : his name was 'robert palmer'
OUT: 'his name was \'robert palmer\''
"""

# Testing for the presence of these characters: '"/\😂

assert pe.escape_string("his name was 'robert palmer'") == r"'his name was \'robert palmer\''"

# These tests represent the same user input in the several ways it can be written in Python
# Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently.
assert pe.escape_string("his name was \"robert palmer\"") == "'his name was \"robert palmer\"'"
assert pe.escape_string('his name was "robert palmer"') == "'his name was \"robert palmer\"'"
assert pe.escape_string('his name was {}'.format('"robert palmer"')) == "'his name was \"robert palmer\"'"

assert pe.escape_string("his name was robert / palmer") == r"'his name was robert / palmer'"

# If you need to include a single backslash, use an r-string to prevent Python from raising a
# DeprecationWarning for an invalid escape sequence
assert pe.escape_string("his name was robert \\/ palmer") == r"'his name was robert \\/ palmer'"
assert pe.escape_string("his name was robert \\ palmer") == r"'his name was robert \\ palmer'"
assert pe.escape_string("his name was robert \\\\ palmer") == r"'his name was robert \\\\ palmer'"

assert pe.escape_string("his name was robert palmer 😂") == r"'his name was robert palmer 😂'"

# Adding the test from PR #56 to prove escape behaviour

assert pe.escape_string("you're") == r"'you\'re'"

# Adding this test from #51 to prove escape behaviour when the target string involves repeated SQL escape chars
assert pe.escape_string("cat\\'s meow") == r"'cat\\\'s meow'"

# Tests from the docs: https://docs.databricks.com/sql/language-manual/data-types/string-type.html

assert pe.escape_string('Spark') == "'Spark'"
assert pe.escape_string("O'Connell") == r"'O\'Connell'"
assert pe.escape_string("Some\\nText") == r"'Some\\nText'"
assert pe.escape_string("Some\\\\nText") == r"'Some\\\\nText'"
assert pe.escape_string("서울시") == "'서울시'"
assert pe.escape_string("\\\\") == r"'\\\\'"

def test_escape_date_time(self):
INPUT = datetime(1991,8,3,21,55)
FORMAT = "%Y-%m-%d %H:%M:%S"
OUTPUT = "'1991-08-03 21:55:00'"
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT

def test_escape_date(self):
INPUT = date(1991,8,3)
FORMAT = "%Y-%m-%d"
OUTPUT = "'1991-08-03'"
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT

def test_escape_sequence_integer(self):
assert pe.escape_sequence([1,2,3,4]) == "(1,2,3,4)"

def test_escape_sequence_float(self):
assert pe.escape_sequence([1.1,2.2,3.3,4.4]) == "(1.1,2.2,3.3,4.4)"

def test_escape_sequence_string(self):
assert pe.escape_sequence(
["his", "name", "was", "robert", "palmer"]) == \
"('his','name','was','robert','palmer')"

def test_escape_sequence_sequence_of_strings(self):
# This is not valid SQL.
INPUT = [["his", "name"], ["was", "robert"], ["palmer"]]
OUTPUT = "(('his','name'),('was','robert'),('palmer'))"

assert pe.escape_sequence(INPUT) == OUTPUT


class TestFullQueryEscaping(object):

def test_simple(self):

INPUT = """
SELECT
field1,
field2,
field3
FROM
table
WHERE
field1 = %(param1)s
"""

OUTPUT = """
SELECT
field1,
field2,
field3
FROM
table
WHERE
field1 = ';DROP ALL TABLES'
"""

args = {"param1": ";DROP ALL TABLES"}

assert inject_parameters(INPUT, pe.escape_args(args)) == OUTPUT

@unittest.skipUnless(False, "Thrift server supports native parameter binding.")
def test_only_bind_in_where_clause(self):

INPUT = """
SELECT
%(field)s,
field2,
field3
FROM table
"""

args = {"field": "Some Value"}

with pytest.raises(Exception):
inject_parameters(INPUT, pe.escape_args(args))

[8]ページ先頭

©2009-2025 Movatter.jp