diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index fc6dcf215d..6db75ed17e 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -134,3 +134,4 @@ SQLMesh provides two other predefined variables used to modify model behavior ba * Can be used in model definitions when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly. * Can be passed as an argument to macros that access or interact with the underlying physical table. * @this_env - A string value containing the name of the current [environment](../environments.md). Only available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them. +* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults). \ No newline at end of file diff --git a/docs/reference/model_configuration.md b/docs/reference/model_configuration.md index 31a874acc7..bd92478212 100644 --- a/docs/reference/model_configuration.md +++ b/docs/reference/model_configuration.md @@ -107,6 +107,34 @@ To override `partition_expiration_days`, add a new `creatable_type` property and ) ``` +You can also use the `@model_kind_name` variable to fine-tune control over `physical_properties` in `model_defaults`. This holds the current model's kind name and is useful for conditionally assigning a property. For example, to disable `creatable_type` for your project's `VIEW` kind models: + +=== "YAML" + + ```yaml linenums="1" + model_defaults: + dialect: snowflake + start: 2022-01-01 + physical_properties: + creatable_type: "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="snowflake", + start="2022-01-01", + physical_properties={ + "creatable_type": "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)", + }, + ), + ) + ``` + The SQLMesh project-level `model_defaults` key supports the following options, described in the [general model properties](#general-model-properties) table above: diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index 84e1fa8b76..ea6713b5a9 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -10,9 +10,9 @@ model_kind_validator, on_destructive_change_validator, ) +from sqlmesh.core.model.meta import FunctionCall from sqlmesh.core.node import IntervalUnit from sqlmesh.utils.date import TimeLike -from sqlmesh.core.model.meta import FunctionCall from sqlmesh.utils.pydantic import field_validator @@ -56,10 +56,10 @@ class ModelDefaultsConfig(BaseConfig): virtual_properties: t.Optional[t.Dict[str, t.Any]] = None session_properties: t.Optional[t.Dict[str, t.Any]] = None audits: t.Optional[t.List[FunctionCall]] = None - optimize_query: t.Optional[bool] = None - allow_partials: t.Optional[bool] = None - interval_unit: t.Optional[IntervalUnit] = None - enabled: t.Optional[bool] = None + optimize_query: t.Optional[t.Union[str, bool]] = None + allow_partials: t.Optional[t.Union[str, bool]] = None + interval_unit: t.Optional[t.Union[str, IntervalUnit]] = None + enabled: t.Optional[t.Union[str, bool]] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index b83199e090..e201006a6c 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -18,7 +18,9 @@ create_sql_model, create_models_from_blueprints, get_model_name, + parse_defaults_properties, render_meta_fields, + render_model_defaults, ) from sqlmesh.core.model.kind import ModelKindName, _ModelKind from sqlmesh.utils import registry_decorator, DECORATOR_RETURN_TYPE @@ -159,8 +161,25 @@ def model( if isinstance(rendered_name, exp.Expression): rendered_fields["name"] = rendered_name.sql(dialect=dialect) + rendered_defaults = ( + render_model_defaults( + defaults=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + if defaults + else {} + ) + + rendered_defaults = parse_defaults_properties(rendered_defaults, dialect=dialect) + common_kwargs = { - "defaults": defaults, + "defaults": rendered_defaults, "path": path, "time_column_format": time_column_format, "python_env": serialize_env(env, path=module_path), diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 9f8c7882bf..d2f59da77d 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -60,7 +60,7 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType - from sqlmesh.core._typing import Self, TableName + from sqlmesh.core._typing import Self, TableName, SessionProperties from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF @@ -71,14 +71,15 @@ logger = logging.getLogger(__name__) +PROPERTIES = {"physical_properties", "session_properties", "virtual_properties"} + RUNTIME_RENDERED_MODEL_FIELDS = { "audits", "signals", "description", "cron", - "physical_properties", "merge_filter", -} +} | PROPERTIES class _Model(ModelMeta, frozen=True): @@ -657,16 +658,20 @@ def render_merge_filter( raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}") return rendered_exprs[0].transform(d.replace_merge_table_aliases) - def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, exp.Expression]: - def _render(expression: exp.Expression) -> exp.Expression: + def _render_properties( + self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any + ) -> t.Dict[str, t.Any]: + def _render(expression: exp.Expression) -> exp.Expression | None: # note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which # in turn makes @this_model available in the evaluation context rendered_exprs = self._statement_renderer(expression).render(**render_kwargs) - if not rendered_exprs: - raise SQLMeshError( + # Warn instead of raising for cases where a property is conditionally assigned + if not rendered_exprs or rendered_exprs[0].sql().lower() in {"none", "null"}: + logger.warning( f"Expected rendering '{expression.sql(dialect=self.dialect)}' to return an expression" ) + return None if len(rendered_exprs) != 1: raise SQLMeshError( @@ -675,7 +680,20 @@ def _render(expression: exp.Expression) -> exp.Expression: return rendered_exprs[0] - return {k: _render(v) for k, v in self.physical_properties.items()} + return { + k: rendered + for k, v in properties.items() + if (rendered := (_render(v) if isinstance(v, exp.Expression) else v)) + } + + def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.physical_properties, **render_kwargs) + + def render_virtual_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.virtual_properties, **render_kwargs) + + def render_session_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.session_properties, **render_kwargs) def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer: return ExpressionRenderer( @@ -1989,8 +2007,21 @@ def load_sql_based_model( unrendered_merge_filter = None for prop in meta.expressions: + # Macro functions that programmaticaly generate the key-value pair properties should be rendered + # This is needed in the odd case where a macro shares the name of one of the properties + # eg `@session_properties()` Test: `test_macros_in_model_statement` Reference PR: #2574 + if isinstance(prop, d.MacroFunc): + continue + prop_name = prop.name.lower() - if prop_name in ("signals", "audits", "physical_properties"): + if ( + prop_name + in { + "signals", + "audits", + } + | PROPERTIES + ): unrendered_properties[prop_name] = prop.args.get("value") elif ( prop.name.lower() == "kind" @@ -2021,6 +2052,23 @@ def load_sql_based_model( rendered_meta = rendered_meta_exprs[0] + rendered_defaults = ( + render_model_defaults( + defaults=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + if defaults + else {} + ) + + rendered_defaults = parse_defaults_properties(rendered_defaults, dialect=dialect) + # Extract the query and any pre/post statements query_or_seed_insert, pre_statements, post_statements, on_virtual_update, inline_audits = ( _split_sql_model_statements(expressions[1:], path, dialect=dialect) @@ -2066,7 +2114,7 @@ def load_sql_based_model( pre_statements=pre_statements, post_statements=post_statements, on_virtual_update=on_virtual_update, - defaults=defaults, + defaults=rendered_defaults, path=path, module_path=module_path, macros=macros, @@ -2226,9 +2274,9 @@ def create_python_model( for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions } - variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} - if variables: - python_env[c.SQLMESH_VARS] = Executable.value(variables) + used_variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} + if used_variables: + python_env[c.SQLMESH_VARS] = Executable.value(used_variables) return _create_model( PythonModel, @@ -2302,11 +2350,7 @@ def _create_model( _validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path) - for prop in [ - "session_properties", - "physical_properties", - "virtual_properties", - ]: + for prop in PROPERTIES: kwargs[prop] = _resolve_properties((defaults or {}).get(prop), kwargs.get(prop)) dialect = dialect or "" @@ -2338,10 +2382,12 @@ def _create_model( statements.extend(kwargs["post_statements"]) if "on_virtual_update" in kwargs: statements.extend(kwargs["on_virtual_update"]) - if physical_properties := kwargs.get("physical_properties"): - # to allow variables like @gateway to be used in physical_properties - # since rendering shifted from load time to run time - statements.extend(physical_properties) + + # to allow variables like @gateway to be used in these properties + # since rendering shifted from load time to run time + for property_name in PROPERTIES: + if property_values := kwargs.get(property_name): + statements.extend(property_values) jinja_macro_references, used_variables = extract_macro_references_and_variables( *(gen(e) for e in statements) @@ -2573,9 +2619,7 @@ def render_meta_fields( default_catalog: t.Optional[str], ) -> t.Dict[str, t.Any]: def render_field_value(value: t.Any) -> t.Any: - if isinstance(value, exp.Expression) or ( - isinstance(value, str) and d.SQLMESH_MACRO_PREFIX in value - ): + if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value): expression = exp.maybe_parse(value, dialect=dialect) rendered_expr = render_expression( expression=expression, @@ -2587,16 +2631,20 @@ def render_field_value(value: t.Any) -> t.Any: dialect=dialect, default_catalog=default_catalog, ) - if rendered_expr is None: + if not rendered_expr: raise SQLMeshError( - f"Failed to render model attribute `{fields['name']}` at `{path}`\n" - f"'{expression.sql(dialect=dialect)}' must return an expression" + f"Rendering `{expression.sql(dialect=dialect)}` did not return an expression" ) + if len(rendered_expr) != 1: raise SQLMeshError( - f"Failed to render model attribute `{fields['name']}` at `{path}`.\n" - f"`{expression.sql(dialect=dialect)}` must return one result, but got {len(rendered_expr)}" + f"Rendering `{expression.sql(dialect=dialect)}` must return one result, but got {len(rendered_expr)}" ) + + # For cases where a property is conditionally assigned + if rendered_expr[0].sql().lower() in {"none", "null"}: + return None + return rendered_expr[0] return value @@ -2605,17 +2653,81 @@ def render_field_value(value: t.Any) -> t.Any: field = field_info.alias or field_name if field not in RUNTIME_RENDERED_MODEL_FIELDS and (field_value := fields.get(field)): if isinstance(field_value, dict): - for key in list(field_value.keys()): - if key not in RUNTIME_RENDERED_MODEL_FIELDS: - fields[field][key] = render_field_value(field_value[key]) + rendered_dict = {} + for key, value in field_value.items(): + if key in RUNTIME_RENDERED_MODEL_FIELDS: + rendered_dict[key] = value + elif rendered := render_field_value(value): + rendered_dict[key] = rendered + if rendered_dict: + fields[field] = rendered_dict + else: + fields.pop(field) elif isinstance(field_value, list): - fields[field] = [render_field_value(value) for value in field_value] + if rendered_list := [ + rendered for value in field_value if (rendered := render_field_value(value)) + ]: + fields[field] = rendered_list + else: + fields.pop(field) else: - fields[field] = render_field_value(field_value) + if rendered_field := render_field_value(field_value): + fields[field] = rendered_field + else: + fields.pop(field) return fields +def render_model_defaults( + defaults: t.Dict[str, t.Any], + module_path: Path, + path: Path, + jinja_macros: t.Optional[JinjaMacroRegistry], + macros: t.Optional[MacroRegistry], + dialect: DialectType, + variables: t.Optional[t.Dict[str, t.Any]], + default_catalog: t.Optional[str], +) -> t.Dict[str, t.Any]: + rendered_defaults = render_meta_fields( + fields=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + + # Validate defaults that have macros are rendered to boolean + for boolean in {"optimize_query", "allow_partials", "enabled"}: + if var := rendered_defaults.get(boolean): + if not isinstance(var, (exp.Boolean, bool)): + raise ConfigError(f"Expected boolean for '{var}', got '{type(var)}' instead") + + # Validate the 'interval_unit' if present is an Interval Unit + if (var := rendered_defaults.get("interval_unit")) and isinstance(var, str): + try: + rendered_defaults["interval_unit"] = IntervalUnit(var) + except ValueError as e: + raise ConfigError(f"Invalid interval unit: {var}") from e + + return rendered_defaults + + +def parse_defaults_properties( + defaults: t.Dict[str, t.Any], dialect: DialectType +) -> t.Dict[str, t.Any]: + for prop in PROPERTIES: + if default_properties := defaults.get(prop): + for key, value in default_properties.items(): + if isinstance(key, str) and d.SQLMESH_MACRO_PREFIX in str(value): + defaults[prop][key] = exp.maybe_parse(value, dialect=dialect) + + return defaults + + def render_expression( expression: exp.Expression, module_path: Path, diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index efbebef761..3e73615d0d 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -110,14 +110,16 @@ def _render( this_model = kwargs.pop("this_model", None) + this_snapshot = (snapshots or {}).get(self._model_fqn) if self._model_fqn else None if not this_model and self._model_fqn: - this_snapshot = (snapshots or {}).get(self._model_fqn) this_model = self._resolve_table( self._model_fqn, snapshots={self._model_fqn: this_snapshot} if this_snapshot else None, deployability_index=deployability_index, table_mapping=table_mapping, ) + if this_snapshot and (kind := this_snapshot.model_kind_name): + kwargs["model_kind_name"] = kind.name expressions = [self._expression] diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 1527fa1dc1..dc8d86c333 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -667,7 +667,9 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: physical_properties=rendered_physical_properties, ) - with adapter.transaction(), adapter.session(snapshot.model.session_properties): + with adapter.transaction(), adapter.session( + snapshot.model.render_session_properties(**render_statements_kwargs) + ): wap_id: t.Optional[str] = None if ( table_name @@ -766,7 +768,9 @@ def _create_snapshot( deployability_index=deployability_index, ) - with adapter.transaction(), adapter.session(snapshot.model.session_properties): + with adapter.transaction(), adapter.session( + snapshot.model.render_session_properties(**create_render_kwargs) + ): rendered_physical_properties = snapshot.model.render_physical_properties( **create_render_kwargs ) @@ -886,7 +890,9 @@ def _migrate_snapshot( runtime_stage=RuntimeStage.CREATING, deployability_index=deployability_index, ) - with adapter.transaction(), adapter.session(snapshot.model.session_properties): + with adapter.transaction(), adapter.session( + snapshot.model.render_session_properties(**render_kwargs) + ): self._execute_create( snapshot=snapshot, table_name=target_table_name, @@ -917,12 +923,6 @@ def _promote_snapshot( view_name = snapshot.qualified_view_name.for_environment( environment_naming_info, dialect=adapter.dialect ) - _evaluation_strategy(snapshot, adapter).promote( - table_name=table_name, - view_name=view_name, - model=snapshot.model, - environment=environment_naming_info.name, - ) render_kwargs: t.Dict[str, t.Any] = dict( start=start, end=end, @@ -933,6 +933,13 @@ def _promote_snapshot( table_mapping=table_mapping, runtime_stage=RuntimeStage.PROMOTING, ) + _evaluation_strategy(snapshot, adapter).promote( + table_name=table_name, + view_name=view_name, + model=snapshot.model, + environment=environment_naming_info.name, + **render_kwargs, + ) adapter.execute(snapshot.model.render_on_virtual_update(**render_kwargs)) if on_complete is not None: @@ -1366,12 +1373,22 @@ def promote( ) -> None: is_prod = environment == c.PROD logger.info("Updating view '%s' to point at table '%s'", view_name, table_name) + render_kwargs: t.Dict[str, t.Any] = dict( + start=kwargs.get("start"), + end=kwargs.get("end"), + execution_time=kwargs.get("execution_time"), + engine_adapter=kwargs.get("engine_adapter"), + snapshots=kwargs.get("snapshots"), + deployability_index=kwargs.get("deployability_index"), + table_mapping=kwargs.get("table_mapping"), + runtime_stage=kwargs.get("runtime_stage"), + ) self.adapter.create_view( view_name, exp.select("*").from_(table_name, dialect=self.adapter.dialect), table_description=model.description if is_prod else None, column_descriptions=model.column_descriptions if is_prod else None, - view_properties=model.virtual_properties, + view_properties=model.render_virtual_properties(**render_kwargs), ) def demote(self, view_name: str, **kwargs: t.Any) -> None: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9431eef30e..c14c1996b7 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -2064,7 +2064,7 @@ def my_model(context, **kwargs): assert m.depends_on == {'"foo"."table_name"'} -def test_python_model_with_properties(): +def test_python_model_with_properties(make_snapshot): @model( name="python_model_prop", kind="full", @@ -2087,7 +2087,8 @@ def python_model_prop(context, **kwargs): }, "physical_properties": { "partition_expiration_days": 13, - "creatable_type": "TRANSIENT", + "creatable_type": "@IF(@model_kind_name != 'view', 'TRANSIENT', NULL)", + "conditional_prop": "@IF(@model_kind_name == 'view', 'view_prop', NULL)", }, "virtual_properties": { "creatable_type": "SECURE", @@ -2102,6 +2103,20 @@ def python_model_prop(context, **kwargs): } assert m.physical_properties == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name != 'view', 'TRANSIENT', NULL)", dialect="duckdb" + ), + "conditional_prop": exp.maybe_parse( + "@IF(@model_kind_name == 'view', 'view_prop', NULL)", dialect="duckdb" + ), + } + + snapshot: Snapshot = make_snapshot(m) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Rendering the properties will result to a TRANSIENT creatable_type and the removal of the conditional prop + assert m.render_physical_properties(snapshots={m.fqn: snapshot}, python_env=m.python_env) == { "partition_expiration_days": exp.convert(7), "creatable_type": exp.convert("TRANSIENT"), } @@ -3611,6 +3626,71 @@ def test_project_level_properties(sushi_context): assert model_2.cron == "@daily" +def test_conditional_physical_properties(make_snapshot): + model_defaults = ModelDefaultsConfig( + physical_properties={ + "creatable_type": "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)" + }, + ) + + full_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_full_model, + kind FULL, + ); + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ), + defaults=model_defaults.dict(), + ) + + view_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_view_model_kind, + kind VIEW, + ); + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ), + defaults=model_defaults.dict(), + ) + + # load time is a no-op + assert ( + view_model.physical_properties + == full_model.physical_properties + == { + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)", dialect="snowflake" + ) + } + ) + + # substitution occurs at runtime + snapshot: Snapshot = make_snapshot(full_model) + snapshot_view: Snapshot = make_snapshot(view_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Validate use of TRANSIENT type for FULL model + assert full_model.render_physical_properties( + snapshots={full_model.fqn: snapshot, view_model.fqn: snapshot_view} + ) == {"creatable_type": exp.Literal(this="TRANSIENT", is_string=True)} + + # Validate disabling the creatable_type property for VIEW model + assert ( + view_model.render_physical_properties( + snapshots={full_model.fqn: snapshot, view_model.fqn: snapshot_view} + ) + == {} + ) + + def test_project_level_properties_python_model(): model_defaults = { "physical_properties": { @@ -3653,6 +3733,236 @@ def python_model_prop(context, **kwargs): assert m.interval_unit == IntervalUnit.QUARTER_HOUR +def test_model_defaults_macros(make_snapshot): + model_defaults = ModelDefaultsConfig( + table_format="@IF(@gateway = 'dev', 'iceberg', NULL)", + storage_format="@IF(@gateway = 'local', 'parquet', NULL)", + optimize_query="@IF(@gateway = 'dev', True, False)", + enabled="@IF(@gateway = 'dev', True, False)", + allow_partials="@IF(@gateway = 'local', True, False)", + interval_unit="@IF(@gateway = 'local', 'quarter_hour', 'day')", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + session_properties={ + "spark.executor.cores": "@IF(@gateway = 'dpev', 1, 2)", + "spark.executor.memory": "1G", + "unset_property": "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", + }, + physical_properties={ + "partition_expiration_days": 13, + "creatable_type": "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", + }, + virtual_properties={ + "creatable_type": "@create_type", + "unset_property": "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", + }, + ) + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + physical_properties ( + target_lag = '1 hour' + ), + ); + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ), + defaults=model_defaults.dict(), + variables={"gateway": "dev", "create_type": "SECURE"}, + ) + + snapshot: Snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Validate rendering of model defaults + assert model.optimize_query + assert model.enabled + assert model.start == "1 month ago" + assert not model.allow_partials + assert model.interval_unit == IntervalUnit.DAY + assert model.table_format == "iceberg" + + # Validate disabling of conditional model default + assert not model.storage_format + + # The model defaults properties won't be rendered at load time + assert model.session_properties == { + "spark.executor.cores": exp.maybe_parse( + "@IF(@gateway = 'dpev', 1, 2)", dialect="snowflake" + ), + "spark.executor.memory": "1G", + "unset_property": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", dialect="snowflake" + ), + } + + assert model.physical_properties == { + "partition_expiration_days": exp.convert(13), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", dialect="snowflake" + ), + "target_lag": exp.convert("1 hour"), + } + + assert model.virtual_properties == { + "creatable_type": d.MacroVar(this="create_type"), + "unset_property": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", dialect="snowflake" + ), + } + + # Validate the correct rendering and removal of conditional properties + assert model.render_session_properties(snapshots={model.fqn: snapshot}) == { + "spark.executor.cores": exp.convert(2), + "spark.executor.memory": "1G", + } + + assert model.render_physical_properties(snapshots={model.fqn: snapshot}) == { + "partition_expiration_days": exp.convert(13), + "target_lag": exp.convert("1 hour"), + } + + assert model.render_virtual_properties(snapshots={model.fqn: snapshot}) == { + "creatable_type": exp.convert("SECURE"), + } + + +def test_model_defaults_macros_python_model(make_snapshot): + model_defaults = { + "physical_properties": { + "partition_expiration_days": 13, + "creatable_type": "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", + }, + "table_format": "@IF(@gateway = 'local', 'iceberg', NULL)", + "storage_format": "@IF(@gateway = 'dev', 'parquet', NULL)", + "optimize_query": "@IF(@gateway = 'local', True, False)", + "enabled": "@IF(@gateway = 'local', True, False)", + "allow_partials": "@IF(@gateway = 'local', True, False)", + "interval_unit": "@IF(@gateway = 'local', 'quarter_hour', 'day')", + "start": "@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + "virtual_properties": {"creatable_type": "@create_type"}, + "session_properties": { + "spark.executor.cores": "@IF(@gateway = 'dev', 1, 2)", + "spark.executor.memory": "1G", + }, + } + + @model( + name="python_model_defaults_macro", + kind="full", + columns={"some_col": "int"}, + physical_properties={"partition_expiration_days": 7}, + ) + def python_model_prop_macro(context, **kwargs): + context.resolve_table("foo") + + m = model.get_registry()["python_model_defaults_macro"].model( + module_path=Path("."), + path=Path("."), + dialect="duckdb", + defaults=model_defaults, + variables={"gateway": "local", "create_type": "SECURE"}, + ) + + # Even if in the project wide defaults this is ignored for python models + assert not m.optimize_query + + # Validate rendering of model defaults + assert m.enabled + assert m.start == "2024-01-01" + assert m.allow_partials + assert m.interval_unit == IntervalUnit.QUARTER_HOUR + assert m.table_format == "iceberg" + + # Validate disabling attribute dynamically + assert not m.storage_format + + snapshot: Snapshot = make_snapshot(m) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Ensure properties are not rendered at load time + assert m.physical_properties == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", dialect="duckdb" + ), + } + + # Substitution occurs at runtime for properties so here these will be unrendered + assert m.render_physical_properties( + snapshots={ + m.fqn: snapshot, + } + ) == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), + } + + assert m.session_properties == { + "spark.executor.cores": exp.maybe_parse("@IF(@gateway = 'dev', 1, 2)", dialect="duckdb"), + "spark.executor.memory": "1G", + } + + assert m.virtual_properties["creatable_type"] == d.MacroVar(this="create_type") + + # Validate rendering of properties + assert m.render_session_properties( + snapshots={ + m.fqn: snapshot, + }, + ) == { + "spark.executor.cores": exp.convert(2), + "spark.executor.memory": "1G", + } + + assert m.render_virtual_properties( + snapshots={ + m.fqn: snapshot, + } + ) == {"creatable_type": exp.convert("SECURE")} + + assert m.render_physical_properties(snapshots={m.fqn: snapshot}) == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), + } + + +@pytest.mark.parametrize( + "optimize_query, enabled, allow_partials, interval_unit, expected_error", + [ + ("string", "string", "string", "string", r"^Expected boolean for*"), + (True, "string", "string", "string", r"^Expected boolean for*"), + (True, True, "string", "string", r"^Expected boolean for*"), + (True, True, True, "string", r"^Invalid interval unitr*"), + ], +) +def test_model_defaults_validations( + optimize_query, enabled, allow_partials, interval_unit, expected_error +): + model_defaults = ModelDefaultsConfig( + optimize_query=optimize_query, + enabled=enabled, + allow_partials=allow_partials, + interval_unit=interval_unit, + ) + + with pytest.raises(ConfigError, match=expected_error): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + ); + SELECT a FROM tbl; + """, + ), + defaults=model_defaults.dict(), + ) + + def test_model_session_properties(sushi_context): assert sushi_context.models['"memory"."sushi"."items"'].session_properties == { "string_prop": "some_value", @@ -5375,13 +5685,13 @@ def model_with_macros(evaluator, **kwargs): assert "location1" in python_sql_model.physical_properties assert "location2" in python_sql_model.physical_properties + # The properties will stay unrendered at load time assert python_sql_model.session_properties == { - "spark.executor.cores": 1, + "spark.executor.cores": "@IF(@gateway = 'dev', 1, 2)", "spark.executor.memory": "1G", } - assert python_sql_model.virtual_properties["creatable_type"].this == "SECURE" + assert python_sql_model.virtual_properties["creatable_type"] == exp.convert("@{create_type}") - # The physical_properties will stay unrendered at load time assert ( python_sql_model.physical_properties["location1"].text("this") == "@'s3://bucket/prefix/@{schema_name}/@{table_name}'" @@ -5727,7 +6037,8 @@ def test_macros_in_physical_properties(make_snapshot): @resolve_template('hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}'), @resolve_template('s3://prod/@{table_name}') ), - sort_order = @IF(@gateway = 'prod', 'desc', 'asc') + sort_order = @IF(@gateway = 'prod', 'desc', 'asc'), + conditional_prop = @IF(@gateway == 'prod', 'PROD_PROP', NULL) ) ); @@ -5743,11 +6054,13 @@ def test_macros_in_physical_properties(make_snapshot): assert "location1" in model.physical_properties assert "location2" in model.physical_properties assert "sort_order" in model.physical_properties + assert "conditional_prop" in model.physical_properties # load time is a no-op assert isinstance(model.physical_properties["location1"], d.MacroFunc) assert isinstance(model.physical_properties["location2"], d.MacroFunc) assert isinstance(model.physical_properties["sort_order"], d.MacroFunc) + assert isinstance(model.physical_properties["conditional_prop"], d.MacroFunc) # substitution occurs at runtime snapshot = make_snapshot(model) @@ -5769,6 +6082,9 @@ def test_macros_in_physical_properties(make_snapshot): ) assert rendered_physical_properties["sort_order"].text("this") == "asc" + # the conditional_prop will be disabled for "dev" gateway + assert "conditional_prop" not in rendered_physical_properties + def test_macros_in_model_statement(sushi_context, assert_exp_eq): @macro()