Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
DEFAULT_SCHEMA = "default"

SQLMESH_VARS = "__sqlmesh__vars__"
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"

VAR = "var"
GATEWAY = "gateway"

Expand Down
13 changes: 12 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,15 @@ def __init__(
default_dialect: t.Optional[str] = None,
default_catalog: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
):
self.snapshots = snapshots
self.deployability_index = deployability_index
self._engine_adapter = engine_adapter
self._default_catalog = default_catalog
self._default_dialect = default_dialect
self._variables = variables or {}
self._blueprint_variables = blueprint_variables or {}

@property
def default_dialect(self) -> t.Optional[str]:
Expand Down Expand Up @@ -288,7 +290,15 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.
"""Returns a variable value."""
return self._variables.get(var_name.lower(), default)

def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext:
def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns a blueprint variable value."""
return self._blueprint_variables.get(var_name.lower(), default)

def with_variables(
self,
variables: t.Dict[str, t.Any],
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
) -> ExecutionContext:
"""Returns a new ExecutionContext with additional variables."""
return ExecutionContext(
self._engine_adapter,
Expand All @@ -297,6 +307,7 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext:
self._default_dialect,
self._default_catalog,
variables=variables,
blueprint_variables=blueprint_variables,
)


Expand Down
68 changes: 51 additions & 17 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from sqlmesh.utils.date import DatetimeRanges, to_datetime, to_date
from sqlmesh.utils.errors import MacroEvalError, SQLMeshError
from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception
from sqlmesh.utils.metaprogramming import Executable, SqlValue, prepare_env, print_exception

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Expand Down Expand Up @@ -173,14 +173,15 @@ def __init__(
"MacroEvaluator": MacroEvaluator,
}
self.python_env = python_env or {}
self._jinja_env: t.Optional[Environment] = jinja_env
self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()}
self.columns_to_types_called = False
self.default_catalog = default_catalog

self._jinja_env: t.Optional[Environment] = jinja_env
self._schema = schema
self._resolve_table = resolve_table
self._resolve_tables = resolve_tables
self.columns_to_types_called = False
self._snapshots = snapshots if snapshots is not None else {}
self.default_catalog = default_catalog
self._path = path
self._environment_naming_info = environment_naming_info

Expand All @@ -191,7 +192,18 @@ def __init__(
elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None):
self.macros[normalize_macro_name(k)] = self.env[k]
elif v.is_value:
self.locals[k] = self.env[k]
value = self.env[k]
if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS):
value = {
var_name: (
self.parse_one(var_value.sql)
if isinstance(var_value, SqlValue)
else var_value
)
for var_name, var_value in value.items()
}

self.locals[k] = value

def send(
self, name: str, *args: t.Any, **kwargs: t.Any
Expand Down Expand Up @@ -219,20 +231,23 @@ def evaluate_macros(

if isinstance(node, MacroVar):
changed = True
variables = self.locals.get(c.SQLMESH_VARS, {})
variables = self.variables

if node.name not in self.locals and node.name.lower() not in variables:
if not isinstance(node.parent, StagedFilePath):
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")

return node

# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
value = self.locals.get(node.name, variables.get(node.name.lower()))
if isinstance(value, list):
return exp.convert(
tuple(
self.transform(v) if isinstance(v, exp.Expression) else v for v in value
)
)

return exp.convert(
self.transform(value) if isinstance(value, exp.Expression) else value
)
Expand Down Expand Up @@ -279,17 +294,12 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
Returns:
The rendered string.
"""
mapping = {}

variables = self.locals.get(c.SQLMESH_VARS, {})

for k, v in chain(variables.items(), self.locals.items(), local_variables.items()):
# try to convert all variables into sqlglot expressions
# because they're going to be converted into strings in sql
# we don't convert strings because that would result in adding quotes
if k != c.SQLMESH_VARS:
mapping[k] = convert_sql(v, self.dialect)

# We try to convert all variables into sqlglot expressions because they're going to be converted
# into strings; in sql we don't convert strings because that would result in adding quotes
mapping = {
k: convert_sql(v, self.dialect)
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
}
return MacroStrTemplate(str(text)).safe_substitute(mapping)

def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
Expand Down Expand Up @@ -467,6 +477,17 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)

def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns the value of the specified blueprint variable, or the default value if it doesn't exist."""
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default)

@property
def variables(self) -> t.Dict[str, t.Any]:
return {
**self.locals.get(c.SQLMESH_VARS, {}),
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}),
}

def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
"""Coerces the given expression to the specified type on a best-effort basis."""
return _coerce(expr, typ, self.dialect, self._path, strict)
Expand Down Expand Up @@ -1054,6 +1075,19 @@ def var(
return exp.convert(evaluator.var(var_name.this, default))


@macro("BLUEPRINT_VAR")
def blueprint_var(
evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
) -> exp.Expression:
"""Returns the value of a blueprint variable or the default value if the variable is not set."""
if not var_name.is_string:
raise SQLMeshError(
f"Invalid blueprint variable name '{var_name.sql()}'. Expected a string literal."
)

return exp.convert(evaluator.blueprint_var(var_name.this, default))


@macro()
def deduplicate(
evaluator: MacroEvaluator,
Expand Down
22 changes: 21 additions & 1 deletion sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate
from sqlmesh.utils import str_to_bool
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
from sqlmesh.utils.metaprogramming import Executable, build_env, prepare_env, serialize_env
from sqlmesh.utils.metaprogramming import (
Executable,
SqlValue,
build_env,
prepare_env,
serialize_env,
)
from sqlmesh.utils.pydantic import ValidationInfo, field_validator

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlmesh.utils.jinja import MacroReference


Expand All @@ -30,6 +37,8 @@ def make_python_env(
path: t.Optional[str | Path] = None,
python_env: t.Optional[t.Dict[str, Executable]] = None,
strict_resolution: bool = True,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
dialect: DialectType = None,
) -> t.Dict[str, Executable]:
python_env = {} if python_env is None else python_env
variables = variables or {}
Expand Down Expand Up @@ -86,6 +95,8 @@ def make_python_env(
python_env,
used_variables,
variables,
blueprint_variables=blueprint_variables,
dialect=dialect,
strict_resolution=strict_resolution,
)

Expand All @@ -95,6 +106,8 @@ def _add_variables_to_python_env(
used_variables: t.Optional[t.Set[str]],
variables: t.Optional[t.Dict[str, t.Any]],
strict_resolution: bool = True,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
dialect: DialectType = None,
) -> t.Dict[str, Executable]:
_, python_used_variables = parse_dependencies(
python_env,
Expand All @@ -107,6 +120,13 @@ def _add_variables_to_python_env(
if variables:
python_env[c.SQLMESH_VARS] = Executable.value(variables)

if blueprint_variables:
blueprint_variables = {
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
for k, v in blueprint_variables.items()
}
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)

return python_env


Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def model(
default_catalog: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
infer_names: t.Optional[bool] = False,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
) -> Model:
"""Get the model registered by this function."""
env: t.Dict[str, t.Any] = {}
Expand Down Expand Up @@ -155,6 +156,7 @@ def model(
path=path,
dialect=dialect,
default_catalog=default_catalog,
blueprint_variables=blueprint_variables,
)

rendered_name = rendered_fields["name"]
Expand Down Expand Up @@ -193,6 +195,7 @@ def model(
"macros": macros,
"jinja_macros": jinja_macros,
"audit_definitions": audit_definitions,
"blueprint_variables": blueprint_variables,
**rendered_fields,
}

Expand Down
Loading