From 0a735d7ddc2350877032568b5814df1f7b26fa56 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Tue, 4 Feb 2025 10:59:21 +0200 Subject: [PATCH 1/8] Feat: Introduce model linting --- examples/multi/repo_1/config.yaml | 5 + examples/multi/repo_1/linter/user.py | 15 ++ examples/multi/repo_2/config.yaml | 5 + examples/sushi/linter/user.py | 13 ++ sqlmesh/core/config/__init__.py | 1 + sqlmesh/core/config/linter.py | 44 +++++ sqlmesh/core/config/model.py | 2 - sqlmesh/core/config/root.py | 3 + sqlmesh/core/constants.py | 1 + sqlmesh/core/context.py | 32 +++- sqlmesh/core/linter/__init__.py | 0 sqlmesh/core/linter/definition.py | 72 ++++++++ sqlmesh/core/linter/rule.py | 97 ++++++++++ sqlmesh/core/linter/rules/__init__.py | 1 + sqlmesh/core/linter/rules/builtin.py | 51 +++++ sqlmesh/core/loader.py | 27 +++ sqlmesh/core/model/cache.py | 13 +- sqlmesh/core/model/common.py | 1 - sqlmesh/core/model/definition.py | 21 ++- sqlmesh/core/model/meta.py | 13 +- sqlmesh/core/model/schema.py | 6 +- sqlmesh/core/renderer.py | 35 +--- .../migrations/v0074_remove_validate_query.py | 82 +++++++++ tests/core/test_context.py | 134 +++++++++++++- tests/core/test_integration.py | 15 +- tests/core/test_model.py | 167 +++++++---------- tests/core/test_snapshot.py | 6 +- tests/core/test_state_sync.py | 174 ++++++++++-------- 28 files changed, 803 insertions(+), 233 deletions(-) create mode 100644 examples/multi/repo_1/linter/user.py create mode 100644 examples/sushi/linter/user.py create mode 100644 sqlmesh/core/config/linter.py create mode 100644 sqlmesh/core/linter/__init__.py create mode 100644 sqlmesh/core/linter/definition.py create mode 100644 sqlmesh/core/linter/rule.py create mode 100644 sqlmesh/core/linter/rules/__init__.py create mode 100644 sqlmesh/core/linter/rules/builtin.py create mode 100644 sqlmesh/migrations/v0074_remove_validate_query.py diff --git a/examples/multi/repo_1/config.yaml b/examples/multi/repo_1/config.yaml index c84b849958..0f35441b86 100644 --- a/examples/multi/repo_1/config.yaml +++ b/examples/multi/repo_1/config.yaml @@ -20,3 +20,8 @@ after_all: model_defaults: dialect: 'duckdb' + +linter: + enabled: True + + warn_rules: "ALL" \ No newline at end of file diff --git a/examples/multi/repo_1/linter/user.py b/examples/multi/repo_1/linter/user.py new file mode 100644 index 0000000000..1dfc7c8ae2 --- /dev/null +++ b/examples/multi/repo_1/linter/user.py @@ -0,0 +1,15 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingDescription(Rule): + """All models should be documented.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + return self.violation() if not model.description else None diff --git a/examples/multi/repo_2/config.yaml b/examples/multi/repo_2/config.yaml index bc4603f0f5..23bec6d8fe 100644 --- a/examples/multi/repo_2/config.yaml +++ b/examples/multi/repo_2/config.yaml @@ -20,3 +20,8 @@ after_all: model_defaults: dialect: 'duckdb' + +linter: + enabled: True + + ignored_rules: "ALL" \ No newline at end of file diff --git a/examples/sushi/linter/user.py b/examples/sushi/linter/user.py new file mode 100644 index 0000000000..a4a33c0efc --- /dev/null +++ b/examples/sushi/linter/user.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingOwner(Rule): + """All models should have an owner specified.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + return self.violation() if not model.owner else None diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index b72a271f9d..6017fc8895 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -30,6 +30,7 @@ from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig as PlanConfig from sqlmesh.core.config.root import Config as Config from sqlmesh.core.config.run import RunConfig as RunConfig diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py new file mode 100644 index 0000000000..c2a40e09aa --- /dev/null +++ b/sqlmesh/core/config/linter.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.helper import ensure_collection + +from sqlmesh.core.config.base import BaseConfig + +from sqlmesh.utils.pydantic import field_validator + + +class LinterConfig(BaseConfig): + """Configuration for model linting + + Args: + enabled: Flag indicating whether the linter should run + + rules: A list of error rules to be applied on model + warn_rules: A list of rules to be applied on models but produce warnings instead of raising errors. + ignored_rules: A list of rules to be excluded/ignored + + """ + + enabled: bool = False + + rules: t.Set[str] = set() + warn_rules: t.Set[str] = set() + ignored_rules: t.Set[str] = set() + + @classmethod + def _validate_rules(cls, v: t.Any) -> t.Set[str]: + if isinstance(v, exp.Paren): + v = v.unnest().name + elif isinstance(v, (exp.Tuple, exp.Array)): + v = [e.name for e in v.expressions] + elif isinstance(v, exp.Expression): + v = v.name + + return {name.lower() for name in ensure_collection(v)} + + @field_validator("rules", "warn_rules", "ignored_rules", mode="before") + def rules_validator(cls, vs: t.Any) -> t.Set[str]: + return cls._validate_rules(vs) diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index 62ec7c5935..84e1fa8b76 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -38,7 +38,6 @@ class ModelDefaultsConfig(BaseConfig): session_properties: A key-value mapping of properties specific to the target engine that are applied to the engine session. audits: The audits to be applied globally to all models in the project. optimize_query: Whether the SQL models should be optimized. - validate_query: Whether the SQL models should be validated at compile time. allow_partials: Whether the models can process partial (incomplete) data intervals. enabled: Whether the models are enabled. interval_unit: The temporal granularity of the models data intervals. By default computed from cron. @@ -58,7 +57,6 @@ class ModelDefaultsConfig(BaseConfig): session_properties: t.Optional[t.Dict[str, t.Any]] = None audits: t.Optional[t.List[FunctionCall]] = None optimize_query: t.Optional[bool] = None - validate_query: t.Optional[bool] = None allow_partials: t.Optional[bool] = None interval_unit: t.Optional[IntervalUnit] = None enabled: t.Optional[bool] = None diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index b3fa1cb801..c1642ccd43 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -28,6 +28,7 @@ from sqlmesh.core.config.migration import MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig from sqlmesh.core.config.run import RunConfig from sqlmesh.core.config.scheduler import ( @@ -124,6 +125,7 @@ class Config(BaseConfig): disable_anonymized_analytics: bool = False before_all: t.Optional[t.List[str]] = None after_all: t.Optional[t.List[str]] = None + linter: LinterConfig = LinterConfig() _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { "gateways": UpdateStrategy.NESTED_UPDATE, @@ -141,6 +143,7 @@ class Config(BaseConfig): "plan": UpdateStrategy.NESTED_UPDATE, "before_all": UpdateStrategy.EXTEND, "after_all": UpdateStrategy.EXTEND, + "linter": UpdateStrategy.NESTED_UPDATE, } _connection_config_validator = connection_config_validator diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index df204cfbd9..abc096c9a6 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -60,6 +60,7 @@ AUDITS = "audits" CACHE = ".cache" EXTERNAL_MODELS = "external_models" +LINTER = "linter" MACROS = "macros" MATERIALIZATIONS = "materializations" METRICS = "metrics" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index ee7b75bfea..9fd6f0366b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -75,6 +75,8 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements from sqlmesh.core.loader import Loader +from sqlmesh.core.linter.definition import Linter +from sqlmesh.core.linter.rules import BUILTIN_RULES from sqlmesh.core.macros import ExecutableOrMacro, macro from sqlmesh.core.metric import Metric, rewrite from sqlmesh.core.model import Model, update_model_schemas @@ -349,6 +351,7 @@ def __init__( self._environment_statements: t.List[EnvironmentStatements] = [] self._excluded_requirements: t.Set[str] = set() self._default_catalog: t.Optional[str] = None + self._linters: t.Dict[str, Linter] = {} self._loaded: bool = False self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -490,13 +493,22 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: } ) - update_model_schemas(self.dag, models=self._models, context_path=self.path) + update_model_schemas( + self.dag, + models=self._models, + context_path=self.path, + linters=self._linters, + ) if model.dialect: self._all_dialects.add(model.dialect) model.validate_definition() + # Linter may be `None` if the context is not loaded yet + if linter := self._linters.get(model.project): + linter.lint_model(model) + return model def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: @@ -576,9 +588,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._metrics.clear() self._requirements.clear() self._excluded_requirements.clear() + self._linters.clear() self._environment_statements = [] - for project in loaded_projects: + for loader, project in zip(self._loaders, loaded_projects): self._jinja_macros = self._jinja_macros.merge(project.jinja_macros) self._macros.update(project.macros) self._models.update(project.models) @@ -590,6 +603,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: if project.environment_statements: self._environment_statements.append(project.environment_statements) + config = loader.config + self._linters[config.project] = Linter.from_rules( + BUILTIN_RULES.union(project.user_rules), config.linter + ) + uncached = set() if any(self._projects): @@ -621,12 +639,20 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._models.update({fqn: model.copy(update={"mapping_schema": {}})}) continue - update_model_schemas(self.dag, models=self._models, context_path=self.path) + update_model_schemas( + self.dag, + models=self._models, + context_path=self.path, + linters=self._linters, + ) for model in self.models.values(): # The model definition can be validated correctly only after the schema is set. model.validate_definition() + if linter := self._linters.get(model.project): + linter.lint_model(model) + duplicates = set(self._models) & set(self._standalone_audits) if duplicates: raise ConfigError( diff --git a/sqlmesh/core/linter/__init__.py b/sqlmesh/core/linter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py new file mode 100644 index 0000000000..9334e40b5f --- /dev/null +++ b/sqlmesh/core/linter/definition.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.config.linter import LinterConfig + +from sqlmesh.core.model import Model + +from sqlmesh.utils.errors import raise_config_error +from sqlmesh.core.console import get_console +from sqlmesh.core.linter.rule import RuleSet + + +def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: + if "all" in rule_names: + return all_rules + + rules = set() + for rule_name in rule_names: + if rule_name not in all_rules: + raise_config_error(f"Rule {rule_name} could not be found") + + rules.add(all_rules[rule_name]) + + return RuleSet(rules) + + +class Linter: + def __init__( + self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet + ) -> None: + self.enabled = enabled + self.all_rules = all_rules + self.rules = rules + self.warn_rules = warn_rules + + @classmethod + def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: + ignored_rules = select_rules(all_rules, config.ignored_rules) + included_rules = all_rules.difference(ignored_rules) + + rules = select_rules(included_rules, config.rules) + warn_rules = select_rules(included_rules, config.warn_rules) + + if overlapping := rules.intersection(warn_rules): + overlapping_rules = ", ".join(rule for rule in overlapping) + raise_config_error( + f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" + ) + + return Linter(config.enabled, all_rules, rules, warn_rules) + + def lint_model(self, model: Model) -> None: + if not self.enabled: + return + + ignored_rules = select_rules(self.all_rules, model.ignored_rules) + + rules = self.rules.difference(ignored_rules) + warn_rules = self.warn_rules.difference(ignored_rules) + + error_violations = rules.check_model(model) + warn_violations = warn_rules.check_model(model) + + if warn_violations: + warn_msg = "\n".join(f" - {warn_violation}" for warn_violation in warn_violations) + get_console().log_warning(f"Linter warnings for {model._path}:\n{warn_msg}") + + if error_violations: + error_msg = "\n".join(f" - {error_violations}" for error_violations in error_violations) + + raise_config_error(f"Linter error for {model._path}:\n{error_msg}") diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py new file mode 100644 index 0000000000..93d80a52f4 --- /dev/null +++ b/sqlmesh/core/linter/rule.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import abc + +import operator as op +from collections.abc import Iterator, Iterable, Set, Mapping, Callable +from functools import reduce + +from sqlmesh.core.model import Model + +from typing import Type + +import typing as t + + +class _Rule(abc.ABCMeta): + def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule: + attrs["name"] = clsname.lower() + return super().__new__(cls, clsname, bases, attrs) + + +class Rule(abc.ABC, metaclass=_Rule): + """The base class for a rule.""" + + name = "rule" + + @abc.abstractmethod + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + """The evaluation function that'll check for a violation of this rule.""" + + @property + def summary(self) -> str: + """A summary of what this rule checks for.""" + return self.__doc__ or "" + + def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation: + """Create a RuleViolation instance for this rule""" + return RuleViolation(rule=self, violation_msg=violation_msg or self.summary) + + def __repr__(self) -> str: + return self.name + + +class RuleViolation: + def __init__(self, rule: Rule, violation_msg: str) -> None: + self.rule = rule + self.violation_msg = violation_msg + + def __repr__(self) -> str: + return f"{self.rule.name}: {self.violation_msg}" + + +class RuleSet(Mapping[str, type[Rule]]): + def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: + self._underlying = {rule.name: rule for rule in rules} + + def check_model(self, model: Model) -> t.List[RuleViolation]: + violations = [] + + for rule in self._underlying.values(): + violation = rule().check_model(model) + + if violation: + violations.append(violation) + + return violations + + def __iter__(self) -> Iterator[str]: + return iter(self._underlying) + + def __len__(self) -> int: + return len(self._underlying) + + def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: + key = rule if isinstance(rule, str) else rule.name + return self._underlying[key] + + def __op( + self, + op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], + other: RuleSet, + /, + ) -> RuleSet: + rules = set() + for rule in op(set(self.values()), set(other.values())): + rules.add(other[rule] if rule in other else self[rule]) + + return RuleSet(rules) + + def union(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) + + def intersection(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) + + def difference(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others)) diff --git a/sqlmesh/core/linter/rules/__init__.py b/sqlmesh/core/linter/rules/__init__.py new file mode 100644 index 0000000000..43812479a5 --- /dev/null +++ b/sqlmesh/core/linter/rules/__init__.py @@ -0,0 +1 @@ +from sqlmesh.core.linter.rules.builtin import BUILTIN_RULES as BUILTIN_RULES diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py new file mode 100644 index 0000000000..664bcc9b23 --- /dev/null +++ b/sqlmesh/core/linter/rules/builtin.py @@ -0,0 +1,51 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlglot.helper import subclasses + +from sqlmesh.core.linter.rule import Rule, RuleViolation, RuleSet +from sqlmesh.core.model import Model, SqlModel + + +class NoSelectStar(Rule): + """Query should not contain SELECT * on its outer most projections, even if it can be expanded.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + if not isinstance(model, SqlModel): + return None + + return self.violation() if model.query.is_star else None + + +class InvalidSelectStarExpansion(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + deps = model.violated_rules_for_query.get(InvalidSelectStarExpansion) + if not deps: + return None + + violation_msg = ( + f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " + "Run `sqlmesh create_external_models` and / or make sure that the model " + f"'{model.fqn}' can be rendered at parse time." + ) + + return self.violation(violation_msg) + + +class AmbiguousOrInvalidColumn(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + sqlglot_err = model.violated_rules_for_query.get(AmbiguousOrInvalidColumn) + if not sqlglot_err: + return None + + violation_msg = ( + f"{sqlglot_err} for model '{model.fqn}', the column may not exist or is ambiguous." + ) + + return self.violation(violation_msg) + + +BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,))) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index f90462d9ee..bb9a669ccd 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -12,11 +12,13 @@ from sqlglot.errors import SqlglotError from sqlglot import exp +from sqlglot.helper import subclasses from sqlmesh.core import constants as c from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits from sqlmesh.core.dialect import parse from sqlmesh.core.environment import EnvironmentStatements +from sqlmesh.core.linter.rule import RuleSet, Rule from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl from sqlmesh.core.model import ( @@ -54,6 +56,7 @@ class LoadedProject: requirements: t.Dict[str, str] excluded_requirements: t.Set[str] environment_statements: t.Optional[EnvironmentStatements] + user_rules: RuleSet class Loader(abc.ABC): @@ -120,6 +123,8 @@ def load(self) -> LoadedProject: environment_statements = self._load_environment_statements(macros=macros) + user_rules = self._load_linting_rules() + project = LoadedProject( macros=macros, jinja_macros=jinja_macros, @@ -130,6 +135,7 @@ def load(self) -> LoadedProject: requirements=requirements, excluded_requirements=excluded_requirements, environment_statements=environment_statements, + user_rules=user_rules, ) return project @@ -265,6 +271,10 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: return requirements, excluded_requirements + def _load_linting_rules(self) -> RuleSet: + """Loads user linting rules""" + return RuleSet() + def _glob_paths( self, path: Path, @@ -630,6 +640,23 @@ def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStat return EnvironmentStatements(**statements, python_env=python_env) return None + def _load_linting_rules(self) -> RuleSet: + user_rules: UniqueKeyDict[str, type[Rule]] = UniqueKeyDict("rules") + + for path in self._glob_paths( + self.config_path / c.LINTER, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if os.path.getsize(path): + self._track_file(path) + module = import_python_file(path, self.config_path) + module_rules = subclasses(module.__name__, Rule, (Rule,)) + for user_rule in module_rules: + user_rules[user_rule.name] = user_rule + + return RuleSet(user_rules.values()) + class _Cache: def __init__(self, loader: SqlMeshLoader, config_path: Path): self._loader = loader diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index 268abad966..c5d29bb6a8 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -22,6 +22,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.snapshot import SnapshotId + from sqlmesh.core.linter.definition import Linter T = t.TypeVar("T") @@ -80,8 +81,9 @@ class OptimizedQueryCache: path: The path to the cache folder. """ - def __init__(self, path: Path): + def __init__(self, path: Path, linters: t.Optional[t.Dict[str, Linter]] = None): self.path = path + self.linters = linters self._file_cache: FileCache[OptimizedQueryCacheEntry] = FileCache( path, prefix="optimized_query" ) @@ -130,6 +132,14 @@ def put(self, model: Model) -> t.Optional[str]: def _put(self, name: str, model: SqlModel) -> None: optimized_query = model.render_query() + + if self.linters: + linter = self.linters.get(model.project) + if linter and linter.rules.keys() & model.violated_rules_for_query.keys(): + # Do not cache the optimized query if the renderer came across lint errors + # Note: The ordering of the intersection check matters here + return None + new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query) self._file_cache.put(name, value=new_entry) @@ -140,7 +150,6 @@ def _entry_name(model: SqlModel) -> str: hash_data.append(str([gen(d) for d in model.macro_definitions])) hash_data.append(str([(k, v) for k, v in model.sorted_python_env])) hash_data.extend(model.jinja_macros.data_hash_values) - hash_data.extend(str(model.validate_query)) return f"{model.name}_{crc32(hash_data)}" diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 4122d6d2ff..d4bed645a2 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -321,7 +321,6 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ "allow_partials", "enabled", "optimize_query", - "validate_query", mode="before", check_fields=False, )(parse_bool) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index d3d212e0e1..9f1e2ad5b1 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -64,6 +64,7 @@ from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF + from sqlmesh.core.linter.rule import Rule from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot from sqlmesh.utils.jinja import MacroReference @@ -237,7 +238,7 @@ def render_definition( "enabled", "inline_audits", "optimize_query", - "validate_query", + "ignored_rules_", ): expressions.append( exp.Property( @@ -980,12 +981,6 @@ def validate_definition(self) -> None: self._path, ) - if self.validate_query: - raise_config_error( - "Query validation can only be enabled for SQL models", - self._path, - ) - if isinstance(self.kind, CustomKind): from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type_or_raise @@ -1087,7 +1082,6 @@ def metadata_hash(self) -> str: self.project, str(self.allow_partials), gen(self.session_properties_) if self.session_properties_ else None, - str(self.validate_query) if self.validate_query is not None else None, *[gen(g) for g in self.grains], ] @@ -1228,6 +1222,10 @@ def _is_time_column_in_partitioned_by(self) -> bool: col for expr in self.partitioned_by_ for col in expr.find_all(exp.Column) } + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + return {} + class SqlModel(_Model): """The model definition which relies on a SQL query to fetch the data. @@ -1287,6 +1285,7 @@ def render_query( engine_adapter=engine_adapter, **kwargs, ) + return query def render_definition( @@ -1475,7 +1474,6 @@ def _query_renderer(self) -> QueryRenderer: default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, optimize_query=self.optimize_query, - validate_query=self.validate_query, ) @property @@ -1491,6 +1489,10 @@ def _data_hash_values(self) -> t.List[str]: def _additional_metadata(self) -> t.List[str]: return [*super()._additional_metadata, gen(self.query)] + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + return self._query_renderer._violated_rules + class SeedModel(_Model): """The model definition which uses a pre-built static dataset to source the data from. @@ -2325,7 +2327,6 @@ def _create_model( defaults = {k: v for k, v in (defaults or {}).items() if k in klass.all_fields()} if not issubclass(klass, SqlModel): defaults.pop("optimize_query", None) - defaults.pop("validate_query", None) statements = [] diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 3b89b63441..3c62d3db79 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -10,6 +10,7 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core import dialect as d +from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.dialect import normalize_model_name, extract_func_call from sqlmesh.core.model.common import ( bool_validator, @@ -76,7 +77,9 @@ class ModelMeta(_Node): physical_version: t.Optional[str] = None gateway: t.Optional[str] = None optimize_query: t.Optional[bool] = None - validate_query: t.Optional[bool] = None + ignored_rules_: t.Optional[t.Set[str]] = Field( + default=None, exclude=True, alias="ignored_rules" + ) _bool_validator = bool_validator _model_kind_validator = model_kind_validator @@ -285,6 +288,10 @@ def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expressi return refs + @field_validator("ignored_rules_", mode="before") + def ignored_rules_validator(cls, vs: t.Any) -> t.Any: + return LinterConfig._validate_rules(vs) + @model_validator(mode="before") def _pre_root_validator(cls, data: t.Any) -> t.Any: if not isinstance(data, dict): @@ -459,3 +466,7 @@ def fqn(self) -> str: @property def on_destructive_change(self) -> OnDestructiveChange: return getattr(self.kind, "on_destructive_change", OnDestructiveChange.ALLOW) + + @property + def ignored_rules(self) -> t.Set[str]: + return self.ignored_rules_ or set() diff --git a/sqlmesh/core/model/schema.py b/sqlmesh/core/model/schema.py index 86c628f610..c8efc76cee 100644 --- a/sqlmesh/core/model/schema.py +++ b/sqlmesh/core/model/schema.py @@ -18,15 +18,19 @@ from sqlmesh.core.model.definition import Model from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG + from sqlmesh.core.linter.definition import Linter def update_model_schemas( dag: DAG[str], models: UniqueKeyDict[str, Model], context_path: Path, + linters: t.Optional[t.Dict[str, Linter]] = None, ) -> None: schema = MappingSchema(normalize=False) - optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE) + optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache( + context_path / c.CACHE, linters + ) if c.MAX_FORK_WORKERS == 1: _update_model_schemas_sequential(dag, models, schema, optimized_query_cache) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 3857b56d9f..a36992d073 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -30,6 +30,7 @@ from sqlglot.dialects.dialect import DialectType from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot + from sqlmesh.core.linter.rule import Rule logger = logging.getLogger(__name__) @@ -51,7 +52,6 @@ def __init__( model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, optimize_query: t.Optional[bool] = True, - validate_query: t.Optional[bool] = False, ): self._expression = expression self._dialect = dialect @@ -67,7 +67,7 @@ def __init__( self._cache: t.List[t.Optional[exp.Expression]] = [] self._model_fqn = model_fqn self._optimize_query_flag = optimize_query is not False - self._validate_query = validate_query + self._violated_rules: t.Dict[type[Rule], t.Any] = {} def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -531,7 +531,7 @@ def render( query = self._optimize_query(query, deps) - if should_cache: + if should_cache and not self._violated_rules: self._optimized_cache = query if needs_optimization: @@ -559,6 +559,11 @@ def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = super().update_cache(expression) def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: + from sqlmesh.core.linter.rules.builtin import ( + AmbiguousOrInvalidColumn, + InvalidSelectStarExpansion, + ) + # We don't want to normalize names in the schema because that's handled by the optimizer original = query missing_deps = set() @@ -571,20 +576,8 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: missing_deps.add(dep) if self._model_fqn and not should_optimize and any(s.is_star for s in query.selects): - from sqlmesh.core.console import get_console - deps = ", ".join(f"'{dep}'" for dep in sorted(missing_deps)) - - warning = ( - f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " - "Run `sqlmesh create_external_models` and / or make sure that the model " - f"'{self._model_fqn}' can be rendered at parse time." - ) - - if self._validate_query: - raise_config_error(warning, self._path) - - get_console().log_warning(warning) + self._violated_rules[InvalidSelectStarExpansion] = deps try: if should_optimize: @@ -603,18 +596,10 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: ) ) except SqlglotError as ex: - from sqlmesh.core.console import get_console - - warning = ( - f"{ex} for model '{self._model_fqn}', the column may not exist or is ambiguous." - ) - - if self._validate_query: - raise_config_error(warning, self._path) + self._violated_rules[AmbiguousOrInvalidColumn] = ex query = original - get_console().log_warning(warning) except Exception as ex: raise_config_error( f"Failed to optimize query, please file an issue at https://github.com/TobikoData/sqlmesh/issues/new. {ex}", diff --git a/sqlmesh/migrations/v0074_remove_validate_query.py b/sqlmesh/migrations/v0074_remove_validate_query.py new file mode 100644 index 0000000000..3c637676d8 --- /dev/null +++ b/sqlmesh/migrations/v0074_remove_validate_query.py @@ -0,0 +1,82 @@ +"""Remove validate_query from existing snapshots.""" + +import json + +import pandas as pd +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils.migration import blob_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + parsed_snapshot["node"].pop("validate_query", None) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index d54b4ded9b..66beb60d78 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -23,6 +23,7 @@ EnvironmentSuffixTarget, ModelDefaultsConfig, SnowflakeConnectionConfig, + LinterConfig, load_configs, ) from sqlmesh.core.context import Context @@ -30,7 +31,9 @@ from sqlmesh.core.dialect import parse, schema_ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements -from sqlmesh.core.model import load_sql_based_model, model +from sqlmesh.core.macros import MacroEvaluator +from sqlmesh.core.model import load_sql_based_model, model, SqlModel +from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder @@ -1635,3 +1638,132 @@ def test_environment_statements_dialect(tmp_path: Path): with pytest.raises(ParseError, match=r"Invalid expression / Unexpected token*"): config.model_defaults.dialect = "duckdb" ctx = Context(paths=[tmp_path], config=config) + + +@pytest.mark.slow +def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: + cfg = LinterConfig(enabled=True, rules="ALL") + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), + paths=tmp_path, + ) + + # Case: Ensure load DOES NOT work if linter is enabled + for query in ("SELECT * FROM tbl", "SELECT t.* FROM tbl"): + with pytest.raises( + ConfigError, + match=r".*Query should not contain SELECT.*", + ): + ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) + + error_model = load_sql_based_model(d.parse("MODEL (name test2); SELECT col")) + with pytest.raises( + ConfigError, + match=r""".*Column '"col"' could not be resolved for model.*""", + ): + ctx.upsert_model(error_model) + + # Case: Ensure optimized query is not cached if the model did not pass linting + cache = OptimizedQueryCache(tmp_path / c.CACHE, linters=ctx._linters) + + error_model = t.cast(SqlModel, error_model) + assert error_model._query_renderer._optimized_cache is None + assert not cache._file_cache.exists(cache._entry_name(error_model)) + + # Case: Ensure NoSelectStar only raises for top-level SELECTs, new model shouldn't raise + # and thus should also be cached + model2 = load_sql_based_model(d.parse("MODEL (name test); SELECT col FROM (SELECT * FROM tbl)")) + ctx.upsert_model(model2) + + model2 = t.cast(SqlModel, model2) + assert cache._file_cache.exists(cache._entry_name(model2)) + + # Case: Ensure load WORKS if linter is enabled but the rules are not + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test); SELECT * FROM (SELECT 1 AS col);", + ) + + ignore_or_warn_cfgs = [ + LinterConfig(enabled=True, warn_rules=["noselectstar"]), + LinterConfig(enabled=True, ignored_rules=["noselectstar"]), + ] + for cfg in ignore_or_warn_cfgs: + ctx.config.linter = cfg + ctx.load() + + # Case: Ensure load DOES NOT work if LinterConfig has overlapping rules + with pytest.raises( + ConfigError, + match=r"Rules cannot simultaneously warn and raise an error: \[noselectstar\]", + ): + ctx.config.linter = LinterConfig( + enabled=True, rules=["noselectstar"], warn_rules=["noselectstar"] + ) + ctx.load() + + # Case: Ensure model attribute overrides global config + ctx.config.linter = LinterConfig(enabled=True, rules=["noselectstar"]) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test, ignored_rules ['ALL']); SELECT * FROM (SELECT 1 AS col);", + ) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test2.sql"), + "MODEL(name test2, ignored_rules ['noselectstar']); SELECT * FROM (SELECT 1 AS col);", + ) + + ctx.load() + + # Case: Ensure we can load & use the user-defined rules + sushi_context.config.linter = LinterConfig(enabled=True, rules=["aLl"]) + sushi_context.upsert_model( + load_sql_based_model( + d.parse("MODEL (name sushi.test); SELECT col FROM (SELECT * FROM tbl)"), + default_catalog="memory", + ) + ) + with pytest.raises( + ConfigError, + match=r".*All models should have an owner.*", + ): + sushi_context.load() + + # Case: Ensure the Linter also picks up Python model violations + @model(name="memory.sushi.model3", is_sql=True, kind="full", dialect="snowflake") + def model3_entrypoint(evaluator: MacroEvaluator) -> str: + return "select * from model1" + + model3 = model.get_registry()["memory.sushi.model3"].model( + module_path=Path("."), path=Path(".") + ) + + @model(name="memory.sushi.model4", columns={"col": "int"}) + def model4_entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + model4 = model.get_registry()["memory.sushi.model4"].model( + module_path=Path("."), path=Path(".") + ) + + for python_model in (model3, model4): + with pytest.raises( + ConfigError, + match=r".*All models should have an owner.*", + ): + sushi_context.upsert_model(python_model) + + @model(name="memory.sushi.model5", columns={"col": "int"}, owner="test") + def model5_entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + model5 = model.get_registry()["memory.sushi.model5"].model( + module_path=Path("."), path=Path(".") + ) + + sushi_context.upsert_model(model5) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 43bbdfe231..874e7b107f 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4,6 +4,7 @@ from collections import Counter from datetime import timedelta from unittest import mock +from unittest.mock import patch import numpy as np import pandas as pd @@ -25,7 +26,7 @@ ModelDefaultsConfig, DuckDBConnectionConfig, ) -from sqlmesh.core.console import Console +from sqlmesh.core.console import Console, get_console from sqlmesh.core.context import Context from sqlmesh.core.config.categorizer import CategorizerConfig from sqlmesh.core.engine_adapter import EngineAdapter @@ -4266,7 +4267,17 @@ def test_auto_categorization(sushi_context: Context): def test_multi(mocker): - context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + context = Context( + paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory", load=False + ) + + with patch.object(get_console(), "log_warning") as mock_logger: + context.load() + warnings = mock_logger.call_args[0][0] + repo1_path, repo2_path = context.configs.keys() + assert f"Linter warnings for {repo1_path}" in warnings + assert f"Linter warnings for {repo2_path}" not in warnings + assert ( context.render("bronze.a").sql() == '''SELECT 1 AS "col_a", 'b' AS "col_b", 1 AS "one", 'repo_1' AS "dup"''' diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 06a8fd8044..a48c0ad19d 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -27,6 +27,7 @@ GatewayConfig, NameInferenceConfig, ModelDefaultsConfig, + LinterConfig, ) from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.dialect import parse @@ -280,7 +281,7 @@ def test_model_validation_union_query(): model.validate_definition() -def test_model_qualification(): +def test_model_qualification(tmp_path: Path): with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ @@ -293,11 +294,14 @@ def test_model_qualification(): """ ) - model = load_sql_based_model(expressions) - model.render_query(needs_optimization=True) + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + ctx.upsert_model(load_sql_based_model(expressions)) + assert ( - mock_logger.call_args[0][0] - == """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous.""" + """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous.""" + in mock_logger.call_args[0][0] ) @@ -2726,7 +2730,7 @@ def runtime_macro(**kwargs) -> None: model.validate_definition() -def test_update_schema(): +def test_update_schema(tmp_path: Path): expressions = d.parse( """ MODEL (name db.table); @@ -2743,10 +2747,14 @@ def test_update_schema(): model.update_schema(schema) assert model.mapping_schema == {'"table_a"': {"a": "INT"}} + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) with patch.object(get_console(), "log_warning") as mock_logger: - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg( - '"db"."table"', ('"table_b"',) + ctx.upsert_model(model) + assert ( + missing_schema_warning_msg('"db"."table"', ('"table_b"',)) + in mock_logger.call_args[0][0] ) schema.add_table('"table_b"', {"b": exp.DataType.build("int")}) @@ -2758,7 +2766,7 @@ def test_update_schema(): model.render_query(needs_optimization=True) -def test_missing_schema_warnings(): +def test_missing_schema_warnings(tmp_path: Path): full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2775,6 +2783,10 @@ def test_missing_schema_warnings(): console = get_console() + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + # star, no schema, no deps with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM (SELECT 1 a) x")) @@ -2792,14 +2804,15 @@ def test_missing_schema_warnings(): with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(partial_schema) - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"b"',)) + ctx.upsert_model(model) + + assert missing_schema_warning_msg('"test"', ('"b"',)) in mock_logger.call_args[0][0] # star, no schema with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM b JOIN a")) - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"a"', '"b"')) + ctx.upsert_model(model) + assert missing_schema_warning_msg('"test"', ('"a"', '"b"')) in mock_logger.call_args[0][0] # no star, full schema with patch.object(console, "log_warning") as mock_logger: @@ -3502,7 +3515,6 @@ def test_project_level_properties(sushi_context): enabled=False, allow_partials=True, interval_unit="quarter_hour", - validate_query=True, optimize_query=True, cron="@hourly", ) @@ -3529,7 +3541,6 @@ def test_project_level_properties(sushi_context): assert model.allow_partials assert model.interval_unit == IntervalUnit.QUARTER_HOUR assert model.optimize_query - assert model.validate_query assert model.cron == "@hourly" assert model.session_properties == { @@ -3580,7 +3591,6 @@ def test_project_level_properties_python_model(): "enabled": False, "allow_partials": True, "interval_unit": "quarter_hour", - "validate_query": True, "optimize_query": True, } @@ -3607,7 +3617,6 @@ def python_model_prop(context, **kwargs): # Even if in the project wide defaults these are ignored for python models assert not m.optimize_query - assert not m.validate_query assert not m.enabled assert m.allow_partials @@ -7647,30 +7656,38 @@ def model_with_virtual_statements(context, **kwargs): def test_compile_time_checks(tmp_path: Path, assert_exp_eq): + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) + # Strict SELECT * expansion + linter_cfg = LinterConfig( + enabled=True, rules=["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] + ) + ctx.config.linter = linter_cfg strict_query = d.parse( """ MODEL ( name test, - validate_query True, ); SELECT * FROM tbl """ ) + ctx.load() + with pytest.raises( ConfigError, match=r".*cannot be expanded due to missing schema.*", ): - load_sql_based_model(strict_query).render_query() + ctx.upsert_model(load_sql_based_model(strict_query)) # Strict column resolution strict_query = d.parse( """ MODEL ( name test, - validate_query True, ); SELECT foo @@ -7681,88 +7698,7 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): ConfigError, match=r"""Column '"foo"' could not be resolved for model.*""", ): - load_sql_based_model(strict_query).render_query() - - # Non-strict model with strict defaults raises error, otherwise can still render - strict_default = ModelDefaultsConfig(validate_query=True).dict() - query = d.parse( - """ - MODEL ( - name test, - ); - - SELECT * FROM tbl - """ - ) - - with pytest.raises( - ConfigError, - match=r".*cannot be expanded due to missing schema.*", - ): - load_sql_based_model(query, defaults=strict_default).render_query() - - assert_exp_eq(load_sql_based_model(query).render_query(), 'SELECT * FROM "tbl" AS "tbl"') - - # Ensure plan works for valid queries & cache is invalidated if strict changes - context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) - - query = d.parse( - """ - MODEL ( - name db.test, - validate_query True, - ); - - SELECT 1 AS col - """ - ) - - context.upsert_model(load_sql_based_model(query, default_catalog=context.default_catalog)) - context.plan(auto_apply=True, no_prompts=True) - - context.upsert_model("db.test", validate_query=False) - plan = context.plan(no_prompts=True, auto_apply=True) - - snapshots = list(plan.snapshots.values()) - assert len(snapshots) == 1 - - snapshot = snapshots[0] - assert len(snapshot.previous_versions) == 1 - assert snapshot.change_category == SnapshotChangeCategory.METADATA - - # Ensure non-SQLModels raise if strict mode is set to True - seed_path = tmp_path / "seed.csv" - model_kind = SeedKind(path=str(seed_path.absolute())) - with open(seed_path, "w", encoding="utf-8") as fd: - fd.write( - """ -col_a,col_b,col_c -1,text_a,1.0""" - ) - model = create_seed_model("test_db.test_seed_model", model_kind, validate_query=True) - context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) - - with pytest.raises( - ConfigError, - match=r"Query validation can only be enabled for SQL models at", - ): - context.upsert_model(model) - context.plan(auto_apply=True, no_prompts=True) - - model = create_seed_model("test_db.test_seed_model", model_kind, validate_query=False) - context.upsert_model(model) - context.plan(auto_apply=True, no_prompts=True) - - # Ensure strict defaults don't break all non SQL models to which they weren't applicable in the first place - seed_strict_defaults = create_seed_model( - "test_db.test_seed_model", model_kind, defaults=strict_default - ) - external_strict_defaults = create_external_model( - "test_db.test_external_model", columns={"a": "int", "limit": "int"}, defaults=strict_default - ) - context.upsert_model(seed_strict_defaults) - context.upsert_model(external_strict_defaults) - context.plan(auto_apply=True, no_prompts=True) + ctx.upsert_model(load_sql_based_model(strict_query)) def test_partition_interval_unit(): @@ -8023,3 +7959,28 @@ def test_missing_column_data_in_columns_key(): ) with pytest.raises(ConfigError, match="Missing data type for column 'culprit'."): load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + +def test_ignored_rules_serialization(): + expressions = d.parse( + """ + MODEL( + name test_model, + ignored_rules ['foo', 'bar'] + ); + + SELECT * FROM tbl; + """, + default_dialect="bigquery", + ) + + model = load_sql_based_model(expressions) + + model_json = model.json() + model_json_parsed = json.loads(model_json) + + assert "ignored_rules" not in model_json_parsed + assert "ignored_rules_" not in model_json_parsed + + deserialized_model = SqlModel.parse_raw(model_json) + assert deserialized_model.dict() == model.dict() diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 06cd7c159b..5b622d264c 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -860,7 +860,7 @@ def test_fingerprint(model: Model, parent_model: Model): original_fingerprint = SnapshotFingerprint( data_hash="1312415267", - metadata_hash="2906564841", + metadata_hash="221611364", ) assert fingerprint == original_fingerprint @@ -921,7 +921,7 @@ def test_fingerprint_seed_model(): expected_fingerprint = SnapshotFingerprint( data_hash="1909791099", - metadata_hash="1153541408", + metadata_hash="3403817841", ) model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) @@ -960,7 +960,7 @@ def test_fingerprint_jinja_macros(model: Model): ) original_fingerprint = SnapshotFingerprint( data_hash="923305614", - metadata_hash="2906564841", + metadata_hash="221611364", ) fingerprint = fingerprint_from_node(model, nodes={}) diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 26792ec296..b19869a0e5 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -81,6 +81,10 @@ def snapshots(make_snapshot: t.Callable) -> t.List[Snapshot]: ] +def compare_snapshot_intervals(x: SnapshotIntervals) -> str: + return x.identifier or "" + + def promote_snapshots( state_sync: EngineAdapterStateSync, snapshots: t.List[Snapshot], @@ -1512,24 +1516,27 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=snapshot.identifier, - version=snapshot.version, - dev_version=snapshot.dev_version, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - dev_intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) # Delete the expired snapshot assert state_sync.delete_expired_snapshots() == [ @@ -1547,25 +1554,28 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - # The intervals of the old snapshot is preserved with the null identifier - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=None, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - # The intervals of the new snapshot has identifier - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + # The intervals of the old snapshot is preserved with the null identifier + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + # The intervals of the new snapshot has identifier + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( @@ -1625,24 +1635,27 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=snapshot.identifier, - version=snapshot.version, - dev_version=snapshot.dev_version, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], + ), + ], + key=compare_snapshot_intervals, + ) # Delete the expired snapshot assert state_sync.delete_expired_snapshots() == [] @@ -1660,30 +1673,33 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=None, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], + ), + ], + key=compare_snapshot_intervals, + ) def test_compact_intervals_after_cleanup( From 52fdbdc468f55cba9dc42ceb3458b680e683e9f4 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 5 Mar 2025 19:19:43 +0200 Subject: [PATCH 2/8] PR Feedback 12 --- sqlmesh/core/console.py | 25 +++++++++++++++++++++++++ sqlmesh/core/context.py | 23 +++++++++++++++++------ sqlmesh/core/linter/definition.py | 12 ++++++------ tests/core/test_context.py | 25 ++++++++----------------- tests/core/test_integration.py | 7 ++++++- tests/core/test_model.py | 28 ++++++++++++++++++++-------- 6 files changed, 82 insertions(+), 38 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index adb8f20a54..9c889e8df6 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -26,6 +26,8 @@ from rich.tree import Tree from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.linter.rule import RuleViolation +from sqlmesh.core.model import Model from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, @@ -319,6 +321,12 @@ def _limit_model_names(self, tree: Tree, verbose: bool = False) -> Tree: ] return tree + @abc.abstractmethod + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + """Prints all linter violations depending on their severity""" + class NoopConsole(Console): def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: @@ -481,6 +489,11 @@ def show_row_diff( def print_environments(self, environments_summary: t.Dict[str, int]) -> None: pass + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + pass + def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> Progress: return Progress( @@ -1548,6 +1561,18 @@ def _snapshot_change_choices( } return labeled_choices + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + severity = "errors" if is_error else "warnings" + violations_msg = "\n".join(f" - {violation}" for violation in violations) + msg = f"Linter {severity} for {model._path}:\n{violations_msg}" + + if is_error: + self.log_error(msg) + else: + self.log_warning(msg) + def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: """Helper function to add a widget to a layout widget. diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 9fd6f0366b..9ba48af743 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -505,9 +505,7 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: model.validate_definition() - # Linter may be `None` if the context is not loaded yet - if linter := self._linters.get(model.project): - linter.lint_model(model) + self.lint_models(model) return model @@ -646,12 +644,12 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: linters=self._linters, ) - for model in self.models.values(): + models = self.models.values() + for model in models: # The model definition can be validated correctly only after the schema is set. model.validate_definition() - if linter := self._linters.get(model.project): - linter.lint_model(model) + self.lint_models(*models) duplicates = set(self._models) & set(self._standalone_audits) if duplicates: @@ -2379,6 +2377,19 @@ def _get_models_for_interval_end( ) return models_for_interval_end + def lint_models(self, *models: Model) -> None: + found_error = False + + for model in models: + # Linter may be `None` if the context is not loaded yet + if linter := self._linters.get(model.project): + found_error = linter.lint_model(model) or found_error + + if found_error: + raise ConfigError( + "Linter detected errors in the code. Please fix them before proceeding." + ) + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py index 9334e40b5f..27d33fd5e1 100644 --- a/sqlmesh/core/linter/definition.py +++ b/sqlmesh/core/linter/definition.py @@ -50,9 +50,9 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: return Linter(config.enabled, all_rules, rules, warn_rules) - def lint_model(self, model: Model) -> None: + def lint_model(self, model: Model) -> bool: if not self.enabled: - return + return False ignored_rules = select_rules(self.all_rules, model.ignored_rules) @@ -63,10 +63,10 @@ def lint_model(self, model: Model) -> None: warn_violations = warn_rules.check_model(model) if warn_violations: - warn_msg = "\n".join(f" - {warn_violation}" for warn_violation in warn_violations) - get_console().log_warning(f"Linter warnings for {model._path}:\n{warn_msg}") + get_console().show_linter_violations(warn_violations, model) if error_violations: - error_msg = "\n".join(f" - {error_violations}" for error_violations in error_violations) + get_console().show_linter_violations(error_violations, model, is_error=True) + return True - raise_config_error(f"Linter error for {model._path}:\n{error_msg}") + return False diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 66beb60d78..99c93f8fb4 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1641,26 +1641,22 @@ def test_environment_statements_dialect(tmp_path: Path): @pytest.mark.slow -def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: +def test_model_linting(tmp_path: pathlib.Path, sushi_context, capsys) -> None: cfg = LinterConfig(enabled=True, rules="ALL") ctx = Context( config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), paths=tmp_path, ) + config_err = "Linter detected errors in the code. Please fix them before proceeding." + # Case: Ensure load DOES NOT work if linter is enabled for query in ("SELECT * FROM tbl", "SELECT t.* FROM tbl"): - with pytest.raises( - ConfigError, - match=r".*Query should not contain SELECT.*", - ): + with pytest.raises(ConfigError, match=config_err): ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) error_model = load_sql_based_model(d.parse("MODEL (name test2); SELECT col")) - with pytest.raises( - ConfigError, - match=r""".*Column '"col"' could not be resolved for model.*""", - ): + with pytest.raises(ConfigError, match=config_err): ctx.upsert_model(error_model) # Case: Ensure optimized query is not cached if the model did not pass linting @@ -1728,10 +1724,8 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: default_catalog="memory", ) ) - with pytest.raises( - ConfigError, - match=r".*All models should have an owner.*", - ): + + with pytest.raises(ConfigError, match=config_err): sushi_context.load() # Case: Ensure the Linter also picks up Python model violations @@ -1752,10 +1746,7 @@ def model4_entrypoint(context, **kwargs): ) for python_model in (model3, model4): - with pytest.raises( - ConfigError, - match=r".*All models should have an owner.*", - ): + with pytest.raises(ConfigError, match=config_err): sushi_context.upsert_model(python_model) @model(name="memory.sushi.model5", columns={"col": "int"}, owner="test") diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 874e7b107f..17122c146d 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -26,7 +26,7 @@ ModelDefaultsConfig, DuckDBConnectionConfig, ) -from sqlmesh.core.console import Console, get_console +from sqlmesh.core.console import Console, set_console, get_console, TerminalConsole from sqlmesh.core.context import Context from sqlmesh.core.config.categorizer import CategorizerConfig from sqlmesh.core.engine_adapter import EngineAdapter @@ -4267,6 +4267,9 @@ def test_auto_categorization(sushi_context: Context): def test_multi(mocker): + orig_console = get_console() + set_console(TerminalConsole()) + context = Context( paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory", load=False ) @@ -4335,6 +4338,8 @@ def test_multi(mocker): "CREATE TABLE IF NOT EXISTS after_1 AS select @dup()" ] + set_console(orig_console) + def test_multi_dbt(mocker): context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"]) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a48c0ad19d..49d10f303a 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -14,6 +14,7 @@ from sqlglot.schema import MappingSchema from sqlmesh.cli.example_project import init_example_project, ProjectTemplate from sqlmesh.core.model.kind import TimeColumn, ModelKindName +from sqlmesh.core.console import set_console, get_console, TerminalConsole from sqlmesh import CustomMaterialization, CustomKind from pydantic import model_validator @@ -282,6 +283,9 @@ def test_model_validation_union_query(): def test_model_qualification(tmp_path: Path): + orig_console = get_console() + set_console(TerminalConsole()) + with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ @@ -304,6 +308,8 @@ def test_model_qualification(tmp_path: Path): in mock_logger.call_args[0][0] ) + set_console(orig_console) + @pytest.mark.parametrize( "partition_by_input, partition_by_output, output_dialect, expected_exception", @@ -2731,6 +2737,9 @@ def runtime_macro(**kwargs) -> None: def test_update_schema(tmp_path: Path): + orig_console = get_console() + set_console(TerminalConsole()) + expressions = d.parse( """ MODEL (name db.table); @@ -2765,8 +2774,13 @@ def test_update_schema(tmp_path: Path): } model.render_query(needs_optimization=True) + set_console(orig_console) + def test_missing_schema_warnings(tmp_path: Path): + orig_console = get_console() + set_console(TerminalConsole()) + full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2840,6 +2854,8 @@ def test_missing_schema_warnings(tmp_path: Path): model.render_query(needs_optimization=True) mock_logger.assert_not_called() + set_console(orig_console) + def test_user_provided_depends_on(): for l_delim, r_delim in (("(", ")"), ("[", "]")): @@ -7660,6 +7676,8 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path ) + cfg_err = "Linter detected errors in the code. Please fix them before proceeding." + # Strict SELECT * expansion linter_cfg = LinterConfig( enabled=True, rules=["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] @@ -7677,10 +7695,7 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): ctx.load() - with pytest.raises( - ConfigError, - match=r".*cannot be expanded due to missing schema.*", - ): + with pytest.raises(ConfigError, match=cfg_err): ctx.upsert_model(load_sql_based_model(strict_query)) # Strict column resolution @@ -7694,10 +7709,7 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): """ ) - with pytest.raises( - ConfigError, - match=r"""Column '"foo"' could not be resolved for model.*""", - ): + with pytest.raises(ConfigError, match=cfg_err): ctx.upsert_model(load_sql_based_model(strict_query)) From 0f6250d9b8bde0ab6cb79b2fe91d975bbe6d29e1 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 10:30:34 +0200 Subject: [PATCH 3/8] Add rule violations to OptimizedQueryCache --- sqlmesh/core/model/cache.py | 21 +++++++++---------- sqlmesh/core/renderer.py | 14 ++++++++----- tests/core/test_context.py | 40 +++++++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index c5d29bb6a8..9fcebf3a41 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -23,6 +23,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.snapshot import SnapshotId from sqlmesh.core.linter.definition import Linter + from sqlmesh.core.linter.rule import Rule T = t.TypeVar("T") @@ -72,6 +73,7 @@ def get_or_load( @dataclass class OptimizedQueryCacheEntry: optimized_rendered_query: t.Optional[exp.Expression] + renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]] class OptimizedQueryCache: @@ -102,15 +104,11 @@ def with_optimized_query(self, model: Model, name: t.Optional[str] = None) -> bo cache_entry = self._file_cache.get(name) if cache_entry: try: - if cache_entry.optimized_rendered_query: - model._query_renderer.update_cache( - cache_entry.optimized_rendered_query, optimized=True - ) - else: - # If the optimized rendered query is None, then there are likely adapter calls in the query - # that prevent us from rendering it at load time. This means that we can safely set the - # unoptimized cache to None as well to prevent attempts to render it downstream. - model._query_renderer.update_cache(None, optimized=False) + # If the optimized rendered query is None, then there are likely adapter calls in the query + # that prevent us from rendering it at load time. This means that we can safely set the + # unoptimized cache to None as well to prevent attempts to render it downstream. + optimized = cache_entry.optimized_rendered_query is not None + model._query_renderer.update_cache(cache_entry, optimized=optimized) return True except Exception as ex: logger.warning("Failed to load a cache entry '%s': %s", name, ex) @@ -140,7 +138,10 @@ def _put(self, name: str, model: SqlModel) -> None: # Note: The ordering of the intersection check matters here return None - new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query) + new_entry = OptimizedQueryCacheEntry( + optimized_rendered_query=optimized_query, + renderer_violations=model.violated_rules_for_query, + ) self._file_cache.put(name, value=new_entry) @staticmethod diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index a36992d073..bd26da5be2 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -29,8 +29,9 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType - from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot from sqlmesh.core.linter.rule import Rule + from sqlmesh.core.model.cache import OptimizedQueryCacheEntry + from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot logger = logging.getLogger(__name__) @@ -247,8 +248,8 @@ def _render( self._cache = resolved_expressions return resolved_expressions - def update_cache(self, expression: t.Optional[exp.Expression]) -> None: - self._cache = [expression] + def update_cache(self, cache_entry: OptimizedQueryCacheEntry) -> None: + self._cache = [cache_entry.optimized_rendered_query] def _resolve_table( self, @@ -550,13 +551,16 @@ def render( return query - def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = False) -> None: + def update_cache(self, cache_entry: OptimizedQueryCacheEntry, optimized: bool = False) -> None: + expression = cache_entry.optimized_rendered_query if optimized: if not isinstance(expression, exp.Query): raise SQLMeshError(f"Expected a Query but got: {expression}") self._optimized_cache = expression else: - super().update_cache(expression) + super().update_cache(cache_entry) + + self._violated_rules = cache_entry.renderer_violations or {} def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: from sqlmesh.core.linter.rules.builtin import ( diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 99c93f8fb4..c44d228e1a 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -27,7 +27,7 @@ load_configs, ) from sqlmesh.core.context import Context -from sqlmesh.core.console import create_console +from sqlmesh.core.console import create_console, get_console, set_console, TerminalConsole from sqlmesh.core.dialect import parse, schema_ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements @@ -1641,7 +1641,10 @@ def test_environment_statements_dialect(tmp_path: Path): @pytest.mark.slow -def test_model_linting(tmp_path: pathlib.Path, sushi_context, capsys) -> None: +def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: + orig_console = get_console() + set_console(TerminalConsole()) + cfg = LinterConfig(enabled=True, rules="ALL") ctx = Context( config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), @@ -1655,7 +1658,7 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context, capsys) -> None: with pytest.raises(ConfigError, match=config_err): ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) - error_model = load_sql_based_model(d.parse("MODEL (name test2); SELECT col")) + error_model = load_sql_based_model(d.parse("MODEL (name test); SELECT col")) with pytest.raises(ConfigError, match=config_err): ctx.upsert_model(error_model) @@ -1668,12 +1671,39 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context, capsys) -> None: # Case: Ensure NoSelectStar only raises for top-level SELECTs, new model shouldn't raise # and thus should also be cached - model2 = load_sql_based_model(d.parse("MODEL (name test); SELECT col FROM (SELECT * FROM tbl)")) + model2 = load_sql_based_model( + d.parse("MODEL (name test2); SELECT col FROM (SELECT * FROM tbl)") + ) ctx.upsert_model(model2) model2 = t.cast(SqlModel, model2) assert cache._file_cache.exists(cache._entry_name(model2)) + # Case: Ensure renderer violations are found again even if the optimized query is cached + ctx.config.linter = LinterConfig(enabled=True, warn_rules="ALL") + ctx.load() + + def assert_cached_violations_exist(): + cache_entry = cache._file_cache.get(cache._entry_name(error_model)) + assert cache_entry is not None + assert cache_entry.optimized_rendered_query is not None + assert cache_entry.renderer_violations is not None + + for i in range(3): + with patch.object(get_console(), "log_warning") as mock_logger: + if i > 1: + # Model's violations have been cached from the previous upserts + assert_cached_violations_exist() + + ctx.upsert_model(error_model) + assert ( + """Column '"col"' could not be resolved for model '"test"'""" + in mock_logger.call_args[0][0] + ) + + # Model's violations have been cached after the former upsert + assert_cached_violations_exist() + # Case: Ensure load WORKS if linter is enabled but the rules are not create_temp_file( tmp_path, @@ -1758,3 +1788,5 @@ def model5_entrypoint(context, **kwargs): ) sushi_context.upsert_model(model5) + + set_console(orig_console) From 8280209f7344bb66e1b9e45a5d5852512455ed5f Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 18:45:31 +0200 Subject: [PATCH 4/8] Lift caching constraint for error violations --- sqlmesh/core/context.py | 2 -- sqlmesh/core/model/cache.py | 12 +----------- sqlmesh/core/model/schema.py | 6 +----- sqlmesh/core/renderer.py | 2 +- tests/core/test_context.py | 29 ++++++++++++++--------------- 5 files changed, 17 insertions(+), 34 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 9ba48af743..97af9172e2 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -497,7 +497,6 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: self.dag, models=self._models, context_path=self.path, - linters=self._linters, ) if model.dialect: @@ -641,7 +640,6 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self.dag, models=self._models, context_path=self.path, - linters=self._linters, ) models = self.models.values() diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index 9fcebf3a41..a156839e6f 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -22,7 +22,6 @@ if t.TYPE_CHECKING: from sqlmesh.core.snapshot import SnapshotId - from sqlmesh.core.linter.definition import Linter from sqlmesh.core.linter.rule import Rule T = t.TypeVar("T") @@ -83,9 +82,8 @@ class OptimizedQueryCache: path: The path to the cache folder. """ - def __init__(self, path: Path, linters: t.Optional[t.Dict[str, Linter]] = None): + def __init__(self, path: Path): self.path = path - self.linters = linters self._file_cache: FileCache[OptimizedQueryCacheEntry] = FileCache( path, prefix="optimized_query" ) @@ -130,14 +128,6 @@ def put(self, model: Model) -> t.Optional[str]: def _put(self, name: str, model: SqlModel) -> None: optimized_query = model.render_query() - - if self.linters: - linter = self.linters.get(model.project) - if linter and linter.rules.keys() & model.violated_rules_for_query.keys(): - # Do not cache the optimized query if the renderer came across lint errors - # Note: The ordering of the intersection check matters here - return None - new_entry = OptimizedQueryCacheEntry( optimized_rendered_query=optimized_query, renderer_violations=model.violated_rules_for_query, diff --git a/sqlmesh/core/model/schema.py b/sqlmesh/core/model/schema.py index c8efc76cee..86c628f610 100644 --- a/sqlmesh/core/model/schema.py +++ b/sqlmesh/core/model/schema.py @@ -18,19 +18,15 @@ from sqlmesh.core.model.definition import Model from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG - from sqlmesh.core.linter.definition import Linter def update_model_schemas( dag: DAG[str], models: UniqueKeyDict[str, Model], context_path: Path, - linters: t.Optional[t.Dict[str, Linter]] = None, ) -> None: schema = MappingSchema(normalize=False) - optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache( - context_path / c.CACHE, linters - ) + optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE) if c.MAX_FORK_WORKERS == 1: _update_model_schemas_sequential(dag, models, schema, optimized_query_cache) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index bd26da5be2..93ded18453 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -532,7 +532,7 @@ def render( query = self._optimize_query(query, deps) - if should_cache and not self._violated_rules: + if should_cache: self._optimized_cache = query if needs_optimization: diff --git a/tests/core/test_context.py b/tests/core/test_context.py index c44d228e1a..535c826709 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -32,7 +32,7 @@ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements from sqlmesh.core.macros import MacroEvaluator -from sqlmesh.core.model import load_sql_based_model, model, SqlModel +from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName @@ -1645,6 +1645,13 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: orig_console = get_console() set_console(TerminalConsole()) + def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): + model = t.cast(SqlModel, model) + cache_entry = cache._file_cache.get(cache._entry_name(model)) + assert cache_entry is not None + assert cache_entry.optimized_rendered_query is not None + assert cache_entry.renderer_violations is not None + cfg = LinterConfig(enabled=True, rules="ALL") ctx = Context( config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), @@ -1662,12 +1669,10 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: with pytest.raises(ConfigError, match=config_err): ctx.upsert_model(error_model) - # Case: Ensure optimized query is not cached if the model did not pass linting - cache = OptimizedQueryCache(tmp_path / c.CACHE, linters=ctx._linters) + # Case: Ensure error violations are cached if the model did not pass linting + cache = OptimizedQueryCache(tmp_path / c.CACHE) - error_model = t.cast(SqlModel, error_model) - assert error_model._query_renderer._optimized_cache is None - assert not cache._file_cache.exists(cache._entry_name(error_model)) + assert_cached_violations_exist(cache, error_model) # Case: Ensure NoSelectStar only raises for top-level SELECTs, new model shouldn't raise # and thus should also be cached @@ -1679,21 +1684,15 @@ def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: model2 = t.cast(SqlModel, model2) assert cache._file_cache.exists(cache._entry_name(model2)) - # Case: Ensure renderer violations are found again even if the optimized query is cached + # Case: Ensure warning violations are found again even if the optimized query is cached ctx.config.linter = LinterConfig(enabled=True, warn_rules="ALL") ctx.load() - def assert_cached_violations_exist(): - cache_entry = cache._file_cache.get(cache._entry_name(error_model)) - assert cache_entry is not None - assert cache_entry.optimized_rendered_query is not None - assert cache_entry.renderer_violations is not None - for i in range(3): with patch.object(get_console(), "log_warning") as mock_logger: if i > 1: # Model's violations have been cached from the previous upserts - assert_cached_violations_exist() + assert_cached_violations_exist(cache, model2) ctx.upsert_model(error_model) assert ( @@ -1702,7 +1701,7 @@ def assert_cached_violations_exist(): ) # Model's violations have been cached after the former upsert - assert_cached_violations_exist() + assert_cached_violations_exist(cache, model2) # Case: Ensure load WORKS if linter is enabled but the rules are not create_temp_file( From d3a377c3cfb27bfe6ebb43fce54b68a650357d6c Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 19:42:53 +0200 Subject: [PATCH 5/8] Refactor console setting to a decorator --- tests/core/test_context.py | 9 +++------ tests/core/test_integration.py | 10 +++------- tests/core/test_model.py | 20 ++++---------------- tests/utils/test_helpers.py | 16 ++++++++++++++++ 4 files changed, 26 insertions(+), 29 deletions(-) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 535c826709..6395aae886 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -27,7 +27,7 @@ load_configs, ) from sqlmesh.core.context import Context -from sqlmesh.core.console import create_console, get_console, set_console, TerminalConsole +from sqlmesh.core.console import create_console, get_console from sqlmesh.core.dialect import parse, schema_ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements @@ -49,6 +49,7 @@ ) from sqlmesh.utils.errors import ConfigError, SQLMeshError from sqlmesh.utils.metaprogramming import Executable +from tests.utils.test_helpers import use_terminal_console from tests.utils.test_filesystem import create_temp_file @@ -1641,10 +1642,8 @@ def test_environment_statements_dialect(tmp_path: Path): @pytest.mark.slow +@use_terminal_console def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: - orig_console = get_console() - set_console(TerminalConsole()) - def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): model = t.cast(SqlModel, model) cache_entry = cache._file_cache.get(cache._entry_name(model)) @@ -1787,5 +1786,3 @@ def model5_entrypoint(context, **kwargs): ) sushi_context.upsert_model(model5) - - set_console(orig_console) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 17122c146d..71eafcd5a1 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -26,7 +26,7 @@ ModelDefaultsConfig, DuckDBConnectionConfig, ) -from sqlmesh.core.console import Console, set_console, get_console, TerminalConsole +from sqlmesh.core.console import Console, get_console from sqlmesh.core.context import Context from sqlmesh.core.config.categorizer import CategorizerConfig from sqlmesh.core.engine_adapter import EngineAdapter @@ -60,7 +60,7 @@ from sqlmesh.utils.errors import NoChangesPlanError from sqlmesh.utils.pydantic import validate_string from tests.conftest import DuckDBMetadata, SushiDataValidator - +from tests.utils.test_helpers import use_terminal_console if t.TYPE_CHECKING: from sqlmesh import QueryOrDF @@ -4266,10 +4266,8 @@ def test_auto_categorization(sushi_context: Context): ) +@use_terminal_console def test_multi(mocker): - orig_console = get_console() - set_console(TerminalConsole()) - context = Context( paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory", load=False ) @@ -4338,8 +4336,6 @@ def test_multi(mocker): "CREATE TABLE IF NOT EXISTS after_1 AS select @dup()" ] - set_console(orig_console) - def test_multi_dbt(mocker): context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"]) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 49d10f303a..47e8964f3d 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -14,7 +14,6 @@ from sqlglot.schema import MappingSchema from sqlmesh.cli.example_project import init_example_project, ProjectTemplate from sqlmesh.core.model.kind import TimeColumn, ModelKindName -from sqlmesh.core.console import set_console, get_console, TerminalConsole from sqlmesh import CustomMaterialization, CustomKind from pydantic import model_validator @@ -66,6 +65,7 @@ from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable from sqlmesh.core.macros import RuntimeStage +from tests.utils.test_helpers import use_terminal_console def missing_schema_warning_msg(model, deps): @@ -282,10 +282,8 @@ def test_model_validation_union_query(): model.validate_definition() +@use_terminal_console def test_model_qualification(tmp_path: Path): - orig_console = get_console() - set_console(TerminalConsole()) - with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ @@ -308,8 +306,6 @@ def test_model_qualification(tmp_path: Path): in mock_logger.call_args[0][0] ) - set_console(orig_console) - @pytest.mark.parametrize( "partition_by_input, partition_by_output, output_dialect, expected_exception", @@ -2736,10 +2732,8 @@ def runtime_macro(**kwargs) -> None: model.validate_definition() +@use_terminal_console def test_update_schema(tmp_path: Path): - orig_console = get_console() - set_console(TerminalConsole()) - expressions = d.parse( """ MODEL (name db.table); @@ -2774,13 +2768,9 @@ def test_update_schema(tmp_path: Path): } model.render_query(needs_optimization=True) - set_console(orig_console) - +@use_terminal_console def test_missing_schema_warnings(tmp_path: Path): - orig_console = get_console() - set_console(TerminalConsole()) - full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2854,8 +2844,6 @@ def test_missing_schema_warnings(tmp_path: Path): model.render_query(needs_optimization=True) mock_logger.assert_not_called() - set_console(orig_console) - def test_user_provided_depends_on(): for l_delim, r_delim in (("(", ")"), ("[", "]")): diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 48008064a9..586d8abb6d 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -1,7 +1,10 @@ import pytest +from functools import wraps from sqlglot import expressions from sqlglot.optimizer.annotate_types import annotate_types +from sqlmesh.core.console import set_console, get_console, TerminalConsole + from sqlmesh.utils import columns_to_types_all_known @@ -72,3 +75,16 @@ ) def test_columns_to_types_all_known(columns_to_types, expected) -> None: assert columns_to_types_all_known(columns_to_types) == expected + + +def use_terminal_console(func): + @wraps(func) + def test_wrapper(*args, **kwargs): + orig_console = get_console() + try: + set_console(TerminalConsole()) + func(*args, **kwargs) + finally: + set_console(orig_console) + + return test_wrapper From a5a39d834be1dd6861f842a0fce71a3cba5b24a8 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 22:37:36 +0200 Subject: [PATCH 6/8] PR Feedback 13 --- sqlmesh/core/context.py | 3 ++- sqlmesh/core/model/cache.py | 6 +++++- sqlmesh/core/model/definition.py | 1 + sqlmesh/core/renderer.py | 17 ++++++++++------- sqlmesh/utils/errors.py | 4 ++++ tests/core/test_context.py | 10 +++++----- tests/core/test_model.py | 8 ++++---- 7 files changed, 31 insertions(+), 18 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 97af9172e2..273fd9938a 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -123,6 +123,7 @@ PlanError, SQLMeshError, UncategorizedPlanError, + LinterError, ) from sqlmesh.utils.config import print_config from sqlmesh.utils.jinja import JinjaMacroRegistry @@ -2384,7 +2385,7 @@ def lint_models(self, *models: Model) -> None: found_error = linter.lint_model(model) or found_error if found_error: - raise ConfigError( + raise LinterError( "Linter detected errors in the code. Please fix them before proceeding." ) diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index a156839e6f..6bb5c31846 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -106,7 +106,11 @@ def with_optimized_query(self, model: Model, name: t.Optional[str] = None) -> bo # that prevent us from rendering it at load time. This means that we can safely set the # unoptimized cache to None as well to prevent attempts to render it downstream. optimized = cache_entry.optimized_rendered_query is not None - model._query_renderer.update_cache(cache_entry, optimized=optimized) + model._query_renderer.update_cache( + cache_entry.optimized_rendered_query, + cache_entry.renderer_violations, + optimized=optimized, + ) return True except Exception as ex: logger.warning("Failed to load a cache entry '%s': %s", name, ex) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 9f1e2ad5b1..caadeba4cb 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1491,6 +1491,7 @@ def _additional_metadata(self) -> t.List[str]: @property def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + self.render_query() return self._query_renderer._violated_rules diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 93ded18453..5dcf4f4d97 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -30,7 +30,6 @@ from sqlglot.dialects.dialect import DialectType from sqlmesh.core.linter.rule import Rule - from sqlmesh.core.model.cache import OptimizedQueryCacheEntry from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot @@ -248,8 +247,8 @@ def _render( self._cache = resolved_expressions return resolved_expressions - def update_cache(self, cache_entry: OptimizedQueryCacheEntry) -> None: - self._cache = [cache_entry.optimized_rendered_query] + def update_cache(self, expression: t.Optional[exp.Expression]) -> None: + self._cache = [expression] def _resolve_table( self, @@ -551,16 +550,20 @@ def render( return query - def update_cache(self, cache_entry: OptimizedQueryCacheEntry, optimized: bool = False) -> None: - expression = cache_entry.optimized_rendered_query + def update_cache( + self, + expression: t.Optional[exp.Expression], + violated_rules: t.Optional[t.Dict[type[Rule], t.Any]] = None, + optimized: bool = False, + ) -> None: if optimized: if not isinstance(expression, exp.Query): raise SQLMeshError(f"Expected a Query but got: {expression}") self._optimized_cache = expression else: - super().update_cache(cache_entry) + super().update_cache(expression) - self._violated_rules = cache_entry.renderer_violations or {} + self._violated_rules = violated_rules or {} def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: from sqlmesh.core.linter.rules.builtin import ( diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index 4d93ef7cd4..0159c42e3a 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -176,6 +176,10 @@ class MissingDefaultCatalogError(SQLMeshError): pass +class LinterError(SQLMeshError): + pass + + def raise_config_error( msg: str, location: t.Optional[str | Path] = None, diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 6395aae886..eae4e71936 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -47,7 +47,7 @@ to_timestamp, yesterday_ds, ) -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError from sqlmesh.utils.metaprogramming import Executable from tests.utils.test_helpers import use_terminal_console from tests.utils.test_filesystem import create_temp_file @@ -1661,11 +1661,11 @@ def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): # Case: Ensure load DOES NOT work if linter is enabled for query in ("SELECT * FROM tbl", "SELECT t.* FROM tbl"): - with pytest.raises(ConfigError, match=config_err): + with pytest.raises(LinterError, match=config_err): ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) error_model = load_sql_based_model(d.parse("MODEL (name test); SELECT col")) - with pytest.raises(ConfigError, match=config_err): + with pytest.raises(LinterError, match=config_err): ctx.upsert_model(error_model) # Case: Ensure error violations are cached if the model did not pass linting @@ -1753,7 +1753,7 @@ def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): ) ) - with pytest.raises(ConfigError, match=config_err): + with pytest.raises(LinterError, match=config_err): sushi_context.load() # Case: Ensure the Linter also picks up Python model violations @@ -1774,7 +1774,7 @@ def model4_entrypoint(context, **kwargs): ) for python_model in (model3, model4): - with pytest.raises(ConfigError, match=config_err): + with pytest.raises(LinterError, match=config_err): sushi_context.upsert_model(python_model) @model(name="memory.sushi.model5", columns={"col": "int"}, owner="test") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 47e8964f3d..8c6e763f39 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -61,7 +61,7 @@ from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable from sqlmesh.core.macros import RuntimeStage @@ -7659,7 +7659,7 @@ def model_with_virtual_statements(context, **kwargs): ) -def test_compile_time_checks(tmp_path: Path, assert_exp_eq): +def test_compile_time_checks(tmp_path: Path): ctx = Context( config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path ) @@ -7683,7 +7683,7 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): ctx.load() - with pytest.raises(ConfigError, match=cfg_err): + with pytest.raises(LinterError, match=cfg_err): ctx.upsert_model(load_sql_based_model(strict_query)) # Strict column resolution @@ -7697,7 +7697,7 @@ def test_compile_time_checks(tmp_path: Path, assert_exp_eq): """ ) - with pytest.raises(ConfigError, match=cfg_err): + with pytest.raises(LinterError, match=cfg_err): ctx.upsert_model(load_sql_based_model(strict_query)) From 3394afa8d64dce046fcfad2f8e6dcf21e8f7b08d Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 22:52:44 +0200 Subject: [PATCH 7/8] Move violated_rules to QueryRenderer, formatting fixes --- sqlmesh/core/context.py | 12 ++---------- sqlmesh/core/renderer.py | 2 +- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 273fd9938a..2b8302a69b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -494,11 +494,7 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: } ) - update_model_schemas( - self.dag, - models=self._models, - context_path=self.path, - ) + update_model_schemas(self.dag, models=self._models, context_path=self.path) if model.dialect: self._all_dialects.add(model.dialect) @@ -637,11 +633,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._models.update({fqn: model.copy(update={"mapping_schema": {}})}) continue - update_model_schemas( - self.dag, - models=self._models, - context_path=self.path, - ) + update_model_schemas(self.dag, models=self._models, context_path=self.path) models = self.models.values() for model in models: diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 5dcf4f4d97..efbebef761 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -67,7 +67,6 @@ def __init__( self._cache: t.List[t.Optional[exp.Expression]] = [] self._model_fqn = model_fqn self._optimize_query_flag = optimize_query is not False - self._violated_rules: t.Dict[type[Rule], t.Any] = {} def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -443,6 +442,7 @@ class QueryRenderer(BaseExpressionRenderer): def __init__(self, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) self._optimized_cache: t.Optional[exp.Query] = None + self._violated_rules: t.Dict[type[Rule], t.Any] = {} def update_schema(self, schema: t.Dict[str, t.Any]) -> None: super().update_schema(schema) From f1d701907894e30434befb722548ffe3db1f2628 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 6 Mar 2025 22:57:29 +0200 Subject: [PATCH 8/8] Bump migration script --- ...74_remove_validate_query.py => v0075_remove_validate_query.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sqlmesh/migrations/{v0074_remove_validate_query.py => v0075_remove_validate_query.py} (100%) diff --git a/sqlmesh/migrations/v0074_remove_validate_query.py b/sqlmesh/migrations/v0075_remove_validate_query.py similarity index 100% rename from sqlmesh/migrations/v0074_remove_validate_query.py rename to sqlmesh/migrations/v0075_remove_validate_query.py