diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index abc096c9a6..1ae55672b5 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -76,6 +76,8 @@ DEFAULT_SCHEMA = "default" SQLMESH_VARS = "__sqlmesh__vars__" +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" + VAR = "var" GATEWAY = "gateway" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 87217fe64e..632c71c5e1 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -253,6 +253,7 @@ def __init__( default_dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ): self.snapshots = snapshots self.deployability_index = deployability_index @@ -260,6 +261,7 @@ def __init__( self._default_catalog = default_catalog self._default_dialect = default_dialect self._variables = variables or {} + self._blueprint_variables = blueprint_variables or {} @property def default_dialect(self) -> t.Optional[str]: @@ -288,7 +290,15 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t. """Returns a variable value.""" return self._variables.get(var_name.lower(), default) - def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext: + def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + """Returns a blueprint variable value.""" + return self._blueprint_variables.get(var_name.lower(), default) + + def with_variables( + self, + variables: t.Dict[str, t.Any], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + ) -> ExecutionContext: """Returns a new ExecutionContext with additional variables.""" return ExecutionContext( self._engine_adapter, @@ -297,6 +307,7 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext: self._default_dialect, self._default_catalog, variables=variables, + blueprint_variables=blueprint_variables, ) diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 234290cdde..9b0cdfba8a 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -41,7 +41,7 @@ from sqlmesh.utils.date import DatetimeRanges, to_datetime, to_date from sqlmesh.utils.errors import MacroEvalError, SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja -from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception +from sqlmesh.utils.metaprogramming import Executable, SqlValue, prepare_env, print_exception if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -173,14 +173,15 @@ def __init__( "MacroEvaluator": MacroEvaluator, } self.python_env = python_env or {} - self._jinja_env: t.Optional[Environment] = jinja_env self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()} + self.columns_to_types_called = False + self.default_catalog = default_catalog + + self._jinja_env: t.Optional[Environment] = jinja_env self._schema = schema self._resolve_table = resolve_table self._resolve_tables = resolve_tables - self.columns_to_types_called = False self._snapshots = snapshots if snapshots is not None else {} - self.default_catalog = default_catalog self._path = path self._environment_naming_info = environment_naming_info @@ -191,7 +192,18 @@ def __init__( elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None): self.macros[normalize_macro_name(k)] = self.env[k] elif v.is_value: - self.locals[k] = self.env[k] + value = self.env[k] + if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS): + value = { + var_name: ( + self.parse_one(var_value.sql) + if isinstance(var_value, SqlValue) + else var_value + ) + for var_name, var_value in value.items() + } + + self.locals[k] = value def send( self, name: str, *args: t.Any, **kwargs: t.Any @@ -219,13 +231,15 @@ def evaluate_macros( if isinstance(node, MacroVar): changed = True - variables = self.locals.get(c.SQLMESH_VARS, {}) + variables = self.variables + if node.name not in self.locals and node.name.lower() not in variables: if not isinstance(node.parent, StagedFilePath): raise SQLMeshError(f"Macro variable '{node.name}' is undefined.") return node + # Precedence order is locals (e.g. @DEF) > blueprint variables > config variables value = self.locals.get(node.name, variables.get(node.name.lower())) if isinstance(value, list): return exp.convert( @@ -233,6 +247,7 @@ def evaluate_macros( self.transform(v) if isinstance(v, exp.Expression) else v for v in value ) ) + return exp.convert( self.transform(value) if isinstance(value, exp.Expression) else value ) @@ -279,17 +294,12 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: Returns: The rendered string. """ - mapping = {} - - variables = self.locals.get(c.SQLMESH_VARS, {}) - - for k, v in chain(variables.items(), self.locals.items(), local_variables.items()): - # try to convert all variables into sqlglot expressions - # because they're going to be converted into strings in sql - # we don't convert strings because that would result in adding quotes - if k != c.SQLMESH_VARS: - mapping[k] = convert_sql(v, self.dialect) - + # We try to convert all variables into sqlglot expressions because they're going to be converted + # into strings; in sql we don't convert strings because that would result in adding quotes + mapping = { + k: convert_sql(v, self.dialect) + for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items()) + } return MacroStrTemplate(str(text)).safe_substitute(mapping) def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None: @@ -467,6 +477,17 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t. """Returns the value of the specified variable, or the default value if it doesn't exist.""" return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default) + def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + """Returns the value of the specified blueprint variable, or the default value if it doesn't exist.""" + return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default) + + @property + def variables(self) -> t.Dict[str, t.Any]: + return { + **self.locals.get(c.SQLMESH_VARS, {}), + **self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}), + } + def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: """Coerces the given expression to the specified type on a best-effort basis.""" return _coerce(expr, typ, self.dialect, self._path, strict) @@ -1054,6 +1075,19 @@ def var( return exp.convert(evaluator.var(var_name.this, default)) +@macro("BLUEPRINT_VAR") +def blueprint_var( + evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None +) -> exp.Expression: + """Returns the value of a blueprint variable or the default value if the variable is not set.""" + if not var_name.is_string: + raise SQLMeshError( + f"Invalid blueprint variable name '{var_name.sql()}'. Expected a string literal." + ) + + return exp.convert(evaluator.blueprint_var(var_name.this, default)) + + @macro() def deduplicate( evaluator: MacroEvaluator, diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index d4bed645a2..09e036f55d 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -13,10 +13,17 @@ from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate from sqlmesh.utils import str_to_bool from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error -from sqlmesh.utils.metaprogramming import Executable, build_env, prepare_env, serialize_env +from sqlmesh.utils.metaprogramming import ( + Executable, + SqlValue, + build_env, + prepare_env, + serialize_env, +) from sqlmesh.utils.pydantic import ValidationInfo, field_validator if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType from sqlmesh.utils.jinja import MacroReference @@ -30,6 +37,8 @@ def make_python_env( path: t.Optional[str | Path] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, strict_resolution: bool = True, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + dialect: DialectType = None, ) -> t.Dict[str, Executable]: python_env = {} if python_env is None else python_env variables = variables or {} @@ -86,6 +95,8 @@ def make_python_env( python_env, used_variables, variables, + blueprint_variables=blueprint_variables, + dialect=dialect, strict_resolution=strict_resolution, ) @@ -95,6 +106,8 @@ def _add_variables_to_python_env( used_variables: t.Optional[t.Set[str]], variables: t.Optional[t.Dict[str, t.Any]], strict_resolution: bool = True, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + dialect: DialectType = None, ) -> t.Dict[str, Executable]: _, python_used_variables = parse_dependencies( python_env, @@ -107,6 +120,13 @@ def _add_variables_to_python_env( if variables: python_env[c.SQLMESH_VARS] = Executable.value(variables) + if blueprint_variables: + blueprint_variables = { + k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + for k, v in blueprint_variables.items() + } + python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables) + return python_env diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 4deb1076b7..3c94e5bf9e 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -122,6 +122,7 @@ def model( default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ) -> Model: """Get the model registered by this function.""" env: t.Dict[str, t.Any] = {} @@ -155,6 +156,7 @@ def model( path=path, dialect=dialect, default_catalog=default_catalog, + blueprint_variables=blueprint_variables, ) rendered_name = rendered_fields["name"] @@ -193,6 +195,7 @@ def model( "macros": macros, "jinja_macros": jinja_macros, "audit_definitions": audit_definitions, + "blueprint_variables": blueprint_variables, **rendered_fields, } diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 1b766e28ec..366a6304df 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -52,6 +52,7 @@ from sqlmesh.utils.pydantic import PydanticModel, PRIVATE_FIELDS from sqlmesh.utils.metaprogramming import ( Executable, + SqlValue, build_env, prepare_env, serialize_env, @@ -1749,6 +1750,10 @@ def render( variables = env.get(c.SQLMESH_VARS, {}) variables.update(kwargs.pop("variables", {})) + blueprint_variables = { + k: d.parse_one(v.sql, dialect=self.dialect) if isinstance(v, SqlValue) else v + for k, v in env.get(c.SQLMESH_BLUEPRINT_VARS, {}).items() + } try: kwargs = { **variables, @@ -1759,7 +1764,7 @@ def render( "latest": execution_time, # TODO: Preserved for backward compatibility. Remove in 1.0.0. } df_or_iter = env[self.entrypoint]( - context=context.with_variables(variables), + context=context.with_variables(variables, blueprint_variables=blueprint_variables), **kwargs, ) @@ -1855,18 +1860,14 @@ def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]: return [] # This is unreachable, but is done to satisfy mypy -def _extract_blueprint_variables( - blueprint: t.Any, - dialect: DialectType, - path: Path, -) -> t.Dict[str, str]: +def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.Any]: if not blueprint: return {} if isinstance(blueprint, (exp.Paren, exp.PropertyEQ)): blueprint = blueprint.unnest() - return {blueprint.left.name: blueprint.right.sql(dialect=dialect)} + return {blueprint.left.name: blueprint.right} if isinstance(blueprint, (exp.Tuple, exp.Array)): - return {e.left.name: e.right.sql(dialect=dialect) for e in blueprint.expressions} + return {e.left.name: e.right for e in blueprint.expressions} if isinstance(blueprint, dict): return blueprint @@ -1889,7 +1890,7 @@ def create_models_from_blueprints( ) -> t.List[Model]: model_blueprints: t.List[Model] = [] for blueprint in _extract_blueprints(blueprints, path): - variables = _extract_blueprint_variables(blueprint, dialect, path) + blueprint_variables = _extract_blueprint_variables(blueprint, path) if gateway: rendered_gateway = render_expression( @@ -1897,10 +1898,10 @@ def create_models_from_blueprints( module_path=module_path, macros=loader_kwargs.get("macros"), jinja_macros=loader_kwargs.get("jinja_macros"), - variables=variables, path=path, dialect=dialect, default_catalog=loader_kwargs.get("default_catalog"), + blueprint_variables=blueprint_variables, ) gateway_name = rendered_gateway[0].name if rendered_gateway else None else: @@ -1911,7 +1912,8 @@ def create_models_from_blueprints( path=path, module_path=module_path, dialect=dialect, - variables={**get_variables(gateway_name), **variables}, + variables=get_variables(gateway_name), + blueprint_variables=blueprint_variables, **loader_kwargs, ) ) @@ -1983,6 +1985,7 @@ def load_sql_based_model( default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: """Load a model from a parsed SQLMesh model SQL file. @@ -2059,6 +2062,7 @@ def load_sql_based_model( path=path, dialect=dialect, default_catalog=default_catalog, + blueprint_variables=blueprint_variables, ) if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1: @@ -2143,6 +2147,7 @@ def load_sql_based_model( variables=variables, default_audits=default_audits, inline_audits=inline_audits, + blueprint_variables=blueprint_variables, **meta_fields, ) @@ -2247,6 +2252,7 @@ def create_python_model( module_path: Path = Path(), depends_on: t.Optional[t.Set[str]] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: """Creates a Python model. @@ -2259,6 +2265,7 @@ def create_python_model( path: An optional path to the model definition file. depends_on: The custom set of model's upstream dependencies. variables: The variables to pass to the model. + blueprint_variables: The blueprint's variables to pass to the model. """ # Find dependencies for python models by parsing code if they are not explicitly defined # Also remove self-references that are found @@ -2307,6 +2314,7 @@ def create_python_model( jinja_macros=jinja_macros, module_path=module_path, variables=variables, + blueprint_variables=blueprint_variables, **kwargs, ) @@ -2361,6 +2369,7 @@ def _create_model( macros: t.Optional[MacroRegistry] = None, signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: _validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path) @@ -2469,6 +2478,8 @@ def _create_model( path=path, python_env=python_env, strict_resolution=depends_on is None, + blueprint_variables=blueprint_variables, + dialect=dialect, ) env: t.Dict[str, t.Any] = {} @@ -2632,6 +2643,7 @@ def render_meta_fields( dialect: DialectType, variables: t.Optional[t.Dict[str, t.Any]], default_catalog: t.Optional[str], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: def render_field_value(value: t.Any) -> t.Any: if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value): @@ -2645,6 +2657,7 @@ def render_field_value(value: t.Any) -> t.Any: path=path, dialect=dialect, default_catalog=default_catalog, + blueprint_variables=blueprint_variables, ) if not rendered_expr: raise SQLMeshError( @@ -2752,6 +2765,7 @@ def render_expression( dialect: DialectType = None, variables: t.Optional[t.Dict[str, t.Any]] = None, default_catalog: t.Optional[str] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Optional[t.List[exp.Expression]]: meta_python_env = make_python_env( expressions=expression, @@ -2760,6 +2774,7 @@ def render_expression( macros=macros or macro.get_registry(), variables=variables, path=path, + blueprint_variables=blueprint_variables, ) return ExpressionRenderer( expression, diff --git a/sqlmesh/core/test/context.py b/sqlmesh/core/test/context.py index 6f4563cf51..30fcb318db 100644 --- a/sqlmesh/core/test/context.py +++ b/sqlmesh/core/test/context.py @@ -26,6 +26,7 @@ def __init__( default_dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ): self._engine_adapter = engine_adapter self._models = models @@ -33,6 +34,7 @@ def __init__( self._default_catalog = default_catalog self._default_dialect = default_dialect self._variables = variables or {} + self._blueprint_variables = variables or {} @cached_property def _model_tables(self) -> t.Dict[str, str]: @@ -41,7 +43,11 @@ def _model_tables(self) -> t.Dict[str, str]: name: self._test._test_fixture_table(name).sql() for name, model in self._models.items() } - def with_variables(self, variables: t.Dict[str, t.Any]) -> TestExecutionContext: + def with_variables( + self, + variables: t.Dict[str, t.Any], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + ) -> TestExecutionContext: """Returns a new TestExecutionContext with additional variables.""" return TestExecutionContext( self._engine_adapter, @@ -50,4 +56,5 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> TestExecutionContext: self._default_dialect, self._default_catalog, variables=variables, + blueprint_variables=blueprint_variables, ) diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index eb70592342..e54c1da69f 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -11,6 +11,7 @@ import textwrap import types import typing as t +from dataclasses import dataclass from enum import Enum from numbers import Number from pathlib import Path @@ -349,6 +350,13 @@ def walk(obj: t.Any, name: str) -> None: walk(obj, name) +@dataclass +class SqlValue: + """A SQL string representing a generated SQLGlot AST.""" + + sql: str + + class ExecutableKind(str, Enum): """The kind of of executable. The order of the members is used when serializing the python model to text.""" @@ -490,11 +498,12 @@ def prepare_env( python_env.items(), key=lambda item: 0 if item[1].is_import else 1 ): if executable.is_value: - env[name] = ast.literal_eval(executable.payload) + env[name] = eval(executable.payload) else: exec(executable.payload, env) if executable.alias and executable.name: env[executable.alias] = env[executable.name] + return env diff --git a/tests/core/test_model.py b/tests/core/test_model.py index cd749b11ea..4556b77b9c 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -64,7 +64,7 @@ from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor -from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.metaprogramming import Executable, SqlValue from sqlmesh.core.macros import RuntimeStage from tests.utils.test_helpers import use_terminal_console @@ -8200,6 +8200,7 @@ def identity(evaluator, value): ) def entrypoint(context, *args, **kwargs): x_var = context.var("x") + assert context.blueprint_var("blueprint").startswith("gw") return pd.DataFrame({"x": [x_var]})""" ) blueprint_pysql = tmp_path / "models" / "blueprint_sql.py" @@ -8218,6 +8219,7 @@ def entrypoint(context, *args, **kwargs): ) def entrypoint(evaluator): x_var = evaluator.var("x") + assert evaluator.blueprint_var("blueprint", default="").startswith("gw") return f'SELECT {x_var} AS x'""" ) @@ -8231,11 +8233,21 @@ def entrypoint(evaluator): for model_name in ("test_model_sql", "test_model_pydf", "test_model_pysql"): for gateway_no in range(1, 3): - model = models.get(f'"db"."gw{gateway_no}"."{model_name}"') + blueprint_value = f"gw{gateway_no}" + model = models.get(f'"db"."{blueprint_value}"."{model_name}"') assert model is not None assert "blueprints" not in model.all_fields() - assert model.python_env.get(c.SQLMESH_VARS) == Executable.value({"x": gateway_no}) + + python_env = model.python_env + serialized_blueprint = ( + SqlValue(sql=blueprint_value) if model_name == "test_model_sql" else blueprint_value + ) + assert python_env.get(c.SQLMESH_VARS) == Executable.value({"x": gateway_no}) + assert python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"blueprint": serialized_blueprint} + ) + assert context.fetchdf(f"from {model.fqn}").to_dict() == {"x": {0: gateway_no}} multi_variable_blueprint_example = tmp_path / "models" / "multi_variable_blueprint_example.sql" @@ -8245,15 +8257,17 @@ def entrypoint(evaluator): MODEL ( name @{customer}.my_table, blueprints ( - (customer := customer1, foo := 'bar'), - (customer := customer2, foo := qux), + (customer := customer1, customer_field := 'bar'), + (customer := customer2, customer_field := qux), ), kind FULL ); SELECT - @VAR('foo') AS foo, - FROM @VAR('customer').my_source + @customer_field AS foo, + @{customer_field} AS foo2, + @BLUEPRINT_VAR('customer_field') AS foo3, + FROM @{customer}.my_source """ ) @@ -8264,23 +8278,29 @@ def entrypoint(evaluator): assert len(models) == 8 customer1_model = models.get('"db"."customer1"."my_table"') - assert customer1_model is not None - assert customer1_model.python_env.get(c.SQLMESH_VARS) == Executable.value( - {"customer": "customer1", "foo": "'bar'"} + + customer1_python_env = customer1_model.python_env + assert customer1_python_env.get(c.SQLMESH_VARS) is None + assert customer1_python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": SqlValue(sql="customer1"), "customer_field": SqlValue(sql="'bar'")} ) + assert t.cast(exp.Expression, customer1_model.render_query()).sql() == ( - """SELECT '''bar''' AS "foo" FROM "db"."customer1"."my_source" AS "my_source\"""" + """SELECT 'bar' AS "foo", "bar" AS "foo2", 'bar' AS "foo3" FROM "db"."customer1"."my_source" AS "my_source\"""" ) customer2_model = models.get('"db"."customer2"."my_table"') - assert customer2_model is not None - assert customer2_model.python_env.get(c.SQLMESH_VARS) == Executable.value( - {"customer": "customer2", "foo": "qux"} + + customer2_python_env = customer2_model.python_env + assert customer2_python_env.get(c.SQLMESH_VARS) is None + assert customer2_python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": SqlValue(sql="customer2"), "customer_field": SqlValue(sql="qux")} ) + assert t.cast(exp.Expression, customer2_model.render_query()).sql() == ( - '''SELECT 'qux' AS "foo" FROM "db"."customer2"."my_source" AS "my_source"''' + '''SELECT "qux" AS "foo", "qux" AS "foo2", "qux" AS "foo3" FROM "db"."customer2"."my_source" AS "my_source"''' ) @@ -8352,6 +8372,174 @@ def test_single_blueprint(tmp_path: Path) -> None: assert '"memory"."bar"."some_table"' in ctx.models +def test_blueprinting_with_quotes(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb", template=ProjectTemplate.EMPTY) + + template_with_quoted_vars = tmp_path / "models/template_with_quoted_vars.sql" + template_with_quoted_vars.parent.mkdir(parents=True, exist_ok=True) + template_with_quoted_vars.write_text( + """ + MODEL ( + name m.@{bp_var}, + blueprints ( + (bp_var := "a b"), + (bp_var := 'c d'), + ), + ); + + SELECT @bp_var AS c1, @{bp_var} AS c2 + """ + ) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) + assert len(ctx.models) == 2 + + m1 = ctx.get_model('m."a b"', raise_if_missing=True) + m2 = ctx.get_model('m."c d"', raise_if_missing=True) + + assert m1.name == 'm."a b"' + assert m2.name == 'm."c d"' + assert t.cast(exp.Query, m1.render_query()).sql() == '''SELECT "a b" AS "c1", "a b" AS "c2"''' + assert t.cast(exp.Query, m2.render_query()).sql() == '''SELECT 'c d' AS "c1", "c d" AS "c2"''' + + +def test_blueprint_variable_precedence_sql(tmp_path: Path, assert_exp_eq: t.Callable) -> None: + init_example_project(tmp_path, dialect="duckdb", template=ProjectTemplate.EMPTY) + + blueprint_variables = tmp_path / "models/blueprint_variables.sql" + blueprint_variables.parent.mkdir(parents=True, exist_ok=True) + blueprint_variables.write_text( + """ + MODEL ( + name s.@{bp_name}, + blueprints ( + (bp_name := m1, var1 := 'v1', var2 := 'v2'), + (bp_name := m2, var1 := 'v3'), + ), + ); + + @DEF(bp_name, override); + + SELECT + @var1 AS var1_macro_var, + @{var1} AS var1_identifier, + @VAR('var1') AS var1_var_macro_func, + @BLUEPRINT_VAR('var1') AS var1_blueprint_var_macro_func, + + @var2 AS var2_macro_var, + @{var2} AS var2_identifier, + @VAR('var2') AS var2_var_macro_func, + @BLUEPRINT_VAR('var2') AS var2_blueprint_var_macro_func, + + @bp_name AS bp_name_macro_var, + @{bp_name} AS bp_name_identifier, + @VAR('bp_name') AS bp_name_var_macro_func, + @BLUEPRINT_VAR('bp_name') AS bp_name_blueprint_var_macro_func, + """ + ) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"var2": "1"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 2 + + m1 = ctx.get_model("s.m1", raise_if_missing=True) + m2 = ctx.get_model("s.m2", raise_if_missing=True) + + assert_exp_eq( + m1.render_query(), + """ + SELECT + 'v1' AS "var1_macro_var", + "v1" AS "var1_identifier", + NULL AS "var1_var_macro_func", + 'v1' AS "var1_blueprint_var_macro_func", + 'v2' AS "var2_macro_var", + "v2" AS "var2_identifier", + '1' AS "var2_var_macro_func", + 'v2' AS "var2_blueprint_var_macro_func", + "override" AS "bp_name_macro_var", + "override" AS "bp_name_identifier", + NULL AS "bp_name_var_macro_func", + "m1" AS "bp_name_blueprint_var_macro_func" + """, + ) + assert_exp_eq( + m2.render_query(), + """ + SELECT + 'v3' AS "var1_macro_var", + "v3" AS "var1_identifier", + NULL AS "var1_var_macro_func", + 'v3' AS "var1_blueprint_var_macro_func", + '1' AS "var2_macro_var", + "1" AS "var2_identifier", + '1' AS "var2_var_macro_func", + NULL AS "var2_blueprint_var_macro_func", + "override" AS "bp_name_macro_var", + "override" AS "bp_name_identifier", + NULL AS "bp_name_var_macro_func", + "m2" AS "bp_name_blueprint_var_macro_func" + """, + ) + + +def test_blueprint_variable_precedence_python(tmp_path: Path, mocker: MockerFixture) -> None: + init_example_project(tmp_path, dialect="duckdb", template=ProjectTemplate.EMPTY) + + blueprint_variables = tmp_path / "models/blueprint_variables.py" + blueprint_variables.parent.mkdir(parents=True, exist_ok=True) + blueprint_variables.write_text( + """ +import pandas as pd +from sqlglot import exp +from sqlmesh import model + + +@model( + "s.@{bp_name}", + blueprints=[{"bp_name": "m", "var1": exp.to_column("v1"), "var2": 1}], + kind="FULL", + columns={"x": "INT"}, +) +def entrypoint(context, *args, **kwargs): + assert "bp_name" not in kwargs + assert "var1" not in kwargs + assert kwargs.get("var2") == "1" + + assert context.var("bp_name") is None + assert context.var("var1") is None + assert context.var("var2") == "1" + + assert context.blueprint_var("bp_name") == "m" + assert context.blueprint_var("var1") == exp.to_column("v1") + assert context.blueprint_var("var2") == 1 + + return pd.DataFrame({"x": [1]}) + """ + ) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"var2": "1"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 1 + + m = ctx.get_model("s.m", raise_if_missing=True) + context = ExecutionContext(mocker.Mock(), {}, None, None) + + assert t.cast(pd.DataFrame, list(m.render(context=context))[0]).to_dict() == {"x": {0: 1}} + + @time_machine.travel("2020-01-01 00:00:00 UTC") def test_dynamic_date_spine_model(assert_exp_eq): @macro()