From 8bf5f4800f7a880b09c2dcf6df65c7db1c3de77d Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 22 May 2024 14:32:01 +0100 Subject: [PATCH 1/4] fix: adds the ability to disallow SQL functions per engine --- superset/config.py | 8 ++++++ superset/db_engine_specs/base.py | 7 +++++- superset/exceptions.py | 15 ++++++++++++ superset/sql_parse.py | 38 +++++++++++++++++++++++++++++ tests/unit_tests/sql_parse_tests.py | 26 ++++++++++++++++++++ 5 files changed, 93 insertions(+), 1 deletion(-) diff --git a/superset/config.py b/superset/config.py index 10f075bb5fb2..7e1340598137 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1225,6 +1225,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None +# A set of disallowed SQL functions per engine. This is used to restrict the use of +# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine +# names, and the values are sets of disallowed functions. +DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = { + "postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"}, + "clickhouse": {"url"}, + "mysql": {"version"} +} # A function that intercepts the SQL to be executed and can alter it. # A common use case for this is around adding some sort of comment header to the SQL diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index fafea897f195..2698600d8a78 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -63,7 +63,7 @@ from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table from superset.superset_typing import ( OAuth2ClientConfig, @@ -1819,6 +1819,11 @@ def execute( # pylint: disable=unused-argument """ if not cls.allows_sql_comments: query = sql_parse.strip_comments_from_sql(query, engine=cls.engine) + disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get( + cls.engine, set() + ) + if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine): + raise DisallowedSQLFunction(disallowed_functions) if cls.arraysize: cursor.arraysize = cls.arraysize diff --git a/superset/exceptions.py b/superset/exceptions.py index 0315ee30f4f4..47cd511f8f20 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -358,6 +358,21 @@ def __init__(self, error: str): ) +class DisallowedSQLFunction(SupersetErrorException): + """ + Disallowed function found on SQL statement + """ + + def __init__(self, functions: set[str]): + super().__init__( + SupersetError( + message=f"SQL statement contains disallowed function(s): {functions}", + error_type=SupersetErrorType.SYNTAX_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + class CreateKeyValueDistributedLockFailedException(Exception): """ Exception to signalize failure to acquire lock. diff --git a/superset/sql_parse.py b/superset/sql_parse.py index f32647042b0a..898e9f90d05c 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -39,6 +39,7 @@ from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( + Function, Identifier, IdentifierList, Parenthesis, @@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: return cte, remainder +def check_sql_functions_exist( + sql: str, function_list: set[str], engine: str | None = None +) -> bool: + """ + Check if the SQL statement contains any of the specified functions. + + :param sql: The SQL statement + :param function_list: The list of functions to search for + :param engine: The engine to use for parsing the SQL statement + """ + return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) + + def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: """ Strips comments from a SQL statement, does a simple test first @@ -743,6 +757,30 @@ def tables(self) -> set[Table]: self._tables = self._extract_tables_from_sql() return self._tables + def _check_functions_exist_in_token( + self, token: Token, functions: set[str] + ) -> bool: + if isinstance(token, Function) and token.get_name().lower() in functions: + return True + if hasattr(token, "tokens"): + for inner_token in token.tokens: + if self._check_functions_exist_in_token(inner_token, functions): + return True + return False + + def check_functions_exist(self, functions: set[str]) -> bool: + """ + Check if the SQL statement contains any of the specified functions. + + :param functions: A set of functions to search for + :return: True if the statement contains any of the specified functions + """ + for statement in self._parsed: + for token in statement.tokens: + if self._check_functions_exist_in_token(token, functions): + return True + return False + def _extract_tables_from_sql(self) -> set[Table]: """ Extract all table references in a query. diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 3b80b7e01d1a..6259d6272db6 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -32,6 +32,7 @@ ) from superset.sql_parse import ( add_table_name, + check_sql_functions_exist, extract_table_references, extract_tables_from_jinja_sql, get_rls_for_table, @@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None: ) +def test_check_sql_functions_exist() -> None: + """ + Test that comments are stripped out correctly. + """ + assert not ( + check_sql_functions_exist("select a, b from version", {"version"}, "postgresql") + ) + + assert check_sql_functions_exist("select version()", {"version"}, "postgresql") + + assert check_sql_functions_exist( + "select version from version()", {"version"}, "postgresql" + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version from version()) as a", + {"version"}, + "postgresql", + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version()) as a", {"version"}, "postgresql" + ) + + def test_sanitize_clause_valid(): # regular clauses assert sanitize_clause("col = 1") == "col = 1" From 0f97b4509d57ded10eaa0b7b9f5a980bb84f3d7a Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 22 May 2024 15:03:34 +0100 Subject: [PATCH 2/4] fix test --- superset/sql_parse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 898e9f90d05c..192a998c3fbd 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -760,7 +760,11 @@ def tables(self) -> set[Table]: def _check_functions_exist_in_token( self, token: Token, functions: set[str] ) -> bool: - if isinstance(token, Function) and token.get_name().lower() in functions: + if ( + isinstance(token, Function) + and token.get_name() is not None + and token.get_name().lower() in functions + ): return True if hasattr(token, "tokens"): for inner_token in token.tokens: From 96d04688fdcdfbb46eec7f777edc88b8322f6bf8 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 22 May 2024 15:17:39 +0100 Subject: [PATCH 3/4] fix lint --- superset/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/superset/config.py b/superset/config.py index 7e1340598137..9b8b61bd14d1 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1231,9 +1231,10 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = { "postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"}, "clickhouse": {"url"}, - "mysql": {"version"} + "mysql": {"version"}, } + # A function that intercepts the SQL to be executed and can alter it. # A common use case for this is around adding some sort of comment header to the SQL # with information such as the username and worker node information From 80bb186c7d644bc5b5778ec28448aa90aa9efee3 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Wed, 22 May 2024 16:19:16 +0100 Subject: [PATCH 4/4] fix trino thread --- superset/db_engine_specs/trino.py | 11 +++++--- .../unit_tests/db_engine_specs/test_trino.py | 26 +++++++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 08a38894e664..73fdeef2a106 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING import simplejson as json -from flask import current_app +from flask import current_app, Flask from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -219,11 +219,14 @@ def execute_with_cursor( execute_result: dict[str, Any] = {} execute_event = threading.Event() - def _execute(results: dict[str, Any], event: threading.Event) -> None: + def _execute( + results: dict[str, Any], event: threading.Event, app: Flask + ) -> None: logger.debug("Query %d: Running query: %s", query_id, sql) try: - cls.execute(cursor, sql, query.database) + with app.app_context(): + cls.execute(cursor, sql, query.database) except Exception as ex: # pylint: disable=broad-except results["error"] = ex finally: @@ -231,7 +234,7 @@ def _execute(results: dict[str, Any], event: threading.Event) -> None: execute_thread = threading.Thread( target=_execute, - args=(execute_result, execute_event), + args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access ) execute_thread.start() diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index e35615f57a4f..d46455b05a8a 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -400,7 +400,7 @@ def test_handle_cursor_early_cancel( assert cancel_query_mock.call_args is None -def test_execute_with_cursor_in_parallel(mocker: MockerFixture): +def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): """Test that `execute_with_cursor` fetches query ID from the cursor""" from superset.db_engine_specs.trino import TrinoEngineSpec @@ -415,16 +415,20 @@ def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id mock_cursor.execute.side_effect = _mock_execute - - TrinoEngineSpec.execute_with_cursor( - cursor=mock_cursor, - sql="SELECT 1 FROM foo", - query=mock_query, - ) - - mock_query.set_extra_json_key.assert_called_once_with( - key=QUERY_CANCEL_KEY, value=query_id - ) + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) + + mock_query.set_extra_json_key.assert_called_once_with( + key=QUERY_CANCEL_KEY, value=query_id + ) def test_get_columns(mocker: MockerFixture):