Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 119 additions & 36 deletions superset/db_engine_specs/starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import logging
import re
from re import Pattern
from typing import Any, Optional, Union
from typing import Any
from urllib import parse

from flask_babel import gettext as __
from sqlalchemy import Float, Integer, Numeric, types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine

Expand All @@ -31,6 +32,8 @@
from superset.models.core import Database
from superset.utils.core import GenericDataType

DEFAULT_CATALOG = "default_catalog"

# Regular expressions to catch custom errors
CONNECTION_ACCESS_DENIED_REGEX = re.compile(
"Access denied for user '(?P<username>.*?)'"
Expand Down Expand Up @@ -68,23 +71,23 @@ class ARRAY(TypeEngine):
__visit_name__ = "ARRAY"

@property
def python_type(self) -> Optional[type[list[Any]]]:
def python_type(self) -> type[list[Any]] | None:
return list


class MAP(TypeEngine):
__visit_name__ = "MAP"

@property
def python_type(self) -> Optional[type[dict[Any, Any]]]:
def python_type(self) -> type[dict[Any, Any]] | None:
return dict


class STRUCT(TypeEngine):
__visit_name__ = "STRUCT"

@property
def python_type(self) -> Optional[type[Any]]:
def python_type(self) -> type[Any] | None:
return None


Expand All @@ -93,9 +96,9 @@ class StarRocksEngineSpec(MySQLEngineSpec):
engine_name = "StarRocks"

default_driver = "starrocks"
sqlalchemy_uri_placeholder = (
"starrocks://user:password@host:port/catalog.db[?key=value&key=value...]"
)
sqlalchemy_uri_placeholder = "starrocks://user:password@host:port[/catalog.db]"
supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

column_type_mappings = ( # type: ignore
(
Expand Down Expand Up @@ -168,17 +171,39 @@ def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
"""
Adjust engine parameters for StarRocks catalog and schema support.

StarRocks uses a "catalog.schema" format in the database field:
- "catalog.schema" - both specified
- "catalog." - catalog only (for browsing schemas)
- None - neither specified
"""
if uri.database and "." in uri.database:
current_catalog, current_schema = uri.database.split(".", 1)
elif uri.database:
current_catalog, current_schema = uri.database, None
else:
current_catalog, current_schema = None, None

if schema:
schema = parse.quote(schema, safe="")
if "." in database:
database = database.split(".")[0] + "." + schema
else:
database = "default_catalog." + schema
uri = uri.set(database=database)

effective_catalog = catalog or current_catalog or DEFAULT_CATALOG
# only use the schema/db from uri if we're not overriding catalog
effective_schema = schema
if not effective_schema and (not catalog or catalog == current_catalog):
effective_schema = current_schema

if effective_schema:
adjusted_database = f"{effective_catalog}.{effective_schema}"
else:
adjusted_database = f"{effective_catalog}."

uri = uri.set(database=adjusted_database)

return uri, connect_args

Expand All @@ -187,21 +212,87 @@ def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
connect_args: dict[str, Any],
) -> Optional[str]:
) -> str | None:
"""
Return the configured schema.
Extract schema from engine parameters.

For StarRocks the SQLAlchemy URI looks like this:
Returns the schema portion from formats like:
- "catalog.schema" -> "schema"
- "schema" -> None (ambiguous - could be catalog or schema)
- "" or None -> None
"""
if not sqlalchemy_uri.database:
return None

database = sqlalchemy_uri.database.strip("/")
if not database or "." not in database:
return None

starrocks://localhost:9030/catalog.schema
schema = database.split(".")[-1]
return parse.unquote(schema)

@classmethod
def get_default_catalog(cls, database: Database) -> str:
"""
database = sqlalchemy_uri.database.strip("/")
Return the default catalog.

if "." not in database:
return None
Extracts catalog from URI (e.g., "iceberg" from "iceberg.schema"),
otherwise returns DEFAULT_CATALOG.
"""
if database.url_object.database and "." in database.url_object.database:
return database.url_object.database.split(".")[0]

return DEFAULT_CATALOG

@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
"""
Get all available catalogs.

Executes SHOW CATALOGS and extracts catalog names from the result.
The command returns columns: Catalog, Type, Comment
"""
try:
result = inspector.bind.execute("SHOW CATALOGS")
catalogs = set()

for row in result:
try:
if hasattr(row, "keys") and "Catalog" in row.keys():
catalogs.add(row["Catalog"])
elif hasattr(row, "Catalog"):
catalogs.add(row.Catalog)
else:
catalogs.add(row[0])
except (AttributeError, TypeError, IndexError, KeyError) as ex:
logger.warning(
"Unable to extract catalog name from row: %s (%s)", row, ex
)
continue

return catalogs
except Exception as ex: # pylint: disable=broad-except
logger.exception("Error fetching catalog names from SHOW CATALOGS: %s", ex)
return set()

@classmethod
def get_schema_names(cls, inspector: Inspector) -> set[str]:
"""
Get all schemas/databases using SHOW DATABASES.

return parse.unquote(database.split(".")[1])
The catalog context is set via the database field in the connection URL
(e.g., "catalog." sets the context to that catalog).
"""
try:
result = inspector.bind.execute("SHOW DATABASES")
return {row[0] for row in result}
except Exception as ex: # pylint: disable=broad-except
logger.exception("Error fetching schema names from SHOW DATABASES: %s", ex)
return set()

@classmethod
def impersonate_user(
Expand All @@ -225,21 +316,13 @@ def impersonate_user(
def get_prequeries(
cls,
database: Database,
catalog: Union[str, None] = None,
schema: Union[str, None] = None,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
"""
Return pre-session queries.

These are currently used as an alternative to ``adjust_engine_params`` for
databases where the selected schema cannot be specified in the SQLAlchemy URI or
connection arguments.

For example, in order to specify a default schema in RDS we need to run a query
at the beginning of the session:

sql> set search_path = my_schema;
Get pre-session queries.

For StarRocks with user impersonation enabled, returns an EXECUTE AS statement.
"""
if database.impersonate_user:
username = database.get_effective_user(database.url_object)
Expand Down
116 changes: 113 additions & 3 deletions tests/unit_tests/db_engine_specs/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_get_column_spec(
(
"starrocks://user:password@host/db1",
{"param1": "some_value"},
"db1",
"db1.", # Single value is treated as schema (in default catalog)
{"param1": "some_value"},
),
(
Expand All @@ -88,12 +88,18 @@ def test_get_column_spec(
"catalog1.db1",
{"param1": "some_value"},
),
(
"starrocks://user:password@host",
{"param1": "some_value"},
"default_catalog.",
{"param1": "some_value"},
),
],
)
def test_adjust_engine_params(
sqlalchemy_uri: str,
connect_args: dict[str, Any],
return_schema: str,
return_schema: Optional[str],
return_connect_args: dict[str, Any],
) -> None:
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
Expand All @@ -112,6 +118,7 @@ def test_get_schema_from_engine_params() -> None:
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

# With catalog.schema format
assert (
StarRocksEngineSpec.get_schema_from_engine_params(
make_url("starrocks://localhost:9030/hive.default"),
Expand All @@ -120,9 +127,19 @@ def test_get_schema_from_engine_params() -> None:
== "default"
)

# With only catalog (no schema) - should return None
assert (
StarRocksEngineSpec.get_schema_from_engine_params(
make_url("starrocks://localhost:9030/sales"),
{},
)
is None
)

# With no database - should return None
assert (
StarRocksEngineSpec.get_schema_from_engine_params(
make_url("starrocks://localhost:9030/hive"),
make_url("starrocks://localhost:9030"),
{},
)
is None
Expand Down Expand Up @@ -173,3 +190,96 @@ def test_impersonation_disabled(mocker: MockerFixture) -> None:
) == (make_url("starrocks://service_user@localhost:9030/hive.default"), {})

assert StarRocksEngineSpec.get_prequeries(database) == []


def test_get_default_catalog(mocker: MockerFixture) -> None:
"""
Test the ``get_default_catalog`` method.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

# Test case 1: Catalog is in the URI
database = mocker.MagicMock()
database.url_object.database = "hive.default"

assert StarRocksEngineSpec.get_default_catalog(database) == "hive"

# Test case 2: Catalog is not in the URI, returns default
database = mocker.MagicMock()
database.url_object.database = "default"

assert StarRocksEngineSpec.get_default_catalog(database) == "default_catalog"


def test_get_catalog_names(mocker: MockerFixture) -> None:
"""
Test the ``get_catalog_names`` method.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

database = mocker.MagicMock()
inspector = mocker.MagicMock()

# Mock the actual StarRocks SHOW CATALOGS format
# StarRocks returns rows with keys: ['Catalog', 'Type', 'Comment']
mock_row_1 = mocker.MagicMock()
mock_row_1.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_1.__getitem__ = (
lambda self, key: "default_catalog" if key == "Catalog" else None
)

mock_row_2 = mocker.MagicMock()
mock_row_2.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_2.__getitem__ = lambda self, key: "hive" if key == "Catalog" else None

mock_row_3 = mocker.MagicMock()
mock_row_3.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_3.__getitem__ = lambda self, key: "iceberg" if key == "Catalog" else None
Comment on lines +225 to +237
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The mocks for catalog rows override __getitem__ on MagicMock instances, but special methods are looked up on the class, so row["Catalog"] still returns MagicMock objects instead of the expected strings, causing get_catalog_names to return a set of mocks instead of {"default_catalog", "hive", "iceberg"} and making the test assertions fail. [logic error]

Severity Level: Minor ⚠️

Suggested change
mock_row_1 = mocker.MagicMock()
mock_row_1.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_1.__getitem__ = lambda self, key: "default_catalog" if key == "Catalog" else None
mock_row_2 = mocker.MagicMock()
mock_row_2.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_2.__getitem__ = lambda self, key: "hive" if key == "Catalog" else None
mock_row_3 = mocker.MagicMock()
mock_row_3.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_3.__getitem__ = lambda self, key: "iceberg" if key == "Catalog" else None
mock_row_1 = {"Catalog": "default_catalog", "Type": None, "Comment": None}
mock_row_2 = {"Catalog": "hive", "Type": None, "Comment": None}
mock_row_3 = {"Catalog": "iceberg", "Type": None, "Comment": None}
Why it matters? ⭐

The claim is correct: assigning getitem on MagicMock instances does not affect Python's special-method lookup which resolves special methods on the object's class. As written the test may yield MagicMock values when code does row["Catalog"] (or similar), causing the assertion to fail or be flaky. Replacing the MagicMock rows with plain dicts (or configuring the mock's class-level getitem) produces real strings and makes the test deterministic. This is a real logic fix for the unit test rather than a cosmetic change.

Prompt for AI Agent 🤖
This is a comment left during a code review.

**Path:** tests/unit_tests/db_engine_specs/test_starrocks.py
**Line:** 225:235
**Comment:**
	*Logic Error: The mocks for catalog rows override `__getitem__` on `MagicMock` instances, but special methods are looked up on the class, so `row["Catalog"]` still returns MagicMock objects instead of the expected strings, causing `get_catalog_names` to return a set of mocks instead of `{"default_catalog", "hive", "iceberg"}` and making the test assertions fail.

Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.


inspector.bind.execute.return_value = [mock_row_1, mock_row_2, mock_row_3]

catalogs = StarRocksEngineSpec.get_catalog_names(database, inspector)
assert catalogs == {"default_catalog", "hive", "iceberg"}


@pytest.mark.parametrize(
"uri,catalog,schema,expected_database",
[
# Test with catalog and schema/db in URI
("starrocks://host/hive.sales", None, None, "hive.sales"),
# Test overriding catalog
("starrocks://host/hive.sales", "iceberg", None, "iceberg."),
# Test overriding schema/db
("starrocks://host/hive.sales", None, "marketing", "hive.marketing"),
# Test overriding both
("starrocks://host/hive.sales", "iceberg", "marketing", "iceberg.marketing"),
# Test with only catalog in URI (no schema/db), add new schema
("starrocks://host/hive", None, "marketing", "hive.marketing"),
# Test with catalog in URI, override catalog
("starrocks://host/hive", "iceberg", None, "iceberg."),
# Test with no catalog/database in URI, overriding catalog"
("starrocks://host", "iceberg", None, "iceberg."),
# Test with no catalog/database in URI, catalog and schema/db
("starrocks://host", "iceberg", "sales", "iceberg.sales"),
# Test with empty database and empty overrides, uses default catalog
("starrocks://host", None, None, "default_catalog."),
# Test schema only (no catalog) when URI has no database, uses default_catalog
("starrocks://host", None, "sales", "default_catalog.sales"),
],
)
def test_adjust_engine_params_with_catalog(
uri: str,
catalog: Optional[str],
schema: Optional[str],
expected_database: Optional[str],
) -> None:
"""
Test the ``adjust_engine_params`` method with catalog parameter.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

url = make_url(uri)
returned_url, _ = StarRocksEngineSpec.adjust_engine_params(
url, {}, catalog=catalog, schema=schema
)
assert returned_url.database == expected_database
Loading