diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index ebbf136cd1..38f662c76f 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -18,7 +18,7 @@ from sqlglot import Dialect, exp from sqlglot.errors import ErrorLevel -from sqlglot.helper import ensure_list +from sqlglot.helper import ensure_list, seq_get from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlmesh.core.dialect import ( @@ -1772,7 +1772,7 @@ def scd_type_2_by_column( valid_from_col: exp.Column, valid_to_col: exp.Column, execution_time: t.Union[TimeLike, exp.Column], - check_columns: t.Union[exp.Star, t.Sequence[exp.Column]], + check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]], invalidate_hard_deletes: bool = True, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1810,7 +1810,7 @@ def _scd_type_2( execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1885,8 +1885,10 @@ def remove_managed_columns( # they are equal or not, the extra check is not a problem and we gain simplified logic here. # If we want to change this, then we just need to check the expressions in unique_key and pull out the # column names and then remove them from the unmanaged_columns - if check_columns and check_columns == exp.Star(): - check_columns = [exp.column(col) for col in unmanaged_columns_to_types] + if check_columns: + # Handle both Star directly and [Star()] (which can happen during serialization/deserialization) + if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star): + check_columns = [exp.column(col) for col in unmanaged_columns_to_types] execution_ts = ( exp.cast(execution_time, time_data_type, dialect=self.dialect) if isinstance(execution_time, exp.Column) @@ -1923,7 +1925,8 @@ def remove_managed_columns( col_qualified.set("table", exp.to_identifier("joined")) t_col = col_qualified.copy() - t_col.this.set("this", f"t_{col.name}") + for column in t_col.find_all(exp.Column): + column.this.set("this", f"t_{column.name}") row_check_conditions.extend( [ diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 21846b8693..272c45c193 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -302,7 +302,7 @@ def _scd_type_2( execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index ad5197a73a..9abaa9c650 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -23,7 +23,7 @@ PydanticModel, SQLGlotBool, SQLGlotColumn, - SQLGlotListOfColumnsOrStar, + SQLGlotListOfFieldsOrStar, SQLGlotListOfFields, SQLGlotPositiveInt, SQLGlotString, @@ -852,7 +852,7 @@ def to_expression( class SCDType2ByColumnKind(_SCDType2Kind): name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN - columns: SQLGlotListOfColumnsOrStar + columns: SQLGlotListOfFieldsOrStar execution_time_as_valid_from: SQLGlotBool = False updated_at_name: t.Optional[SQLGlotColumn] = None diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 317e873aeb..2c9c570e5b 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -289,13 +289,13 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column: return expression -def list_of_columns_or_star_validator( +def list_of_fields_or_star_validator( v: t.Any, values: t.Any -) -> t.Union[exp.Star, t.List[exp.Column]]: +) -> t.Union[exp.Star, t.List[exp.Expression]]: expressions = _get_fields(v, values) if len(expressions) == 1 and isinstance(expressions[0], exp.Star): return t.cast(exp.Star, expressions[0]) - return t.cast(t.List[exp.Column], expressions) + return t.cast(t.List[exp.Expression], expressions) def cron_validator(v: t.Any) -> str: @@ -339,7 +339,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: SQLGlotPositiveInt = int SQLGlotColumn = exp.Column SQLGlotListOfFields = t.List[exp.Expression] - SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star] + SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star] SQLGlotCron = str else: from pydantic.functional_validators import BeforeValidator @@ -352,7 +352,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: SQLGlotListOfFields = t.Annotated[ t.List[exp.Expression], BeforeValidator(list_of_fields_validator) ] - SQLGlotListOfColumnsOrStar = t.Annotated[ - t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator) + SQLGlotListOfFieldsOrStar = t.Annotated[ + t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator) ] SQLGlotCron = t.Annotated[str, BeforeValidator(cron_validator)] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c3feef6095..1ada9bd4a5 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -949,7 +949,7 @@ def test_scd_type_2_by_col_serde(): model_json_parsed = json.loads(model.json()) assert model_json_parsed["kind"]["dialect"] == "bigquery" assert model_json_parsed["kind"]["unique_key"] == ["`a`"] - assert model_json_parsed["kind"]["columns"] == "*" + assert model_json_parsed["kind"]["columns"] == ["*"] # Bigquery converts TIMESTAMP -> DATETIME assert model_json_parsed["kind"]["time_data_type"] == "DATETIME" @@ -5427,7 +5427,7 @@ def scd_type_2_model(context, **kwargs): '["col1"]', [exp.to_column("col1", quoted=True)], ), - ("*", exp.Star()), + ("*", [exp.Star()]), ], ) def test_check_column_variants(input_columns, expected_columns): @@ -8360,7 +8360,7 @@ def test_model_kind_to_expression(): .kind.to_expression() .sql() == """SCD_TYPE_2_BY_COLUMN ( -columns *, +columns (*), execution_time_as_valid_from FALSE, unique_key ("a", "b"), valid_from_name "valid_from", diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index c0a7a01b51..f3a06c83bd 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -2490,7 +2490,7 @@ def test_insert_into_scd_type_2_by_column( target_columns_to_types=table_columns, table_format=None, unique_key=[exp.to_column("id", quoted=True)], - check_columns=exp.Star(), + check_columns=[exp.Star()], valid_from_col=exp.column("valid_from", quoted=True), valid_to_col=exp.column("valid_to", quoted=True), execution_time="2020-01-02", diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index 489d69683b..c946adc7ec 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -8,7 +8,7 @@ from sqlmesh import Context from sqlmesh.core.console import NoopConsole, get_console from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind -from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange +from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange, SCDType2ByColumnKind from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model.meta import GrantsTargetLayer @@ -707,6 +707,29 @@ def test_load_multiple_snapshots_defined_in_same_file(sushi_test_dbt_context: Co assert context.get_model("snapshots.items_check_snapshot") +@pytest.mark.slow +def test_dbt_snapshot_with_check_cols_expressions(sushi_test_dbt_context: Context) -> None: + context = sushi_test_dbt_context + model = context.get_model("snapshots.items_check_with_cast_snapshot") + assert model is not None + assert isinstance(model.kind, SCDType2ByColumnKind) + + columns = model.kind.columns + assert isinstance(columns, list) + assert len(columns) == 1 + + # expression in check_cols is: ds::DATE + assert isinstance(columns[0], exp.Cast) + assert columns[0].sql() == 'CAST("ds" AS DATE)' + + context.load() + cached_model = context.get_model("snapshots.items_check_with_cast_snapshot") + assert cached_model is not None + assert isinstance(cached_model.kind, SCDType2ByColumnKind) + assert isinstance(cached_model.kind.columns, list) + assert len(cached_model.kind.columns) == 1 + + @pytest.mark.slow def test_dbt_jinja_macro_undefined_variable_error(create_empty_project): project_dir, model_dir = create_empty_project() diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 0a6db38361..c964b5b56b 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -292,6 +292,32 @@ def test_model_kind(): on_additive_change=OnAdditiveChange.ALLOW, ) + check_cols_with_cast = ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + strategy="check", + check_cols=["created_at::TIMESTAMPTZ"], + ).model_kind(context) + assert isinstance(check_cols_with_cast, SCDType2ByColumnKind) + assert check_cols_with_cast.execution_time_as_valid_from is True + assert len(check_cols_with_cast.columns) == 1 + assert isinstance(check_cols_with_cast.columns[0], exp.Cast) + assert check_cols_with_cast.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)' + + check_cols_multiple_expr = ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + strategy="check", + check_cols=["created_at::TIMESTAMPTZ", "COALESCE(status, 'active')"], + ).model_kind(context) + assert isinstance(check_cols_multiple_expr, SCDType2ByColumnKind) + assert len(check_cols_multiple_expr.columns) == 2 + assert isinstance(check_cols_multiple_expr.columns[0], exp.Cast) + assert isinstance(check_cols_multiple_expr.columns[1], exp.Coalesce) + + assert check_cols_multiple_expr.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)' + assert check_cols_multiple_expr.columns[1].sql() == "COALESCE(\"status\", 'active')" + assert ModelConfig(materialized=Materialization.INCREMENTAL, time_column="foo").model_kind( context ) == IncrementalByTimeRangeKind( diff --git a/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql b/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql index 77d79d03ba..fbce585edf 100644 --- a/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql +++ b/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql @@ -30,3 +30,19 @@ select * from {{ source('streaming', 'items') }} select * from {{ source('streaming', 'items') }} {% endsnapshot %} + +{% snapshot items_check_with_cast_snapshot %} + +{{ + config( + target_schema='snapshots', + unique_key='id', + strategy='check', + check_cols=['ds::DATE'], + invalidate_hard_deletes=True, + ) +}} + +select * from {{ source('streaming', 'items') }} + +{% endsnapshot %}