Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions docs/concepts/macros/macro_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ This example used one of SQLMesh's predefined variables, but you can also define

We describe SQLMesh's predefined variables below; user-defined macro variables are discussed in the [SQLMesh macros](./sqlmesh_macros.md#user-defined-variables) and [Jinja macros](./jinja_macros.md#user-defined-variables) pages.

## Predefined Variables
## Predefined variables
SQLMesh comes with predefined variables that can be used in your queries. They are automatically set by the SQLMesh runtime.

Most predefined variables are related to time and use a combination of prefixes (start, end, etc.) and postfixes (date, ds, ts, etc.). They are described in the next section; [other predefined variables](#runtime-variables) are discussed in the following section.
Expand Down Expand Up @@ -120,7 +120,7 @@ All predefined temporal macro variables:

### Runtime variables

SQLMesh provides two other predefined variables used to modify model behavior based on information available at runtime.
SQLMesh provides additional predefined variables used to modify model behavior based on information available at runtime.

* @runtime_stage - A string value denoting the current stage of the SQLMesh runtime. Typically used in models to conditionally execute pre/post-statements (learn more [here](../models/sql_models.md#optional-prepost-statements)). It returns one of these values:
* 'loading' - The project is being loaded into SQLMesh's runtime context.
Expand All @@ -133,5 +133,11 @@ SQLMesh provides two other predefined variables used to modify model behavior ba
* @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). In the case of [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements) it contains the qualified view name instead.
* Can be used in model definitions when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly.
* Can be passed as an argument to macros that access or interact with the underlying physical table.
* @this_env - A string value containing the name of the current [environment](../environments.md). Only available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them.
* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults).
* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults).

#### Before all and after all variables

The following variables are also available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them.

