Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from sqlglot import Dialect, exp
from sqlglot.errors import ErrorLevel
from sqlglot.helper import ensure_list
from sqlglot.helper import ensure_list, seq_get
from sqlglot.optimizer.qualify_columns import quote_identifiers

from sqlmesh.core.dialect import (
Expand Down Expand Up @@ -1772,7 +1772,7 @@ def scd_type_2_by_column(
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
check_columns: t.Union[exp.Star, t.Sequence[exp.Column]],
check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]],
invalidate_hard_deletes: bool = True,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
Expand Down Expand Up @@ -1810,7 +1810,7 @@ def _scd_type_2(
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
Expand Down Expand Up @@ -1885,8 +1885,10 @@ def remove_managed_columns(
# they are equal or not, the extra check is not a problem and we gain simplified logic here.
# If we want to change this, then we just need to check the expressions in unique_key and pull out the
# column names and then remove them from the unmanaged_columns
if check_columns and check_columns == exp.Star():
check_columns = [exp.column(col) for col in unmanaged_columns_to_types]
if check_columns:
# Handle both Star directly and [Star()] (which can happen during serialization/deserialization)
if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star):
check_columns = [exp.column(col) for col in unmanaged_columns_to_types]
execution_ts = (
exp.cast(execution_time, time_data_type, dialect=self.dialect)
if isinstance(execution_time, exp.Column)
Expand Down Expand Up @@ -1923,7 +1925,8 @@ def remove_managed_columns(
col_qualified.set("table", exp.to_identifier("joined"))

t_col = col_qualified.copy()
t_col.this.set("this", f"t_{col.name}")
for column in t_col.find_all(exp.Column):
column.this.set("this", f"t_{column.name}")

row_check_conditions.extend(
[
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _scd_type_2(
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PydanticModel,
SQLGlotBool,
SQLGlotColumn,
SQLGlotListOfColumnsOrStar,
SQLGlotListOfFieldsOrStar,
SQLGlotListOfFields,
SQLGlotPositiveInt,
SQLGlotString,
Expand Down Expand Up @@ -852,7 +852,7 @@ def to_expression(

class SCDType2ByColumnKind(_SCDType2Kind):
name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN
columns: SQLGlotListOfColumnsOrStar
columns: SQLGlotListOfFieldsOrStar
execution_time_as_valid_from: SQLGlotBool = False
updated_at_name: t.Optional[SQLGlotColumn] = None

Expand Down
12 changes: 6 additions & 6 deletions sqlmesh/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column:
return expression


def list_of_columns_or_star_validator(
def list_of_fields_or_star_validator(
v: t.Any, values: t.Any
) -> t.Union[exp.Star, t.List[exp.Column]]:
) -> t.Union[exp.Star, t.List[exp.Expression]]:
expressions = _get_fields(v, values)
if len(expressions) == 1 and isinstance(expressions[0], exp.Star):
return t.cast(exp.Star, expressions[0])
return t.cast(t.List[exp.Column], expressions)
return t.cast(t.List[exp.Expression], expressions)


def cron_validator(v: t.Any) -> str:
Expand Down Expand Up @@ -339,7 +339,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
SQLGlotPositiveInt = int
SQLGlotColumn = exp.Column
SQLGlotListOfFields = t.List[exp.Expression]
SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star]
SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star]
SQLGlotCron = str
else:
from pydantic.functional_validators import BeforeValidator
Expand All @@ -352,7 +352,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
SQLGlotListOfFields = t.Annotated[
t.List[exp.Expression], BeforeValidator(list_of_fields_validator)
]
SQLGlotListOfColumnsOrStar = t.Annotated[
t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator)
SQLGlotListOfFieldsOrStar = t.Annotated[
t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator)
]
SQLGlotCron = t.Annotated[str, BeforeValidator(cron_validator)]
6 changes: 3 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def test_scd_type_2_by_col_serde():
model_json_parsed = json.loads(model.json())
assert model_json_parsed["kind"]["dialect"] == "bigquery"
assert model_json_parsed["kind"]["unique_key"] == ["`a`"]
assert model_json_parsed["kind"]["columns"] == "*"
assert model_json_parsed["kind"]["columns"] == ["*"]
# Bigquery converts TIMESTAMP -> DATETIME
assert model_json_parsed["kind"]["time_data_type"] == "DATETIME"

Expand Down Expand Up @@ -5427,7 +5427,7 @@ def scd_type_2_model(context, **kwargs):
'["col1"]',
[exp.to_column("col1", quoted=True)],
),
("*", exp.Star()),
("*", [exp.Star()]),
],
)
def test_check_column_variants(input_columns, expected_columns):
Expand Down Expand Up @@ -8360,7 +8360,7 @@ def test_model_kind_to_expression():
.kind.to_expression()
.sql()
== """SCD_TYPE_2_BY_COLUMN (
columns *,
columns (*),
execution_time_as_valid_from FALSE,
unique_key ("a", "b"),
valid_from_name "valid_from",
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,7 +2490,7 @@ def test_insert_into_scd_type_2_by_column(
target_columns_to_types=table_columns,
table_format=None,
unique_key=[exp.to_column("id", quoted=True)],
check_columns=exp.Star(),
check_columns=[exp.Star()],
valid_from_col=exp.column("valid_from", quoted=True),
valid_to_col=exp.column("valid_to", quoted=True),
execution_time="2020-01-02",
Expand Down
25 changes: 24 additions & 1 deletion tests/dbt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlmesh import Context
from sqlmesh.core.console import NoopConsole, get_console
from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind
from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange
from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange, SCDType2ByColumnKind
from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json
from sqlmesh.core.config.common import VirtualEnvironmentMode
from sqlmesh.core.model.meta import GrantsTargetLayer
Expand Down Expand Up @@ -707,6 +707,29 @@ def test_load_multiple_snapshots_defined_in_same_file(sushi_test_dbt_context: Co
assert context.get_model("snapshots.items_check_snapshot")


@pytest.mark.slow
def test_dbt_snapshot_with_check_cols_expressions(sushi_test_dbt_context: Context) -> None:
context = sushi_test_dbt_context
model = context.get_model("snapshots.items_check_with_cast_snapshot")
assert model is not None
assert isinstance(model.kind, SCDType2ByColumnKind)

columns = model.kind.columns
assert isinstance(columns, list)
assert len(columns) == 1

# expression in check_cols is: ds::DATE
assert isinstance(columns[0], exp.Cast)
assert columns[0].sql() == 'CAST("ds" AS DATE)'

context.load()
cached_model = context.get_model("snapshots.items_check_with_cast_snapshot")
assert cached_model is not None
assert isinstance(cached_model.kind, SCDType2ByColumnKind)
assert isinstance(cached_model.kind.columns, list)
assert len(cached_model.kind.columns) == 1


@pytest.mark.slow
def test_dbt_jinja_macro_undefined_variable_error(create_empty_project):
project_dir, model_dir = create_empty_project()
Expand Down
26 changes: 26 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,32 @@ def test_model_kind():
on_additive_change=OnAdditiveChange.ALLOW,
)

check_cols_with_cast = ModelConfig(
materialized=Materialization.SNAPSHOT,
unique_key=["id"],
strategy="check",
check_cols=["created_at::TIMESTAMPTZ"],
).model_kind(context)
assert isinstance(check_cols_with_cast, SCDType2ByColumnKind)
assert check_cols_with_cast.execution_time_as_valid_from is True
assert len(check_cols_with_cast.columns) == 1
assert isinstance(check_cols_with_cast.columns[0], exp.Cast)
assert check_cols_with_cast.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)'

check_cols_multiple_expr = ModelConfig(
materialized=Materialization.SNAPSHOT,
unique_key=["id"],
strategy="check",
check_cols=["created_at::TIMESTAMPTZ", "COALESCE(status, 'active')"],
).model_kind(context)
assert isinstance(check_cols_multiple_expr, SCDType2ByColumnKind)
assert len(check_cols_multiple_expr.columns) == 2
assert isinstance(check_cols_multiple_expr.columns[0], exp.Cast)
assert isinstance(check_cols_multiple_expr.columns[1], exp.Coalesce)

assert check_cols_multiple_expr.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)'
assert check_cols_multiple_expr.columns[1].sql() == "COALESCE(\"status\", 'active')"

assert ModelConfig(materialized=Materialization.INCREMENTAL, time_column="foo").model_kind(
context
) == IncrementalByTimeRangeKind(
Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ select * from {{ source('streaming', 'items') }}
select * from {{ source('streaming', 'items') }}

{% endsnapshot %}

{% snapshot items_check_with_cast_snapshot %}

{{
config(
target_schema='snapshots',
unique_key='id',
strategy='check',
check_cols=['ds::DATE'],
invalidate_hard_deletes=True,
)
}}

select * from {{ source('streaming', 'items') }}

{% endsnapshot %}