diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4770a9f650..4ac87199c6 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -67,7 +67,7 @@ SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker -from sqlmesh.utils import random_id, CorrelationId +from sqlmesh.utils import random_id, CorrelationId, AttributeDict from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, concurrent_apply_to_values, @@ -2731,12 +2731,12 @@ def _execute_materialization( **kwargs: t.Any, ) -> None: jinja_macros = model.jinja_macros - existing_globals = jinja_macros.global_objs.copy() # For vdes we need to use the table, since we don't know the schema/table at parse time parts = exp.to_table(table_name, dialect=self.adapter.dialect) - relation_info = existing_globals.pop("this") + existing_globals = jinja_macros.global_objs + relation_info = existing_globals.get("this") if isinstance(relation_info, dict): relation_info["database"] = parts.catalog relation_info["identifier"] = parts.name @@ -2750,29 +2750,29 @@ def _execute_materialization( "identifier": parts.name, "target": existing_globals.get("target", {"type": self.adapter.dialect}), "execution_dt": kwargs.get("execution_time"), + "engine_adapter": self.adapter, + "sql": str(query_or_df), + "is_first_insert": is_first_insert, + "create_only": create_only, + # FIXME: Add support for transaction=False + "pre_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": True}) + for s in model.pre_statements + ], + "post_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": True}) + for s in model.post_statements + ], + "model_instance": model, + **kwargs, } - context = jinja_macros._create_builtin_globals( - {"engine_adapter": self.adapter, **jinja_globals} - ) - - context.update( - { - "sql": str(query_or_df), - "is_first_insert": is_first_insert, - "create_only": create_only, - "pre_hooks": model.render_pre_statements(**render_kwargs), - "post_hooks": model.render_post_statements(**render_kwargs), - **kwargs, - } - ) - try: - jinja_env = jinja_macros.build_environment(**context) + jinja_env = jinja_macros.build_environment(**jinja_globals) template = jinja_env.from_string(self.materialization_template) try: - template.render(**context) + template.render() except MacroReturnVal as ret: # this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache) returned_relations = ret.value.get("relations", []) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index a8b2b9af72..7f7c7eb4fb 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -99,12 +99,6 @@ def execute( ) -> t.Tuple[AdapterResponse, agate.Table]: """Executes the given SQL statement and returns the results as an agate table.""" - @abc.abstractmethod - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - """Executes the given hooks.""" - @abc.abstractmethod def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: """Resolves the relation's schema to its physical schema.""" @@ -247,12 +241,6 @@ def execute( self._raise_parsetime_adapter_call_error("execute SQL") raise - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - self._raise_parsetime_adapter_call_error("run hooks") - raise - def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: return relation.schema @@ -463,12 +451,6 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]: identifier = self._map_table_name(self._normalize(self._relation_to_table(relation))).name return identifier if identifier else None - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - # inside_transaction not yet supported similarly to transaction - self.engine_adapter.execute([exp.maybe_parse(hook) for hook in hooks]) - def _map_table_name(self, table: exp.Table) -> exp.Table: # Use the default dialect since this is the dialect used to normalize and quote keys in the # mapping table. diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 8690eb91fa..b8180bc011 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -544,9 +544,9 @@ def create_builtin_globals( "load_result": sql_execution.load_result, "run_query": sql_execution.run_query, "statement": sql_execution.statement, - "run_hooks": adapter.run_hooks, "graph": adapter.graph, "selected_resources": list(jinja_globals.get("selected_models") or []), + "write": lambda input: None, # We don't support writing yet } ) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index f8e6e01fc4..17c5e91700 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -395,6 +395,12 @@ def _load_models_and_seeds(self) -> None: dependencies = dependencies.union( self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True) ) + for hook in [*node_config.get("pre-hook", []), *node_config.get("post-hook", [])]: + dependencies = dependencies.union( + self._extra_dependencies( + hook["sql"], node.package_name, track_all_model_attrs=True + ) + ) dependencies = dependencies.union( self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) ) diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 508c6dce2d..59e9f6dd2f 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -369,6 +369,7 @@ def build_environment(self, **kwargs: t.Any) -> Environment: context.update(builtin_globals) context.update(root_macros) context.update(package_macros) + context["render"] = lambda input: env.from_string(input).render() env.globals.update(context) env.filters.update(self._environment.filters) diff --git a/tests/dbt/test_custom_materializations.py b/tests/dbt/test_custom_materializations.py index bd961136d2..9e7a94315c 100644 --- a/tests/dbt/test_custom_materializations.py +++ b/tests/dbt/test_custom_materializations.py @@ -37,7 +37,7 @@ def test_custom_materialization_manifest_loading(): assert custom_incremental.name == "custom_incremental" assert custom_incremental.adapter == "default" assert "make_temp_relation(new_relation)" in custom_incremental.definition - assert "run_hooks(pre_hooks, inside_transaction=False)" in custom_incremental.definition + assert "run_hooks(pre_hooks)" in custom_incremental.definition assert " {{ return({'relations': [new_relation]}) }}" in custom_incremental.definition diff --git a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql index d39453f1c6..c61899c8ff 100644 --- a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql +++ b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql @@ -19,7 +19,7 @@ {%- set time_column = config.get('time_column') -%} {%- set interval_config = config.get('interval') -%} - {{ run_hooks(pre_hooks, inside_transaction=False) }} + {{ run_hooks(pre_hooks) }} {%- if existing_relation is none -%} {# The first insert creates new table if it doesn't exist #} @@ -55,7 +55,7 @@ {%- endcall -%} {%- endif -%} - {{ run_hooks(post_hooks, inside_transaction=False) }} + {{ run_hooks(post_hooks) }} {{ return({'relations': [new_relation]}) }} -{%- endmaterialization -%} \ No newline at end of file +{%- endmaterialization -%}