Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def scd_type_2_by_column(
self,
target_table: TableName,
source_table: QueryOrDF,
unique_key: t.Sequence[exp.Column],
unique_key: t.Sequence[exp.Expression],
valid_from_name: str,
valid_to_name: str,
execution_time: TimeLike,
Expand Down Expand Up @@ -1248,7 +1248,7 @@ def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
unique_key: t.Union[t.Sequence[exp.Expression], t.Sequence[exp.Column]],
unique_key: t.Sequence[exp.Expression],
valid_from_name: str,
valid_to_name: str,
execution_time: TimeLike,
Expand Down
3 changes: 1 addition & 2 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from sqlmesh.utils.pydantic import (
PydanticModel,
SQLGlotBool,
SQLGlotListOfColumns,
SQLGlotListOfColumnsOrStar,
SQLGlotListOfFields,
SQLGlotPositiveInt,
Expand Down Expand Up @@ -336,7 +335,7 @@ class FullKind(_ModelKind):


class _SCDType2Kind(_ModelKind):
unique_key: SQLGlotListOfColumns
unique_key: SQLGlotListOfFields
valid_from_name: SQLGlotString = "valid_from"
valid_to_name: SQLGlotString = "valid_to"
invalidate_hard_deletes: SQLGlotBool = False
Expand Down
11 changes: 3 additions & 8 deletions sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,9 @@ def time_column(self) -> t.Optional[TimeColumn]:

@property
def unique_key(self) -> t.List[exp.Expression]:
if isinstance(self.kind, IncrementalByUniqueKeyKind):
return self.kind.unique_key
return []

@property
def unique_key_columns(self) -> t.List[exp.Column]:
if self.kind.is_scd_type_2:
assert isinstance(self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind))
if isinstance(
self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind, IncrementalByUniqueKeyKind)
):
return self.kind.unique_key
return []