* @this_env - A string value containing the name of the current [environment](../environments.md).
* @schemas - A list of the schema names of the [virtual layer](../../concepts/glossary.md#virtual-layer) of the current environment.
1 change: 0 additions & 1 deletion docs/integrations/dbt.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ The dbt jinja methods that are not currently supported are:
* selected_sources
* adapter.expand_target_column_types
* adapter.rename_relation
* schemas
* graph.nodes.values
* graph.metrics.values

Expand Down
3 changes: 3 additions & 0 deletions examples/multi_dbt/bronze/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
models:
start: "2024-01-01"
+materialized: table

on-run-start:
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'
3 changes: 3 additions & 0 deletions examples/multi_dbt/silver/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
models:
start: "2024-01-01"
+materialized: table

on-run-end:
- '{{ store_schemas(schemas) }}'
3 changes: 3 additions & 0 deletions examples/multi_dbt/silver/macros/store_schemas.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{% macro store_schemas(schemas) %}
create or replace table schema_table as select {{schemas}} as all_schemas;
{% endmacro %}
3 changes: 3 additions & 0 deletions sqlmesh/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot
from sqlmesh.utils import word_characters_only
from sqlmesh.utils.date import TimeLike, now_timestamp
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.utils.metaprogramming import Executable
from sqlmesh.utils.pydantic import PydanticModel, field_validator

Expand Down Expand Up @@ -218,6 +219,7 @@ class EnvironmentStatements(PydanticModel):
before_all: t.List[str]
after_all: t.List[str]
python_env: t.Dict[str, Executable]
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()


def execute_environment_statements(
Expand All @@ -239,6 +241,7 @@ def execute_environment_statements(
dialect=adapter.dialect,
default_catalog=default_catalog,
python_env=statements.python_env,
jinja_macros=statements.jinja_macros,
snapshots=snapshots,
start=start,
end=end,
Expand Down
18 changes: 16 additions & 2 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ def _render(

if environment_naming_info := kwargs.get("environment_naming_info", None):
kwargs["this_env"] = getattr(environment_naming_info, "name")
if snapshots and (
schemas := set(
[
s.qualified_view_name.schema_for_environment(
environment_naming_info, dialect=self._dialect
)
for s in snapshots.values()
if s.is_model and not s.is_symbolic
]
)
):
kwargs["schemas"] = list(schemas)

this_model = kwargs.pop("this_model", None)

Expand Down Expand Up @@ -411,19 +423,21 @@ def render(

def render_statements(
statements: t.List[str],
dialect: DialectType = None,
dialect: str,
default_catalog: t.Optional[str] = None,
python_env: t.Optional[t.Dict[str, Executable]] = None,
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
**render_kwargs: t.Any,
) -> t.List[str]:
rendered_statements: t.List[str] = []
for statement in statements:
for expression in parse(statement, dialect=dialect):
for expression in d.parse(statement, default_dialect=dialect):
if expression:
rendered = ExpressionRenderer(
expression,
dialect,
[],
jinja_macro_registry=jinja_macros or JinjaMacroRegistry(),
python_env=python_env,
default_catalog=default_catalog,
quote_identifiers=False,
Expand Down
64 changes: 63 additions & 1 deletion sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
import sys
import typing as t
import sqlmesh.core.dialect as d
from sqlglot.optimizer.simplify import gen
from pathlib import Path
from sqlmesh.core import constants as c
from sqlmesh.core.config import (
Expand All @@ -11,9 +13,11 @@
GatewayConfig,
ModelDefaultsConfig,
)
from sqlmesh.core.environment import EnvironmentStatements
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
from sqlmesh.core.macros import MacroRegistry, macro
from sqlmesh.core.model import Model, ModelCache
from sqlmesh.core.model.common import make_python_env
from sqlmesh.core.signal import signal
from sqlmesh.dbt.basemodel import BMC, BaseModelConfig
from sqlmesh.dbt.context import DbtContext
Expand All @@ -23,7 +27,11 @@
from sqlmesh.dbt.target import TargetConfig
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.utils.jinja import (
JinjaMacroRegistry,
MacroInfo,
extract_macro_references_and_variables,
)

if sys.version_info >= (3, 12):
from importlib import metadata
Expand Down Expand Up @@ -230,6 +238,60 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:

return requirements, excluded_requirements

def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""

on_run_start: t.List[str] = []
on_run_end: t.List[str] = []
jinja_root_macros: t.Dict[str, MacroInfo] = {}
variables: t.Dict[str, t.Any] = self._get_variables()
dialect = self.config.dialect
for project in self._load_projects():
context = project.context.copy()
if manifest := context._manifest:
on_run_start.extend(manifest._on_run_start or [])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we access these outside the manifest instance, we shouldn't mark them as private. Or at the very least add a public property method to expose these.

on_run_end.extend(manifest._on_run_end or [])

if root_package := context.jinja_macros.root_package_name:
if root_macros := context.jinja_macros.packages.get(root_package):
jinja_root_macros |= root_macros
context.set_and_render_variables(context.variables, root_package)
variables |= context.variables

if statements := on_run_start + on_run_end:
jinja_macro_references, used_variables = extract_macro_references_and_variables(
*(gen(stmt) for stmt in statements)
)
jinja_macros = context.jinja_macros
jinja_macros.root_macros = jinja_root_macros
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this gymnastics with root macros?

jinja_macros = (
jinja_macros.trim(jinja_macro_references)
if not jinja_macros.trimmed
else jinja_macros
)

python_env = make_python_env(
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
jinja_macro_references=jinja_macro_references,
module_path=self.config_path,
macros=macros,
variables=variables,
used_variables=used_variables,
path=self.config_path,
)

return EnvironmentStatements(
before_all=[
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or []
],
after_all=[
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or []
],
python_env=python_env,
jinja_macros=jinja_macros,
)
return None

def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
if not root.is_dir():
return {}
Expand Down
8 changes: 8 additions & 0 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
self.project_path / c.CACHE, "jinja_calls"
)

self._on_run_start: t.Optional[t.List[str]] = None
self._on_run_end: t.Optional[t.List[str]] = None

def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
self._load_all()
return self._tests_per_package[package_name or self._project_name]
Expand Down Expand Up @@ -312,6 +315,11 @@ def _load_manifest(self) -> Manifest:

runtime_config = RuntimeConfig.from_parts(project, profile, args)

if runtime_config.on_run_start:
self._on_run_start = runtime_config.on_run_start
if runtime_config.on_run_end:
self._on_run_end = runtime_config.on_run_end

self._project_name = project.project_name

if DBT_VERSION >= (1, 8):
Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,7 @@ def test_environment_statements(tmp_path: pathlib.Path):
after_all=[
"@grant_schema_usage()",
"@grant_select_privileges()",
"@grant_usage_role(@schemas, 'admin')",
],
)

Expand Down Expand Up @@ -1481,6 +1482,22 @@ def grant_schema_usage(evaluator):
""",
)

create_temp_file(
tmp_path,
pathlib.Path(macros_dir, "grant_usage_file.py"),
"""
from sqlmesh import macro

@macro()
def grant_usage_role(evaluator, schemas, role):
if evaluator._environment_naming_info:
return [
f"GRANT USAGE ON SCHEMA {schema} TO {role};"
for schema in schemas
]
""",
)

context = Context(paths=tmp_path, config=config)
snapshots = {s.name: s for s in context.snapshots.values()}

Expand Down Expand Up @@ -1515,6 +1532,7 @@ def grant_schema_usage(evaluator):
assert after_all_rendered == [
"GRANT USAGE ON SCHEMA db TO user_role",
"GRANT SELECT ON VIEW memory.db.test_after_model TO ROLE admin_role",
'GRANT USAGE ON SCHEMA "db" TO "admin"',
]

after_all_rendered_dev = render_statements(
Expand All @@ -1529,6 +1547,7 @@ def grant_schema_usage(evaluator):
assert after_all_rendered_dev == [
"GRANT USAGE ON SCHEMA db__dev TO user_role",
"GRANT SELECT ON VIEW memory.db__dev.test_after_model TO ROLE admin_role",
'GRANT USAGE ON SCHEMA "db__dev" TO "admin"',
]


Expand Down
18 changes: 18 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4502,6 +4502,24 @@ def test_multi_dbt(mocker):
context.apply(plan)
validate_apply_basics(context, c.PROD, plan.snapshots.values())

environment_statements = context.state_sync.get_environment_statements(c.PROD)
assert len(environment_statements) == 2
bronze_statements = environment_statements[0]
assert bronze_statements.before_all == [
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
]
assert not bronze_statements.after_all
silver_statements = environment_statements[1]
assert not silver_statements.before_all
assert silver_statements.after_all == [
"JINJA_STATEMENT_BEGIN;\n{{ store_schemas(schemas) }}\nJINJA_END;"
]
assert "store_schemas" in silver_statements.jinja_macros.root_macros
analytics_table = context.fetchdf("select * from analytic_stats;")
assert sorted(analytics_table.columns) == sorted(["physical_table", "evaluation_time"])
schema_table = context.fetchdf("select * from schema_table;")
assert sorted(schema_table.all_schemas[0]) == sorted(["bronze", "silver"])


def test_multi_hybrid(mocker):
context = Context(
Expand Down
48 changes: 48 additions & 0 deletions tests/dbt/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from sqlmesh import Context
from sqlmesh.core.dialect import schema_
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.renderer import render_statements
from sqlmesh.core.snapshot import SnapshotId
from sqlmesh.dbt.adapter import ParsetimeAdapter
from sqlmesh.dbt.project import Project
Expand Down Expand Up @@ -270,3 +273,48 @@ def test_quote_as_configured():
adapter.quote_as_configured("foo", "identifier") == '"foo"'
adapter.quote_as_configured("foo", "schema") == "foo"
adapter.quote_as_configured("foo", "database") == "foo"


def test_on_run_start_end(copy_to_temp_path):
project_root = "tests/fixtures/dbt/sushi_test"
sushi_context = Context(paths=copy_to_temp_path(project_root))
assert len(sushi_context._environment_statements) == 1
environment_statements = sushi_context._environment_statements[0]

assert environment_statements.before_all == [
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
]
assert environment_statements.after_all == [
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;"
]
assert "create_tables" in environment_statements.jinja_macros.root_macros

rendered_before_all = render_statements(
environment_statements.before_all,
dialect=sushi_context.default_dialect,
python_env=environment_statements.python_env,
jinja_macros=environment_statements.jinja_macros,
runtime_stage=RuntimeStage.BEFORE_ALL,
)

rendered_after_all = render_statements(
environment_statements.after_all,
dialect=sushi_context.default_dialect,
python_env=environment_statements.python_env,
jinja_macros=environment_statements.jinja_macros,
snapshots=sushi_context.snapshots,
runtime_stage=RuntimeStage.AFTER_ALL,
environment_naming_info=EnvironmentNamingInfo(name="dev"),
)

assert rendered_before_all == [
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)"
]

# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
assert sorted(rendered_after_all) == sorted(
[
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
]
)
10 changes: 10 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,16 @@ def test_dbt_version(sushi_test_project: Project):
assert context.render("{{ dbt_version }}").startswith("1.")


@pytest.mark.xdist_group("dbt_manifest")
def test_dbt_on_run_start_end(sushi_test_project: Project):
context = sushi_test_project.context
assert context._manifest
assert context._manifest._on_run_start == [
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);"
]
assert context._manifest._on_run_end == ["{{ create_tables(schemas) }}"]


@pytest.mark.xdist_group("dbt_manifest")
def test_parsetime_adapter_call(
assert_exp_eq, sushi_test_project: Project, sushi_test_dbt_context: Context
Expand Down
Loading