Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,15 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None

# A set of disallowed SQL functions per engine. This is used to restrict the use of
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
# names, and the values are sets of disallowed functions.
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure where the best place for this deny list, i.e., here in the configuration or within the extra JSON payload of the database.

Additionally should this be engine (dialect) specific or database specific? If it's the later then maybe the extra JSON payload field is preferable.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSON payload at the database level is more dynamic and would avoid having to change the config to add remove disallowed functions. But on the other hand the user that actually registers the db could have intentions to "abuse" these functions.

"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
"clickhouse": {"url"},
"mysql": {"version"},
}


# A function that intercepts the SQL to be executed and can alter it.
# A common use case for this is around adding some sort of comment header to the SQL
Expand Down
7 changes: 6 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import (
OAuth2ClientConfig,
Expand Down Expand Up @@ -1818,6 +1818,11 @@ def execute( # pylint: disable=unused-argument
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
cls.engine, set()
)
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
raise DisallowedSQLFunction(disallowed_functions)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
11 changes: 7 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
from typing import Any, TYPE_CHECKING

from flask import current_app
from flask import current_app, Flask
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
Expand Down Expand Up @@ -218,19 +218,22 @@ def execute_with_cursor(
execute_result: dict[str, Any] = {}
execute_event = threading.Event()

def _execute(results: dict[str, Any], event: threading.Event) -> None:
def _execute(
results: dict[str, Any], event: threading.Event, app: Flask
) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)

try:
cls.execute(cursor, sql, query.database)
with app.app_context():
cls.execute(cursor, sql, query.database)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting.

except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
finally:
event.set()

execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access
)
execute_thread.start()

Expand Down
15 changes: 15 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,21 @@ def __init__(self, error: str):
)


class DisallowedSQLFunction(SupersetErrorException):
"""
Disallowed function found on SQL statement
"""

def __init__(self, functions: set[str]):
super().__init__(
SupersetError(
message=f"SQL statement contains disallowed function(s): {functions}",
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
)
)


class CreateKeyValueDistributedLockFailedException(Exception):
"""
Exception to signalize failure to acquire lock.
Expand Down
42 changes: 42 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
Expand Down Expand Up @@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
return cte, remainder


def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.

:param sql: The SQL statement
:param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement
"""
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We (probably me) will have to convert this to use sqlglot and the SQLStatement class (#26786) but I'm happy to do it, seems simple enough.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to do it myself, can I just not use sqlparse? implement something using the same pattern as extract_tables_from_statement?



def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
Expand Down Expand Up @@ -743,6 +757,34 @@ def tables(self) -> set[Table]:
self._tables = self._extract_tables_from_sql()
return self._tables

def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False

def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.

:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False

def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Expand Down
24 changes: 14 additions & 10 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args is None


def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec

Expand All @@ -416,16 +416,20 @@ def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_get_columns(mocker: MockerFixture):
Expand Down
26 changes: 26 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from superset.sql_parse import (
add_table_name,
check_sql_functions_exist,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
Expand Down Expand Up @@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
)


def test_check_sql_functions_exist() -> None:
"""
Test that comments are stripped out correctly.
"""
assert not (
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
)

assert check_sql_functions_exist("select version()", {"version"}, "postgresql")

assert check_sql_functions_exist(
"select version from version()", {"version"}, "postgresql"
)

assert check_sql_functions_exist(
"select 1, a.version from (select version from version()) as a",
{"version"},
"postgresql",
)

assert check_sql_functions_exist(
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
)


def test_sanitize_clause_valid():
# regular clauses
assert sanitize_clause("col = 1") == "col = 1"
Expand Down