From b182c02376c4c0f7cc3c611b9ac74254393e6efe Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Fri, 12 Jul 2024 15:35:55 +0300 Subject: [PATCH 1/8] Feat: support pre-post-statements in python at creation --- sqlmesh/core/model/definition.py | 170 +++++++++++++------------------ 1 file changed, 71 insertions(+), 99 deletions(-) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 5d9e603775..ad3334c81f 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -117,6 +117,14 @@ class _Model(ModelMeta, frozen=True): mapping_schema: t.Dict[str, t.Any] = {} _full_depends_on: t.Optional[t.Set[str]] = None + __statement_renderers: t.Dict[int, ExpressionRenderer] = {} + + pre_statements_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="pre_statements" + ) + post_statements_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="post_statements" + ) _expressions_validator = expression_validator @@ -335,7 +343,17 @@ def render_pre_statements( Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + self.pre_statements, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) def render_post_statements( self, @@ -367,7 +385,58 @@ def render_post_statements( Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + self.post_statements, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) + + @property + def pre_statements(self) -> t.List[exp.Expression]: + return self.pre_statements_ or [] + + @property + def post_statements(self) -> t.List[exp.Expression]: + return self.post_statements_ or [] + + @property + def macro_definitions(self) -> t.List[d.MacroDef]: + """All macro definitions from the list of expressions.""" + return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] + + def _render_statements( + self, + statements: t.Iterable[exp.Expression], + **kwargs: t.Any, + ) -> t.List[exp.Expression]: + rendered = ( + self._statement_renderer(statement).render(**kwargs) + for statement in statements + if not isinstance(statement, d.MacroDef) + ) + return [r for expressions in rendered if expressions for r in expressions] + + def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + expression_key = id(expression) + if expression_key not in self.__statement_renderers: + self.__statement_renderers[expression_key] = ExpressionRenderer( + expression, + self.dialect, + self.macro_definitions, + path=self._path, + jinja_macro_registry=self.jinja_macros, + python_env=self.python_env, + only_execution_time=self.kind.only_execution_time, + default_catalog=self.default_catalog, + model_fqn=self.fqn, + ) + return self.__statement_renderers[expression_key] def render_signals( self, @@ -857,16 +926,8 @@ def full_depends_on(self) -> t.Set[str]: class _SqlBasedModel(_Model): - pre_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="pre_statements" - ) - post_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="post_statements" - ) inline_audits_: t.Dict[str, t.Any] = Field(default={}, alias="inline_audits") - __statement_renderers: t.Dict[int, ExpressionRenderer] = {} - _expression_validator = expression_validator @field_validator("inline_audits_", mode="before") @@ -887,99 +948,10 @@ def _inline_audits_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any return inline_audits - def render_pre_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.pre_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - - def render_post_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.post_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - - @property - def pre_statements(self) -> t.List[exp.Expression]: - return self.pre_statements_ or [] - - @property - def post_statements(self) -> t.List[exp.Expression]: - return self.post_statements_ or [] - - @property - def macro_definitions(self) -> t.List[d.MacroDef]: - """All macro definitions from the list of expressions.""" - return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] - @property def inline_audits(self) -> t.Dict[str, ModelAudit]: return self.inline_audits_ - def _render_statements( - self, - statements: t.Iterable[exp.Expression], - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - rendered = ( - self._statement_renderer(statement).render(**kwargs) - for statement in statements - if not isinstance(statement, d.MacroDef) - ) - return [r for expressions in rendered if expressions for r in expressions] - - def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: - expression_key = id(expression) - if expression_key not in self.__statement_renderers: - self.__statement_renderers[expression_key] = ExpressionRenderer( - expression, - self.dialect, - self.macro_definitions, - path=self._path, - jinja_macro_registry=self.jinja_macros, - python_env=self.python_env, - only_execution_time=self.kind.only_execution_time, - default_catalog=self.default_catalog, - model_fqn=self.fqn, - ) - return self.__statement_renderers[expression_key] - @property def _data_hash_values(self) -> t.List[str]: data_hash_values = super()._data_hash_values From 52de55e29de102de36f6d422c7ccc3b50d8ccdee Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Tue, 16 Jul 2024 18:44:57 +0300 Subject: [PATCH 2/8] Load macros in python models --- sqlmesh/core/loader.py | 7 +++++-- tests/core/test_model.py | 15 ++++++++++++++- tests/integrations/jupyter/test_magics.py | 4 ++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index a77e5b4e9d..269f5ce49d 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -329,7 +329,7 @@ def _load_models( """ models = self._load_sql_models(macros, jinja_macros, audits) models.update(self._load_external_models(gateway)) - models.update(self._load_python_models()) + models.update(self._load_python_models(macros)) return models @@ -392,7 +392,7 @@ def _load() -> Model: return models - def _load_python_models(self) -> UniqueKeyDict[str, Model]: + def _load_python_models(self, macros: MacroRegistry) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") registry = model_registry.registry() @@ -414,6 +414,9 @@ def _load_python_models(self) -> UniqueKeyDict[str, Model]: new = registry.keys() - registered registered |= new for name in new: + if macros: + macro.set_registry(macros) + model = registry[name].model( path=path, module_path=context_path, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 7d5827273f..25d3357a32 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1638,7 +1638,14 @@ def test_parse(assert_exp_eq): def test_python_model(assert_exp_eq) -> None: from functools import reduce - @model(name="my_model", kind="full", columns={'"COL"': "int"}, enabled=True) + @model( + name="my_model", + kind="full", + columns={'"COL"': "int"}, + pre_statements=["CACHE TABLE x AS SELECT 1;"], + post_statements=["DROP TABLE x;"], + enabled=True, + ) def my_model(context, **kwargs): context.table("foo") context.table(model_name=CONST + ".baz") @@ -1652,6 +1659,12 @@ def my_model(context, **kwargs): dialect="duckdb", ) + assert list(m.pre_statements) == [ + d.parse_one("CACHE TABLE x AS SELECT 1"), + ] + assert list(m.post_statements) == [ + d.parse_one("DROP TABLE x"), + ] assert m.enabled assert m.dialect == "duckdb" assert m.depends_on == {'"foo"', '"bar"."baz"'} diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index ee6b6c5519..0ee081a637 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -552,7 +552,7 @@ def test_info(notebook, sushi_context, convert_all_html_output_to_text, get_all_ assert len(output.outputs) == 3 assert convert_all_html_output_to_text(output) == [ "Models: 17", - "Macros: 5", + "Macros: 0", "Data warehouse connection succeeded", ] assert get_all_html_output(output) == [ @@ -582,7 +582,7 @@ def test_info(notebook, sushi_context, convert_all_html_output_to_text, get_all_ h( "span", {"style": f"{NEUTRAL_STYLE}; font-weight: bold"}, - "5", + "0", autoescape=False, ) ), From 306656be601b678ea13347da50f72b8b7f7df6cd Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Wed, 17 Jul 2024 14:13:15 +0300 Subject: [PATCH 3/8] Adapt logic for macros --- sqlmesh/core/loader.py | 8 ++++++-- sqlmesh/core/model/decorator.py | 13 ++++++++++++- sqlmesh/core/model/definition.py | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 269f5ce49d..de5f9ac281 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -329,7 +329,7 @@ def _load_models( """ models = self._load_sql_models(macros, jinja_macros, audits) models.update(self._load_external_models(gateway)) - models.update(self._load_python_models(macros)) + models.update(self._load_python_models(macros, jinja_macros)) return models @@ -392,7 +392,9 @@ def _load() -> Model: return models - def _load_python_models(self, macros: MacroRegistry) -> UniqueKeyDict[str, Model]: + def _load_python_models( + self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry + ) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") registry = model_registry.registry() @@ -421,6 +423,8 @@ def _load_python_models(self, macros: MacroRegistry) -> UniqueKeyDict[str, Model path=path, module_path=context_path, defaults=config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, dialect=config.model_defaults.dialect, time_column_format=config.time_column_format, physical_schema_override=config.physical_schema_override, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 6f4f0c9f18..4bdf73c86f 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -8,6 +8,8 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType +from sqlmesh.core.macros import MacroRegistry +from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.core import constants as c from sqlmesh.core.dialect import MacroFunc from sqlmesh.core.model.definition import ( @@ -75,6 +77,8 @@ def model( module_path: Path, path: Path, defaults: t.Optional[t.Dict[str, t.Any]] = None, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, dialect: t.Optional[str] = None, time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, physical_schema_override: t.Optional[t.Dict[str, str]] = None, @@ -132,5 +136,12 @@ def model( ) return create_python_model( - self.name, entrypoint, columns=self.columns, dialect=dialect, **common_kwargs + self.name, + entrypoint, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + columns=self.columns, + dialect=dialect, + **common_kwargs, ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index ad3334c81f..3ac72963d5 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1838,8 +1838,11 @@ def create_seed_model( def create_python_model( name: str, entrypoint: str, + module_path: Path, python_env: t.Dict[str, Executable], *, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, defaults: t.Optional[t.Dict[str, t.Any]] = None, path: Path = Path(), time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, @@ -1862,6 +1865,21 @@ def create_python_model( """ # Find dependencies for python models by parsing code if they are not explicitly defined # Also remove self-references that are found + + pre_statements = kwargs.get("pre_statements", None) or [] + post_statements = kwargs.get("post_statements", None) or [] + + python_env.update( + _python_env( + expressions=[*pre_statements, *post_statements], + jinja_macro_references=None, + module_path=module_path, + macros=macros or macro.get_registry(), + variables=variables, + path=path, + ) + ) + parsed_depends_on, referenced_variables = ( _parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set()) ) @@ -1881,6 +1899,7 @@ def create_python_model( depends_on=depends_on, entrypoint=entrypoint, python_env=python_env, + jinja_macros=jinja_macros, physical_schema_override=physical_schema_override, **kwargs, ) From 50829c67cc90c64f80a7f61b9e8b2ea1c67438e8 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:39:19 +0300 Subject: [PATCH 4/8] Add unit test for pre/post staments and macro call --- sqlmesh/core/loader.py | 3 -- tests/core/test_snapshot_evaluator.py | 66 ++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index de5f9ac281..30200ec986 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -416,9 +416,6 @@ def _load_python_models( new = registry.keys() - registered registered |= new for name in new: - if macros: - macro.set_registry(macros) - model = registry[name].model( path=path, module_path=context_path, diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index cac20b35ff..5440f9b832 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -4,6 +4,7 @@ import logging import pytest +import pandas as pd from pathlib import Path from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp @@ -20,7 +21,7 @@ InsertOverwriteStrategy, ) from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.macros import RuntimeStage, macro +from sqlmesh.core.macros import RuntimeStage, macro, MacroEvaluator, MacroFunc from sqlmesh.core.model import ( Model, FullKind, @@ -33,6 +34,7 @@ ViewKind, load_sql_based_model, ExternalModel, + model, ) from sqlmesh.core.model.kind import OnDestructiveChange, ExternalKind from sqlmesh.core.node import IntervalUnit @@ -2215,6 +2217,68 @@ def test_create_post_statements_use_deployable_table( assert post_calls[0].sql(dialect="postgres") == expected_call +def test_create_pre_post_statements_python_model( + mocker: MockerFixture, adapter_mock, make_snapshot +): + evaluator = SnapshotEvaluator(adapter_mock) + + @macro() + def create_index( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + if evaluator.runtime_stage == "creating": + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=["CREATE INDEX IF NOT EXISTS idx ON db.test_model(id);"], + post_statements=["@CREATE_INDEX('idx', 'db.test_model', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert len(python_model.pre_statements) == 1 + assert len(python_model.post_statements) == 1 + assert isinstance(python_model.python_env["create_index"], Executable) + assert isinstance(python_model.pre_statements[0], exp.Create) + assert isinstance(python_model.post_statements[0], MacroFunc) + + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}" /* db.test_model */("id")' + + call_args = adapter_mock.execute.call_args_list + pre_calls = call_args[0][0][0] + assert len(pre_calls) == 1 + assert pre_calls[0].sql(dialect="postgres") == expected_call + + post_calls = call_args[1][0][0] + assert len(post_calls) == 1 + assert post_calls[0].sql(dialect="postgres") == expected_call + + def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): model = SqlModel( name="test_schema.test_model", From 38748f79dcf935a158cf0b128f890c80aff6e297 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Thu, 1 Aug 2024 19:12:33 +0300 Subject: [PATCH 5/8] Adjust failing test --- tests/integrations/jupyter/test_magics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index 0ee081a637..ee6b6c5519 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -552,7 +552,7 @@ def test_info(notebook, sushi_context, convert_all_html_output_to_text, get_all_ assert len(output.outputs) == 3 assert convert_all_html_output_to_text(output) == [ "Models: 17", - "Macros: 0", + "Macros: 5", "Data warehouse connection succeeded", ] assert get_all_html_output(output) == [ @@ -582,7 +582,7 @@ def test_info(notebook, sushi_context, convert_all_html_output_to_text, get_all_ h( "span", {"style": f"{NEUTRAL_STYLE}; font-weight: bold"}, - "0", + "5", autoescape=False, ) ), From 8d1116d6452c2c326532b67d4d3629c0cad603b9 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:24:54 +0300 Subject: [PATCH 6/8] Adapt to use jinja; add unit test --- sqlmesh/core/model/decorator.py | 4 +-- sqlmesh/core/model/definition.py | 31 ++++++++++------ tests/core/test_model.py | 62 +++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 13 deletions(-) diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 4bdf73c86f..987cee24cc 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -11,7 +11,7 @@ from sqlmesh.core.macros import MacroRegistry from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.core import constants as c -from sqlmesh.core.dialect import MacroFunc +from sqlmesh.core.dialect import MacroFunc, parse_one from sqlmesh.core.model.definition import ( Model, create_python_model, @@ -127,7 +127,7 @@ def model( for key in ("pre_statements", "post_statements"): statements = common_kwargs.get(key) if statements: - common_kwargs[key] = [exp.maybe_parse(s, dialect=dialect) for s in statements] + common_kwargs[key] = [parse_one(s, dialect=dialect) for s in statements] if self.is_sql: query = MacroFunc(this=exp.Anonymous(this=entrypoint)) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 3ac72963d5..ce000aec42 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1838,13 +1838,13 @@ def create_seed_model( def create_python_model( name: str, entrypoint: str, - module_path: Path, python_env: t.Dict[str, Executable], *, macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, defaults: t.Optional[t.Dict[str, t.Any]] = None, path: Path = Path(), + module_path: Path = Path(), time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, depends_on: t.Optional[t.Set[str]] = None, physical_schema_override: t.Optional[t.Dict[str, str]] = None, @@ -1869,16 +1869,27 @@ def create_python_model( pre_statements = kwargs.get("pre_statements", None) or [] post_statements = kwargs.get("post_statements", None) or [] - python_env.update( - _python_env( - expressions=[*pre_statements, *post_statements], - jinja_macro_references=None, - module_path=module_path, - macros=macros or macro.get_registry(), - variables=variables, - path=path, + if pre_statements or post_statements: + jinja_macro_references, used_variables = extract_macro_references_and_variables( + *(gen(e) for e in pre_statements), + *(gen(e) for e in post_statements), + ) + + jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_references) + for jinja_macro in jinja_macros.root_macros.values(): + used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + + python_env.update( + _python_env( + [*pre_statements, *post_statements], + jinja_macro_references, + module_path, + macros or macro.get_registry(), + variables=variables, + used_variables=used_variables, + path=path, + ) ) - ) parsed_depends_on, referenced_variables = ( _parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set()) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 25d3357a32..c46ea0d216 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -49,7 +49,7 @@ from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable @@ -947,6 +947,66 @@ def test_seed_with_special_characters_in_column(tmp_path, assert_exp_eq): ) +def test_python_model_jinja_pre_post_statements(): + macros = """ + {% macro test_macro(v) %}{{ v }}{% endmacro %} + {% macro extra_macro(v) %}{{ v + 1 }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=[ + "JINJA_STATEMENT_BEGIN;\n{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};\nJINJA_END;" + ], + post_statements=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{test_macro('idx')}} ON db.test_model(id);\nJINJA_END;", + "DROP TABLE x2;", + ], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 2 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "test_macro" in python_model.jinja_macros.root_macros + assert "extra_macro" not in python_model.jinja_macros.root_macros + + expected_pre = [ + d.jinja_statement( + "{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};" + ), + ] + assert python_model.pre_statements == expected_pre + assert python_model.render_pre_statements()[0].sql() == 'CREATE OR REPLACE TABLE "x2"' + + expected_post = [ + d.jinja_statement("CREATE INDEX {{test_macro('idx')}} ON db.test_model(id);"), + *d.parse("DROP TABLE x2;"), + ] + assert python_model.post_statements == expected_post + assert ( + python_model.render_post_statements()[0].sql() + == 'CREATE INDEX "idx" ON "db"."test_model"("id" NULLS LAST)' + ) + assert python_model.render_post_statements()[1].sql() == 'DROP TABLE "x2"' + + def test_audits(): expressions = d.parse( """ From c4b988771018901cdca2a67dc34e6bcbad153f8b Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:27:18 +0300 Subject: [PATCH 7/8] Re-enable support for expressions in statements; update tests accordingly --- sqlmesh/core/model/decorator.py | 4 +++- tests/core/test_model.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 987cee24cc..21a2f0ec15 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -127,7 +127,9 @@ def model( for key in ("pre_statements", "post_statements"): statements = common_kwargs.get(key) if statements: - common_kwargs[key] = [parse_one(s, dialect=dialect) for s in statements] + common_kwargs[key] = [ + parse_one(s, dialect=dialect) if isinstance(s, str) else s for s in statements + ] if self.is_sql: query = MacroFunc(this=exp.Anonymous(this=entrypoint)) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c46ea0d216..c2df9809e1 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -965,7 +965,7 @@ def test_python_model_jinja_pre_post_statements(): ], post_statements=[ "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{test_macro('idx')}} ON db.test_model(id);\nJINJA_END;", - "DROP TABLE x2;", + parse_one("DROP TABLE x2;"), ], ) def model_with_statements(context, **kwargs): From 30d607a2e201f24ebfaa67a3b62569ece5e1f652 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+Themiscodes@users.noreply.github.com> Date: Wed, 7 Aug 2024 19:19:20 +0300 Subject: [PATCH 8/8] Inclue pre/post statements in python models data hash --- sqlmesh/core/model/definition.py | 68 +++++++++++++------------------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index ce000aec42..aa0f430e16 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -824,6 +824,18 @@ def _data_hash_values(self) -> t.List[str]: data.append(key) data.append(gen(value)) + for statement in (*self.pre_statements, *self.post_statements): + statement_exprs: t.List[exp.Expression] = [] + if not isinstance(statement, d.MacroDef): + rendered = self._statement_renderer(statement).render() + if self._is_metadata_statement(statement): + continue + if rendered: + statement_exprs = rendered + else: + statement_exprs = [statement] + data.extend(gen(e) for e in statement_exprs) + return data # type: ignore def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: @@ -908,8 +920,24 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) + for statement in (*self.pre_statements, *self.post_statements): + if self._is_metadata_statement(statement): + additional_metadata.append(gen(statement)) + return additional_metadata + def _is_metadata_statement(self, statement: exp.Expression) -> bool: + if isinstance(statement, d.MacroDef): + return True + if isinstance(statement, d.MacroFunc): + target_macro = macro.get_registry().get(statement.name) + if target_macro: + return target_macro.metadata_only + target_macro = self.python_env.get(statement.name) + if target_macro: + return bool(target_macro.is_metadata) + return False + @property def full_depends_on(self) -> t.Set[str]: if not self._full_depends_on: @@ -952,46 +980,6 @@ def _inline_audits_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any def inline_audits(self) -> t.Dict[str, ModelAudit]: return self.inline_audits_ - @property - def _data_hash_values(self) -> t.List[str]: - data_hash_values = super()._data_hash_values - - for statement in (*self.pre_statements, *self.post_statements): - statement_exprs: t.List[exp.Expression] = [] - if not isinstance(statement, d.MacroDef): - rendered = self._statement_renderer(statement).render() - if self._is_metadata_statement(statement): - continue - if rendered: - statement_exprs = rendered - else: - statement_exprs = [statement] - data_hash_values.extend(gen(e) for e in statement_exprs) - - return data_hash_values - - @property - def _additional_metadata(self) -> t.List[str]: - additional_metadata = super()._additional_metadata - - for statement in (*self.pre_statements, *self.post_statements): - if self._is_metadata_statement(statement): - additional_metadata.append(gen(statement)) - - return additional_metadata - - def _is_metadata_statement(self, statement: exp.Expression) -> bool: - if isinstance(statement, d.MacroDef): - return True - if isinstance(statement, d.MacroFunc): - target_macro = macro.get_registry().get(statement.name) - if target_macro: - return target_macro.metadata_only - target_macro = self.python_env.get(statement.name) - if target_macro: - return bool(target_macro.is_metadata) - return False - class SqlModel(_SqlBasedModel): """The model definition which relies on a SQL query to fetch the data.