diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index 6db75ed17e..bea6036e0e 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -44,7 +44,7 @@ This example used one of SQLMesh's predefined variables, but you can also define We describe SQLMesh's predefined variables below; user-defined macro variables are discussed in the [SQLMesh macros](./sqlmesh_macros.md#user-defined-variables) and [Jinja macros](./jinja_macros.md#user-defined-variables) pages. -## Predefined Variables +## Predefined variables SQLMesh comes with predefined variables that can be used in your queries. They are automatically set by the SQLMesh runtime. Most predefined variables are related to time and use a combination of prefixes (start, end, etc.) and postfixes (date, ds, ts, etc.). They are described in the next section; [other predefined variables](#runtime-variables) are discussed in the following section. @@ -120,7 +120,7 @@ All predefined temporal macro variables: ### Runtime variables -SQLMesh provides two other predefined variables used to modify model behavior based on information available at runtime. +SQLMesh provides additional predefined variables used to modify model behavior based on information available at runtime. * @runtime_stage - A string value denoting the current stage of the SQLMesh runtime. Typically used in models to conditionally execute pre/post-statements (learn more [here](../models/sql_models.md#optional-prepost-statements)). It returns one of these values: * 'loading' - The project is being loaded into SQLMesh's runtime context. @@ -133,5 +133,11 @@ SQLMesh provides two other predefined variables used to modify model behavior ba * @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). In the case of [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements) it contains the qualified view name instead. * Can be used in model definitions when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly. * Can be passed as an argument to macros that access or interact with the underlying physical table. -* @this_env - A string value containing the name of the current [environment](../environments.md). Only available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them. -* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults). \ No newline at end of file +* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults). + +#### Before all and after all variables + +The following variables are also available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them. + +* @this_env - A string value containing the name of the current [environment](../environments.md). +* @schemas - A list of the schema names of the [virtual layer](../../concepts/glossary.md#virtual-layer) of the current environment. \ No newline at end of file diff --git a/docs/integrations/dbt.md b/docs/integrations/dbt.md index 83bac1127e..bd1fa9a7f1 100644 --- a/docs/integrations/dbt.md +++ b/docs/integrations/dbt.md @@ -324,7 +324,6 @@ The dbt jinja methods that are not currently supported are: * selected_sources * adapter.expand_target_column_types * adapter.rename_relation -* schemas * graph.nodes.values * graph.metrics.values diff --git a/examples/multi_dbt/bronze/dbt_project.yml b/examples/multi_dbt/bronze/dbt_project.yml index 14f841251c..1fadcdc1cd 100644 --- a/examples/multi_dbt/bronze/dbt_project.yml +++ b/examples/multi_dbt/bronze/dbt_project.yml @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"] models: start: "2024-01-01" +materialized: table + +on-run-start: + - 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);' \ No newline at end of file diff --git a/examples/multi_dbt/silver/dbt_project.yml b/examples/multi_dbt/silver/dbt_project.yml index e78f4643d3..57edd1f72c 100644 --- a/examples/multi_dbt/silver/dbt_project.yml +++ b/examples/multi_dbt/silver/dbt_project.yml @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"] models: start: "2024-01-01" +materialized: table + +on-run-end: + - '{{ store_schemas(schemas) }}' \ No newline at end of file diff --git a/examples/multi_dbt/silver/macros/store_schemas.sql b/examples/multi_dbt/silver/macros/store_schemas.sql new file mode 100644 index 0000000000..564d2b24bb --- /dev/null +++ b/examples/multi_dbt/silver/macros/store_schemas.sql @@ -0,0 +1,3 @@ +{% macro store_schemas(schemas) %} + create or replace table schema_table as select {{schemas}} as all_schemas; +{% endmacro %} \ No newline at end of file diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py index 8b667d11e2..1762a6275a 100644 --- a/sqlmesh/core/environment.py +++ b/sqlmesh/core/environment.py @@ -14,6 +14,7 @@ from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot from sqlmesh.utils import word_characters_only from sqlmesh.utils.date import TimeLike, now_timestamp +from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.utils.metaprogramming import Executable from sqlmesh.utils.pydantic import PydanticModel, field_validator @@ -218,6 +219,7 @@ class EnvironmentStatements(PydanticModel): before_all: t.List[str] after_all: t.List[str] python_env: t.Dict[str, Executable] + jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() def execute_environment_statements( @@ -239,6 +241,7 @@ def execute_environment_statements( dialect=adapter.dialect, default_catalog=default_catalog, python_env=statements.python_env, + jinja_macros=statements.jinja_macros, snapshots=snapshots, start=start, end=end, diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index c232b11687..86308b9e97 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -107,6 +107,18 @@ def _render( if environment_naming_info := kwargs.get("environment_naming_info", None): kwargs["this_env"] = getattr(environment_naming_info, "name") + if snapshots and ( + schemas := set( + [ + s.qualified_view_name.schema_for_environment( + environment_naming_info, dialect=self._dialect + ) + for s in snapshots.values() + if s.is_model and not s.is_symbolic + ] + ) + ): + kwargs["schemas"] = list(schemas) this_model = kwargs.pop("this_model", None) @@ -411,19 +423,21 @@ def render( def render_statements( statements: t.List[str], - dialect: DialectType = None, + dialect: str, default_catalog: t.Optional[str] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, **render_kwargs: t.Any, ) -> t.List[str]: rendered_statements: t.List[str] = [] for statement in statements: - for expression in parse(statement, dialect=dialect): + for expression in d.parse(statement, default_dialect=dialect): if expression: rendered = ExpressionRenderer( expression, dialect, [], + jinja_macro_registry=jinja_macros or JinjaMacroRegistry(), python_env=python_env, default_catalog=default_catalog, quote_identifiers=False, diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 7f147895d7..a270ca1745 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -3,6 +3,8 @@ import logging import sys import typing as t +import sqlmesh.core.dialect as d +from sqlglot.optimizer.simplify import gen from pathlib import Path from sqlmesh.core import constants as c from sqlmesh.core.config import ( @@ -11,9 +13,11 @@ GatewayConfig, ModelDefaultsConfig, ) +from sqlmesh.core.environment import EnvironmentStatements from sqlmesh.core.loader import CacheBase, LoadedProject, Loader from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model import Model, ModelCache +from sqlmesh.core.model.common import make_python_env from sqlmesh.core.signal import signal from sqlmesh.dbt.basemodel import BMC, BaseModelConfig from sqlmesh.dbt.context import DbtContext @@ -23,7 +27,11 @@ from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + MacroInfo, + extract_macro_references_and_variables, +) if sys.version_info >= (3, 12): from importlib import metadata @@ -230,6 +238,60 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: return requirements, excluded_requirements + def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None: + """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" + + on_run_start: t.List[str] = [] + on_run_end: t.List[str] = [] + jinja_root_macros: t.Dict[str, MacroInfo] = {} + variables: t.Dict[str, t.Any] = self._get_variables() + dialect = self.config.dialect + for project in self._load_projects(): + context = project.context.copy() + if manifest := context._manifest: + on_run_start.extend(manifest._on_run_start or []) + on_run_end.extend(manifest._on_run_end or []) + + if root_package := context.jinja_macros.root_package_name: + if root_macros := context.jinja_macros.packages.get(root_package): + jinja_root_macros |= root_macros + context.set_and_render_variables(context.variables, root_package) + variables |= context.variables + + if statements := on_run_start + on_run_end: + jinja_macro_references, used_variables = extract_macro_references_and_variables( + *(gen(stmt) for stmt in statements) + ) + jinja_macros = context.jinja_macros + jinja_macros.root_macros = jinja_root_macros + jinja_macros = ( + jinja_macros.trim(jinja_macro_references) + if not jinja_macros.trimmed + else jinja_macros + ) + + python_env = make_python_env( + [s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)], + jinja_macro_references=jinja_macro_references, + module_path=self.config_path, + macros=macros, + variables=variables, + used_variables=used_variables, + path=self.config_path, + ) + + return EnvironmentStatements( + before_all=[ + d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or [] + ], + after_all=[ + d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or [] + ], + python_env=python_env, + jinja_macros=jinja_macros, + ) + return None + def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]: if not root.is_dir(): return {} diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index a04a14e313..83d3df1321 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -94,6 +94,9 @@ def __init__( self.project_path / c.CACHE, "jinja_calls" ) + self._on_run_start: t.Optional[t.List[str]] = None + self._on_run_end: t.Optional[t.List[str]] = None + def tests(self, package_name: t.Optional[str] = None) -> TestConfigs: self._load_all() return self._tests_per_package[package_name or self._project_name] @@ -312,6 +315,11 @@ def _load_manifest(self) -> Manifest: runtime_config = RuntimeConfig.from_parts(project, profile, args) + if runtime_config.on_run_start: + self._on_run_start = runtime_config.on_run_start + if runtime_config.on_run_end: + self._on_run_end = runtime_config.on_run_end + self._project_name = project.project_name if DBT_VERSION >= (1, 8): diff --git a/tests/core/test_context.py b/tests/core/test_context.py index d0db9c7bd9..b2c5090269 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1417,6 +1417,7 @@ def test_environment_statements(tmp_path: pathlib.Path): after_all=[ "@grant_schema_usage()", "@grant_select_privileges()", + "@grant_usage_role(@schemas, 'admin')", ], ) @@ -1481,6 +1482,22 @@ def grant_schema_usage(evaluator): """, ) + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "grant_usage_file.py"), + """ +from sqlmesh import macro + +@macro() +def grant_usage_role(evaluator, schemas, role): + if evaluator._environment_naming_info: + return [ + f"GRANT USAGE ON SCHEMA {schema} TO {role};" + for schema in schemas + ] +""", + ) + context = Context(paths=tmp_path, config=config) snapshots = {s.name: s for s in context.snapshots.values()} @@ -1515,6 +1532,7 @@ def grant_schema_usage(evaluator): assert after_all_rendered == [ "GRANT USAGE ON SCHEMA db TO user_role", "GRANT SELECT ON VIEW memory.db.test_after_model TO ROLE admin_role", + 'GRANT USAGE ON SCHEMA "db" TO "admin"', ] after_all_rendered_dev = render_statements( @@ -1529,6 +1547,7 @@ def grant_schema_usage(evaluator): assert after_all_rendered_dev == [ "GRANT USAGE ON SCHEMA db__dev TO user_role", "GRANT SELECT ON VIEW memory.db__dev.test_after_model TO ROLE admin_role", + 'GRANT USAGE ON SCHEMA "db__dev" TO "admin"', ] diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 195db7fe7a..cc70cf302e 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4502,6 +4502,24 @@ def test_multi_dbt(mocker): context.apply(plan) validate_apply_basics(context, c.PROD, plan.snapshots.values()) + environment_statements = context.state_sync.get_environment_statements(c.PROD) + assert len(environment_statements) == 2 + bronze_statements = environment_statements[0] + assert bronze_statements.before_all == [ + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;" + ] + assert not bronze_statements.after_all + silver_statements = environment_statements[1] + assert not silver_statements.before_all + assert silver_statements.after_all == [ + "JINJA_STATEMENT_BEGIN;\n{{ store_schemas(schemas) }}\nJINJA_END;" + ] + assert "store_schemas" in silver_statements.jinja_macros.root_macros + analytics_table = context.fetchdf("select * from analytic_stats;") + assert sorted(analytics_table.columns) == sorted(["physical_table", "evaluation_time"]) + schema_table = context.fetchdf("select * from schema_table;") + assert sorted(schema_table.all_schemas[0]) == sorted(["bronze", "silver"]) + def test_multi_hybrid(mocker): context = Context( diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index bfd715455f..d77a8f75fd 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -13,6 +13,9 @@ from sqlmesh import Context from sqlmesh.core.dialect import schema_ +from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.renderer import render_statements from sqlmesh.core.snapshot import SnapshotId from sqlmesh.dbt.adapter import ParsetimeAdapter from sqlmesh.dbt.project import Project @@ -270,3 +273,48 @@ def test_quote_as_configured(): adapter.quote_as_configured("foo", "identifier") == '"foo"' adapter.quote_as_configured("foo", "schema") == "foo" adapter.quote_as_configured("foo", "database") == "foo" + + +def test_on_run_start_end(copy_to_temp_path): + project_root = "tests/fixtures/dbt/sushi_test" + sushi_context = Context(paths=copy_to_temp_path(project_root)) + assert len(sushi_context._environment_statements) == 1 + environment_statements = sushi_context._environment_statements[0] + + assert environment_statements.before_all == [ + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;" + ] + assert environment_statements.after_all == [ + "JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;" + ] + assert "create_tables" in environment_statements.jinja_macros.root_macros + + rendered_before_all = render_statements( + environment_statements.before_all, + dialect=sushi_context.default_dialect, + python_env=environment_statements.python_env, + jinja_macros=environment_statements.jinja_macros, + runtime_stage=RuntimeStage.BEFORE_ALL, + ) + + rendered_after_all = render_statements( + environment_statements.after_all, + dialect=sushi_context.default_dialect, + python_env=environment_statements.python_env, + jinja_macros=environment_statements.jinja_macros, + snapshots=sushi_context.snapshots, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=EnvironmentNamingInfo(name="dev"), + ) + + assert rendered_before_all == [ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)" + ] + + # The jinja macro should have resolved the schemas for this environment and generated corresponding statements + assert sorted(rendered_after_all) == sorted( + [ + "CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema", + "CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema", + ] + ) diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 7de86e72b3..78af8882c1 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -997,6 +997,16 @@ def test_dbt_version(sushi_test_project: Project): assert context.render("{{ dbt_version }}").startswith("1.") +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_on_run_start_end(sushi_test_project: Project): + context = sushi_test_project.context + assert context._manifest + assert context._manifest._on_run_start == [ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);" + ] + assert context._manifest._on_run_end == ["{{ create_tables(schemas) }}"] + + @pytest.mark.xdist_group("dbt_manifest") def test_parsetime_adapter_call( assert_exp_eq, sushi_test_project: Project, sushi_test_dbt_context: Context diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 0e786a26e8..dd7486821e 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -25,14 +25,14 @@ models: +materialized: table +pre-hook: - '{{ log("pre-hook") }}' - +post-hook: + +post-hook: - '{{ log("post-hook") }}' seeds: sushi: +pre-hook: - '{{ log("pre-hook") }}' - +post-hook: + +post-hook: - '{{ log("post-hook") }}' vars: @@ -57,3 +57,9 @@ vars: value: 1 - name: 'item2' value: 2 + + +on-run-start: + - 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);' +on-run-end: + - '{{ create_tables(schemas) }}' \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/macros/create_tables.sql b/tests/fixtures/dbt/sushi_test/macros/create_tables.sql new file mode 100644 index 0000000000..57616b7389 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/create_tables.sql @@ -0,0 +1,5 @@ +{% macro create_tables(schemas) %} + {% for schema in schemas %} + create or replace table schema_table_{{schema}} as select '{{schema}}' as schema; + {% endfor%} +{% endmacro %} \ No newline at end of file