From 239242d19f05005390d5d321c2647393a2ea08d9 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:31:31 -0800 Subject: [PATCH] fix!: allow expressions in scd type 2 model unique keys --- sqlmesh/core/engine_adapter/base.py | 4 +- sqlmesh/core/model/kind.py | 3 +- sqlmesh/core/model/meta.py | 11 ++---- sqlmesh/core/snapshot/evaluator.py | 8 ++-- sqlmesh/utils/pydantic.py | 14 ------- tests/core/engine_adapter/test_base.py | 39 +++++++++++++++---- tests/core/engine_adapter/test_integration.py | 2 +- tests/core/test_model.py | 13 ++++--- 8 files changed, 50 insertions(+), 44 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 794b446664..b0366d5732 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1217,7 +1217,7 @@ def scd_type_2_by_column( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Sequence[exp.Column], + unique_key: t.Sequence[exp.Expression], valid_from_name: str, valid_to_name: str, execution_time: TimeLike, @@ -1248,7 +1248,7 @@ def _scd_type_2( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Union[t.Sequence[exp.Expression], t.Sequence[exp.Column]], + unique_key: t.Sequence[exp.Expression], valid_from_name: str, valid_to_name: str, execution_time: TimeLike, diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index bfc1835cc0..83ef88b8d7 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -15,7 +15,6 @@ from sqlmesh.utils.pydantic import ( PydanticModel, SQLGlotBool, - SQLGlotListOfColumns, SQLGlotListOfColumnsOrStar, SQLGlotListOfFields, SQLGlotPositiveInt, @@ -336,7 +335,7 @@ class FullKind(_ModelKind): class _SCDType2Kind(_ModelKind): - unique_key: SQLGlotListOfColumns + unique_key: SQLGlotListOfFields valid_from_name: SQLGlotString = "valid_from" valid_to_name: SQLGlotString = "valid_to" invalidate_hard_deletes: SQLGlotBool = False diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 2954eda698..3105d6fe76 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -287,14 +287,9 @@ def time_column(self) -> t.Optional[TimeColumn]: @property def unique_key(self) -> t.List[exp.Expression]: - if isinstance(self.kind, IncrementalByUniqueKeyKind): - return self.kind.unique_key - return [] - - @property - def unique_key_columns(self) -> t.List[exp.Column]: - if self.kind.is_scd_type_2: - assert isinstance(self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind)) + if isinstance( + self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind, IncrementalByUniqueKeyKind) + ): return self.kind.unique_key return [] diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 17d191d20c..164da28a6d 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1280,7 +1280,7 @@ def insert( self.adapter.scd_type_2_by_time( target_table=name, source_table=query_or_df, - unique_key=model.unique_key_columns, + unique_key=model.unique_key, valid_from_name=model.kind.valid_from_name, valid_to_name=model.kind.valid_to_name, updated_at_name=model.kind.updated_at_name, @@ -1295,7 +1295,7 @@ def insert( self.adapter.scd_type_2_by_column( target_table=name, source_table=query_or_df, - unique_key=model.unique_key_columns, + unique_key=model.unique_key, valid_from_name=model.kind.valid_from_name, valid_to_name=model.kind.valid_to_name, check_columns=model.kind.columns, @@ -1325,7 +1325,7 @@ def append( self.adapter.scd_type_2_by_time( target_table=table_name, source_table=query_or_df, - unique_key=model.unique_key_columns, + unique_key=model.unique_key, valid_from_name=model.kind.valid_from_name, valid_to_name=model.kind.valid_to_name, updated_at_name=model.kind.updated_at_name, @@ -1340,7 +1340,7 @@ def append( self.adapter.scd_type_2_by_column( target_table=table_name, source_table=query_or_df, - unique_key=model.unique_key_columns, + unique_key=model.unique_key, valid_from_name=model.kind.valid_from_name, valid_to_name=model.kind.valid_to_name, check_columns=model.kind.columns, diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 302f352d1e..e408434bd6 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -13,7 +13,6 @@ from sqlmesh.core import dialect as d from sqlmesh.utils import str_to_bool -from sqlmesh.utils.errors import SQLMeshError if sys.version_info >= (3, 9): from typing import Annotated @@ -324,14 +323,6 @@ def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expression]: return _get_fields(v, values) -def list_of_columns_validator(v: t.Any, values: t.Any) -> t.List[exp.Column]: - expressions = _get_fields(v, values) - for expression in expressions: - if not isinstance(expression, exp.Column): - raise SQLMeshError(f"Invalid column {expression}. Value must be a column") - return t.cast(t.List[exp.Column], expressions) - - def list_of_columns_or_star_validator( v: t.Any, values: t.Any ) -> t.Union[exp.Star, t.List[exp.Column]]: @@ -347,7 +338,6 @@ def list_of_columns_or_star_validator( SQLGlotBool = bool SQLGlotPositiveInt = int SQLGlotListOfFields = t.List[exp.Expression] - SQLGlotListOfColumns = t.List[exp.Column] SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star] elif PYDANTIC_MAJOR_VERSION >= 2: from pydantic.functional_validators import BeforeValidator # type: ignore @@ -359,7 +349,6 @@ def list_of_columns_or_star_validator( SQLGlotListOfFields = Annotated[ t.List[exp.Expression], BeforeValidator(list_of_fields_validator) ] - SQLGlotListOfColumns = Annotated[t.List[exp.Column], BeforeValidator(list_of_columns_validator)] SQLGlotListOfColumnsOrStar = Annotated[ t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator) ] @@ -387,8 +376,5 @@ class SQLGlotPositiveInt(PydanticTypeProxy[int]): class SQLGlotListOfFields(PydanticTypeProxy[t.List[exp.Expression]]): validate = list_of_fields_validator - class SQLGlotListOfColumns(PydanticTypeProxy[t.List[exp.Column]]): - validate = list_of_columns_validator - class SQLGlotListOfColumnsOrStar(PydanticTypeProxy[t.Union[exp.Star, t.List[exp.Column]]]): validate = list_of_columns_or_star_validator diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index fa0fa39632..bea51f86dd 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -978,7 +978,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): source_table=t.cast( exp.Select, parse_one("SELECT id, name, price, test_updated_at FROM source") ), - unique_key=[exp.func("COALESCE", "id", "''")], + unique_key=[ + parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""), + parse_one("""COALESCE("name", '')"""), + ], valid_from_name="test_valid_from", valid_to_name="test_valid_to", updated_at_name="test_updated_at", @@ -999,7 +1002,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): """ CREATE OR REPLACE TABLE "target" AS WITH "source" AS ( - SELECT DISTINCT ON (COALESCE("id", '')) + SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", '')) TRUE AS "_exists", "id", "name", @@ -1045,17 +1048,24 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): "static"."test_valid_to" FROM "static" LEFT JOIN "latest" - ON COALESCE("static"."id", '') = COALESCE("latest"."id", '') + ON ( + COALESCE("static"."id", '') || '|' || COALESCE("static"."name", '') + ) = ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) + AND COALESCE("static"."name", '') = COALESCE("latest"."name", '') WHERE "latest"."test_valid_to" IS NULL ), "latest_deleted" AS ( SELECT TRUE AS "_exists", - COALESCE("id", '') AS "_key0", + COALESCE("id", '') || '|' || COALESCE("name", '') AS "_key0", + COALESCE("name", '') AS "_key1", MAX("test_valid_to") AS "test_valid_to" FROM "deleted" GROUP BY - COALESCE("id", '') + COALESCE("id", '') || '|' || COALESCE("name", ''), + COALESCE("name", '') ), "joined" AS ( SELECT "source"."_exists", @@ -1071,7 +1081,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): "source"."test_updated_at" AS "test_updated_at" FROM "latest" LEFT JOIN "source" - ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') + ON ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) = ( + COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '') + ) + AND COALESCE("latest"."name", '') = COALESCE("source"."name", '') UNION SELECT "source"."_exists", @@ -1087,7 +1102,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): "source"."test_updated_at" AS "test_updated_at" FROM "latest" RIGHT JOIN "source" - ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') + ON ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) = ( + COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '') + ) + AND COALESCE("latest"."name", '') = COALESCE("source"."name", '') ), "updated_rows" AS ( SELECT COALESCE("joined"."t_id", "joined"."id") AS "id", @@ -1114,7 +1134,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): END AS "test_valid_to" FROM "joined" LEFT JOIN "latest_deleted" - ON COALESCE("joined"."id", '') = "latest_deleted"."_key0" + ON ( + COALESCE("joined"."id", '') || '|' || COALESCE("joined"."name", '') + ) = "latest_deleted"."_key0" + AND COALESCE("joined"."name", '') = "latest_deleted"."_key1" ), "inserted_rows" AS ( SELECT "id", diff --git a/tests/core/engine_adapter/test_integration.py b/tests/core/engine_adapter/test_integration.py index 894b597c52..11597f1f69 100644 --- a/tests/core/engine_adapter/test_integration.py +++ b/tests/core/engine_adapter/test_integration.py @@ -1179,7 +1179,7 @@ def test_scd_type_2_by_time(ctx: TestContext): ctx.engine_adapter.scd_type_2_by_time( table, ctx.input_data(input_data, input_schema), - unique_key=[exp.to_identifier("id")], + unique_key=[parse_one("COALESCE(id, -1)")], valid_from_name="valid_from", valid_to_name="valid_to", updated_at_name="updated_at", diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 172d73100d..befb56c8ab 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -2460,7 +2460,7 @@ def test_scd_type_2_by_time_defaults(): MODEL ( name db.table, kind SCD_TYPE_2 ( - unique_key "ID", + unique_key (COALESCE("ID", '') || '|' || COALESCE("ds", ''), COALESCE("ds", '')), ), ); SELECT @@ -2473,7 +2473,10 @@ def test_scd_type_2_by_time_defaults(): """ ) scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)] + assert scd_type_2_model.unique_key == [ + parse_one("""COALESCE("ID", '') || '|' || COALESCE("ds", '')"""), + parse_one("""COALESCE("ds", '')"""), + ] assert scd_type_2_model.columns_to_types == { "ID": exp.DataType.build("int"), "ds": exp.DataType.build("varchar"), @@ -2525,7 +2528,7 @@ def test_scd_type_2_by_time_overrides(): """ ) scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key_columns == [ + assert scd_type_2_model.unique_key == [ exp.column("iD", quoted=True), exp.column("ds", quoted=False), ] @@ -2566,7 +2569,7 @@ def test_scd_type_2_by_column_defaults(): """ ) scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)] + assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)] assert scd_type_2_model.kind.columns == [exp.to_column("value_to_track", quoted=True)] assert scd_type_2_model.columns_to_types == { "ID": exp.DataType.build("int"), @@ -2614,7 +2617,7 @@ def test_scd_type_2_by_column_overrides(): """ ) scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key_columns == [ + assert scd_type_2_model.unique_key == [ exp.column("iD", quoted=True), exp.column("ds", quoted=False), ]