Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
datetime_format = Column(String(100))
extra = Column(Text)

table: Mapped[SqlaTable] = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
back_populates="columns",
)
Expand Down Expand Up @@ -1107,7 +1107,7 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model
expression = Column(utils.MediumText(), nullable=False)
extra = Column(Text)

table: Mapped[SqlaTable] = relationship(
table: Mapped["SqlaTable"] = relationship(
"SqlaTable",
back_populates="metrics",
)
Expand Down Expand Up @@ -1593,6 +1593,28 @@ def adhoc_metric_to_sqla(

return self.make_sqla_column_compatible(sqla_metric, label)

def _render_adhoc_expression_for_metadata_lookup(
self,
sql_expression: str,
template_processor: BaseTemplateProcessor | None,
) -> str:
"""Render Jinja in *sql_expression* so the result can be matched against
column metadata. Without this, a templated expression such as
``{{ filter_values('x')[0] }}`` is passed raw to ``get_column``, never
matches, and falls back to ``literal_column`` — which breaks for virtual
datasets because the rendered name isn't present in the FROM subquery."""
if not template_processor:
return sql_expression
try:
return template_processor.process_template(sql_expression)
except SupersetSyntaxErrorException as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in adhoc column: %(msg)s",
msg=str(ex),
)
) from ex

def adhoc_column_to_sqla( # pylint: disable=too-many-locals
self,
col: AdhocColumn,
Expand All @@ -1618,8 +1640,12 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
pdf = None
is_column_reference = col.get("isColumnReference", False)

metadata_lookup_key = self._render_adhoc_expression_for_metadata_lookup(
sql_expression, template_processor
)

# First, check if this is a column reference that exists in metadata
if col_in_metadata := self.get_column(sql_expression):
if col_in_metadata := self.get_column(metadata_lookup_key.strip()):
# Column exists in metadata - use it directly
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
Expand Down Expand Up @@ -1660,7 +1686,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
# A small number of drivers (Druid, Pinot) instead build
# cursor.description by inspecting the first returned row;
# for those we fall back to LIMIT 1.
tbl, _ = self.get_from_clause(template_processor)
tbl, _unused_cte = self.get_from_clause(template_processor)
if self.db_engine_spec.type_probe_needs_row:
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
else:
Expand Down
76 changes: 75 additions & 1 deletion tests/unit_tests/models/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
from unittest.mock import patch

import pytest
Expand All @@ -33,6 +33,7 @@
from superset.superset_typing import AdhocColumn

if TYPE_CHECKING:
from superset.jinja_context import BaseTemplateProcessor
from superset.models.core import Database


Expand Down Expand Up @@ -1429,6 +1430,79 @@ def test_adhoc_column_to_sqla_with_column_reference(database: Database) -> None:
assert "Customer Name" in result_str or '"Customer Name"' in result_str


def test_virtual_dataset_calculated_column_selected_via_templated_adhoc_dimension(
database: Database,
) -> None:
"""
Calculated columns on virtual datasets must be resolvable when the column
reference comes from a templated `sqlExpression` (e.g. using `filter_values`).

Regression test for cases where selecting a calculated column by name would
fall back to treating the resolved name as a bare identifier (breaking SQL
execution because the calculated expression is not present in the virtual
dataset's FROM subquery).
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="virtual_t",
# Non-empty `sql` makes the table a virtual dataset.
sql="SELECT random_value, category FROM t",
columns=[
TableColumn(column_name="random_value", type="INTEGER"),
TableColumn(column_name="category", type="TEXT"),
TableColumn(
column_name="gt_or_lt_50",
type="TEXT",
expression=(
"CASE WHEN random_value > 50 THEN 'GT 50' ELSE 'LT 50' END"
),
),
],
)

class DummyTemplateProcessor:
def process_template(self, sql_expression: str) -> str:
# Only resolve the templated column name; leave other expressions
# (like the calculated column's CASE expression) untouched.
if "filter_values('aggregation')" in sql_expression:
return "gt_or_lt_50"
return sql_expression

adhoc_col: AdhocColumn = {
"sqlExpression": (
"{{ filter_values('aggregation')[0] if "
"filter_values('aggregation') else \"'Total'\" }}"
),
"label": "breakdown_by",
"isColumnReference": True,
}

result = table.adhoc_column_to_sqla(
adhoc_col,
template_processor=cast("BaseTemplateProcessor", DummyTemplateProcessor()),
)
assert result is not None

# The calculated column expression should be inlined (not treated as a bare
# identifier), so SQL must contain the CASE expression.
with database.get_sqla_engine() as engine:
sql = str(
result.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)

assert "CASE WHEN" in sql
assert "random_value" in sql
assert "'GT 50'" in sql
assert "'LT 50'" in sql
assert "gt_or_lt_50" not in sql


def test_adhoc_column_to_sqla_preserves_column_type_for_time_grain(
database: Database,
) -> None:
Expand Down
Loading