Expand Down
8 changes: 4 additions & 4 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ def insert(
self.adapter.scd_type_2_by_time(
target_table=name,
source_table=query_or_df,
unique_key=model.unique_key_columns,
unique_key=model.unique_key,
valid_from_name=model.kind.valid_from_name,
valid_to_name=model.kind.valid_to_name,
updated_at_name=model.kind.updated_at_name,
Expand All @@ -1295,7 +1295,7 @@ def insert(
self.adapter.scd_type_2_by_column(
target_table=name,
source_table=query_or_df,
unique_key=model.unique_key_columns,
unique_key=model.unique_key,
valid_from_name=model.kind.valid_from_name,
valid_to_name=model.kind.valid_to_name,
check_columns=model.kind.columns,
Expand Down Expand Up @@ -1325,7 +1325,7 @@ def append(
self.adapter.scd_type_2_by_time(
target_table=table_name,
source_table=query_or_df,
unique_key=model.unique_key_columns,
unique_key=model.unique_key,
valid_from_name=model.kind.valid_from_name,
valid_to_name=model.kind.valid_to_name,
updated_at_name=model.kind.updated_at_name,
Expand All @@ -1340,7 +1340,7 @@ def append(
self.adapter.scd_type_2_by_column(
target_table=table_name,
source_table=query_or_df,
unique_key=model.unique_key_columns,
unique_key=model.unique_key,
valid_from_name=model.kind.valid_from_name,
valid_to_name=model.kind.valid_to_name,
check_columns=model.kind.columns,
Expand Down
14 changes: 0 additions & 14 deletions sqlmesh/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from sqlmesh.core import dialect as d
from sqlmesh.utils import str_to_bool
from sqlmesh.utils.errors import SQLMeshError

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down Expand Up @@ -324,14 +323,6 @@ def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expression]:
return _get_fields(v, values)


def list_of_columns_validator(v: t.Any, values: t.Any) -> t.List[exp.Column]:
expressions = _get_fields(v, values)
for expression in expressions:
if not isinstance(expression, exp.Column):
raise SQLMeshError(f"Invalid column {expression}. Value must be a column")
return t.cast(t.List[exp.Column], expressions)


def list_of_columns_or_star_validator(
v: t.Any, values: t.Any
) -> t.Union[exp.Star, t.List[exp.Column]]:
Expand All @@ -347,7 +338,6 @@ def list_of_columns_or_star_validator(
SQLGlotBool = bool
SQLGlotPositiveInt = int
SQLGlotListOfFields = t.List[exp.Expression]
SQLGlotListOfColumns = t.List[exp.Column]
SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star]
elif PYDANTIC_MAJOR_VERSION >= 2:
from pydantic.functional_validators import BeforeValidator # type: ignore
Expand All @@ -359,7 +349,6 @@ def list_of_columns_or_star_validator(
SQLGlotListOfFields = Annotated[
t.List[exp.Expression], BeforeValidator(list_of_fields_validator)
]
SQLGlotListOfColumns = Annotated[t.List[exp.Column], BeforeValidator(list_of_columns_validator)]
SQLGlotListOfColumnsOrStar = Annotated[
t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator)
]
Expand Down Expand Up @@ -387,8 +376,5 @@ class SQLGlotPositiveInt(PydanticTypeProxy[int]):
class SQLGlotListOfFields(PydanticTypeProxy[t.List[exp.Expression]]):
validate = list_of_fields_validator

class SQLGlotListOfColumns(PydanticTypeProxy[t.List[exp.Column]]):
validate = list_of_columns_validator

class SQLGlotListOfColumnsOrStar(PydanticTypeProxy[t.Union[exp.Star, t.List[exp.Column]]]):
validate = list_of_columns_or_star_validator
39 changes: 31 additions & 8 deletions tests/core/engine_adapter/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
source_table=t.cast(
exp.Select, parse_one("SELECT id, name, price, test_updated_at FROM source")
),
unique_key=[exp.func("COALESCE", "id", "''")],
unique_key=[
parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""),
parse_one("""COALESCE("name", '')"""),
],
valid_from_name="test_valid_from",
valid_to_name="test_valid_to",
updated_at_name="test_updated_at",
Expand All @@ -999,7 +1002,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
"""
CREATE OR REPLACE TABLE "target" AS
WITH "source" AS (
SELECT DISTINCT ON (COALESCE("id", ''))
SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", ''))
TRUE AS "_exists",
"id",
"name",
Expand Down Expand Up @@ -1045,17 +1048,24 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
"static"."test_valid_to"
FROM "static"
LEFT JOIN "latest"
ON COALESCE("static"."id", '') = COALESCE("latest"."id", '')
ON (
COALESCE("static"."id", '') || '|' || COALESCE("static"."name", '')
) = (
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
)
AND COALESCE("static"."name", '') = COALESCE("latest"."name", '')
WHERE
"latest"."test_valid_to" IS NULL
), "latest_deleted" AS (
SELECT
TRUE AS "_exists",
COALESCE("id", '') AS "_key0",
COALESCE("id", '') || '|' || COALESCE("name", '') AS "_key0",
COALESCE("name", '') AS "_key1",
MAX("test_valid_to") AS "test_valid_to"
FROM "deleted"
GROUP BY
COALESCE("id", '')
COALESCE("id", '') || '|' || COALESCE("name", ''),
COALESCE("name", '')
), "joined" AS (
SELECT
"source"."_exists",
Expand All @@ -1071,7 +1081,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
"source"."test_updated_at" AS "test_updated_at"
FROM "latest"
LEFT JOIN "source"
ON COALESCE("latest"."id", '') = COALESCE("source"."id", '')
ON (
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
) = (
COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '')
)
AND COALESCE("latest"."name", '') = COALESCE("source"."name", '')
UNION
SELECT
"source"."_exists",
Expand All @@ -1087,7 +1102,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
"source"."test_updated_at" AS "test_updated_at"
FROM "latest"
RIGHT JOIN "source"
ON COALESCE("latest"."id", '') = COALESCE("source"."id", '')
ON (
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
) = (
COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '')
)
AND COALESCE("latest"."name", '') = COALESCE("source"."name", '')
), "updated_rows" AS (
SELECT
COALESCE("joined"."t_id", "joined"."id") AS "id",
Expand All @@ -1114,7 +1134,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
END AS "test_valid_to"
FROM "joined"
LEFT JOIN "latest_deleted"
ON COALESCE("joined"."id", '') = "latest_deleted"."_key0"
ON (
COALESCE("joined"."id", '') || '|' || COALESCE("joined"."name", '')
) = "latest_deleted"."_key0"
AND COALESCE("joined"."name", '') = "latest_deleted"."_key1"
), "inserted_rows" AS (
SELECT
"id",
Expand Down
2 changes: 1 addition & 1 deletion tests/core/engine_adapter/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def test_scd_type_2_by_time(ctx: TestContext):
ctx.engine_adapter.scd_type_2_by_time(
table,
ctx.input_data(input_data, input_schema),
unique_key=[exp.to_identifier("id")],
unique_key=[parse_one("COALESCE(id, -1)")],
valid_from_name="valid_from",
valid_to_name="valid_to",
updated_at_name="updated_at",
Expand Down
13 changes: 8 additions & 5 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,7 +2460,7 @@ def test_scd_type_2_by_time_defaults():
MODEL (
name db.table,
kind SCD_TYPE_2 (
unique_key "ID",
unique_key (COALESCE("ID", '') || '|' || COALESCE("ds", ''), COALESCE("ds", '')),
),
);
SELECT
Expand All @@ -2473,7 +2473,10 @@ def test_scd_type_2_by_time_defaults():
"""
)
scd_type_2_model = load_sql_based_model(model_def)
assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)]
assert scd_type_2_model.unique_key == [
parse_one("""COALESCE("ID", '') || '|' || COALESCE("ds", '')"""),
parse_one("""COALESCE("ds", '')"""),
]
assert scd_type_2_model.columns_to_types == {
"ID": exp.DataType.build("int"),
"ds": exp.DataType.build("varchar"),
Expand Down Expand Up @@ -2525,7 +2528,7 @@ def test_scd_type_2_by_time_overrides():
"""
)
scd_type_2_model = load_sql_based_model(model_def)
assert scd_type_2_model.unique_key_columns == [
assert scd_type_2_model.unique_key == [
exp.column("iD", quoted=True),
exp.column("ds", quoted=False),
]
Expand Down Expand Up @@ -2566,7 +2569,7 @@ def test_scd_type_2_by_column_defaults():
"""
)
scd_type_2_model = load_sql_based_model(model_def)
assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)]
assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)]
assert scd_type_2_model.kind.columns == [exp.to_column("value_to_track", quoted=True)]
assert scd_type_2_model.columns_to_types == {
"ID": exp.DataType.build("int"),
Expand Down Expand Up @@ -2614,7 +2617,7 @@ def test_scd_type_2_by_column_overrides():
"""
)
scd_type_2_model = load_sql_based_model(model_def)
assert scd_type_2_model.unique_key_columns == [
assert scd_type_2_model.unique_key == [
exp.column("iD", quoted=True),
exp.column("ds", quoted=False),
]
Expand Down