diff --git a/examples/multi/repo_1/config.yaml b/examples/multi/repo_1/config.yaml index c84b849958..0f35441b86 100644 --- a/examples/multi/repo_1/config.yaml +++ b/examples/multi/repo_1/config.yaml @@ -20,3 +20,8 @@ after_all: model_defaults: dialect: 'duckdb' + +linter: + enabled: True + + warn_rules: "ALL" \ No newline at end of file diff --git a/examples/multi/repo_1/linter/user.py b/examples/multi/repo_1/linter/user.py new file mode 100644 index 0000000000..1dfc7c8ae2 --- /dev/null +++ b/examples/multi/repo_1/linter/user.py @@ -0,0 +1,15 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingDescription(Rule): + """All models should be documented.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + return self.violation() if not model.description else None diff --git a/examples/multi/repo_2/config.yaml b/examples/multi/repo_2/config.yaml index bc4603f0f5..23bec6d8fe 100644 --- a/examples/multi/repo_2/config.yaml +++ b/examples/multi/repo_2/config.yaml @@ -20,3 +20,8 @@ after_all: model_defaults: dialect: 'duckdb' + +linter: + enabled: True + + ignored_rules: "ALL" \ No newline at end of file diff --git a/examples/sushi/linter/user.py b/examples/sushi/linter/user.py new file mode 100644 index 0000000000..a4a33c0efc --- /dev/null +++ b/examples/sushi/linter/user.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingOwner(Rule): + """All models should have an owner specified.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + return self.violation() if not model.owner else None diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index b72a271f9d..6017fc8895 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -30,6 +30,7 @@ from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig as PlanConfig from sqlmesh.core.config.root import Config as Config from sqlmesh.core.config.run import RunConfig as RunConfig diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py new file mode 100644 index 0000000000..c2a40e09aa --- /dev/null +++ b/sqlmesh/core/config/linter.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.helper import ensure_collection + +from sqlmesh.core.config.base import BaseConfig + +from sqlmesh.utils.pydantic import field_validator + + +class LinterConfig(BaseConfig): + """Configuration for model linting + + Args: + enabled: Flag indicating whether the linter should run + + rules: A list of error rules to be applied on model + warn_rules: A list of rules to be applied on models but produce warnings instead of raising errors. + ignored_rules: A list of rules to be excluded/ignored + + """ + + enabled: bool = False + + rules: t.Set[str] = set() + warn_rules: t.Set[str] = set() + ignored_rules: t.Set[str] = set() + + @classmethod + def _validate_rules(cls, v: t.Any) -> t.Set[str]: + if isinstance(v, exp.Paren): + v = v.unnest().name + elif isinstance(v, (exp.Tuple, exp.Array)): + v = [e.name for e in v.expressions] + elif isinstance(v, exp.Expression): + v = v.name + + return {name.lower() for name in ensure_collection(v)} + + @field_validator("rules", "warn_rules", "ignored_rules", mode="before") + def rules_validator(cls, vs: t.Any) -> t.Set[str]: + return cls._validate_rules(vs) diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index 62ec7c5935..84e1fa8b76 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -38,7 +38,6 @@ class ModelDefaultsConfig(BaseConfig): session_properties: A key-value mapping of properties specific to the target engine that are applied to the engine session. audits: The audits to be applied globally to all models in the project. optimize_query: Whether the SQL models should be optimized. - validate_query: Whether the SQL models should be validated at compile time. allow_partials: Whether the models can process partial (incomplete) data intervals. enabled: Whether the models are enabled. interval_unit: The temporal granularity of the models data intervals. By default computed from cron. @@ -58,7 +57,6 @@ class ModelDefaultsConfig(BaseConfig): session_properties: t.Optional[t.Dict[str, t.Any]] = None audits: t.Optional[t.List[FunctionCall]] = None optimize_query: t.Optional[bool] = None - validate_query: t.Optional[bool] = None allow_partials: t.Optional[bool] = None interval_unit: t.Optional[IntervalUnit] = None enabled: t.Optional[bool] = None diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index b3fa1cb801..c1642ccd43 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -28,6 +28,7 @@ from sqlmesh.core.config.migration import MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig from sqlmesh.core.config.run import RunConfig from sqlmesh.core.config.scheduler import ( @@ -124,6 +125,7 @@ class Config(BaseConfig): disable_anonymized_analytics: bool = False before_all: t.Optional[t.List[str]] = None after_all: t.Optional[t.List[str]] = None + linter: LinterConfig = LinterConfig() _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { "gateways": UpdateStrategy.NESTED_UPDATE, @@ -141,6 +143,7 @@ class Config(BaseConfig): "plan": UpdateStrategy.NESTED_UPDATE, "before_all": UpdateStrategy.EXTEND, "after_all": UpdateStrategy.EXTEND, + "linter": UpdateStrategy.NESTED_UPDATE, } _connection_config_validator = connection_config_validator diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index adb8f20a54..9c889e8df6 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -26,6 +26,8 @@ from rich.tree import Tree from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.linter.rule import RuleViolation +from sqlmesh.core.model import Model from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, @@ -319,6 +321,12 @@ def _limit_model_names(self, tree: Tree, verbose: bool = False) -> Tree: ] return tree + @abc.abstractmethod + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + """Prints all linter violations depending on their severity""" + class NoopConsole(Console): def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: @@ -481,6 +489,11 @@ def show_row_diff( def print_environments(self, environments_summary: t.Dict[str, int]) -> None: pass + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + pass + def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> Progress: return Progress( @@ -1548,6 +1561,18 @@ def _snapshot_change_choices( } return labeled_choices + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + severity = "errors" if is_error else "warnings" + violations_msg = "\n".join(f" - {violation}" for violation in violations) + msg = f"Linter {severity} for {model._path}:\n{violations_msg}" + + if is_error: + self.log_error(msg) + else: + self.log_warning(msg) + def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: """Helper function to add a widget to a layout widget. diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index df204cfbd9..abc096c9a6 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -60,6 +60,7 @@ AUDITS = "audits" CACHE = ".cache" EXTERNAL_MODELS = "external_models" +LINTER = "linter" MACROS = "macros" MATERIALIZATIONS = "materializations" METRICS = "metrics" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index ee7b75bfea..2b8302a69b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -75,6 +75,8 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements from sqlmesh.core.loader import Loader +from sqlmesh.core.linter.definition import Linter +from sqlmesh.core.linter.rules import BUILTIN_RULES from sqlmesh.core.macros import ExecutableOrMacro, macro from sqlmesh.core.metric import Metric, rewrite from sqlmesh.core.model import Model, update_model_schemas @@ -121,6 +123,7 @@ PlanError, SQLMeshError, UncategorizedPlanError, + LinterError, ) from sqlmesh.utils.config import print_config from sqlmesh.utils.jinja import JinjaMacroRegistry @@ -349,6 +352,7 @@ def __init__( self._environment_statements: t.List[EnvironmentStatements] = [] self._excluded_requirements: t.Set[str] = set() self._default_catalog: t.Optional[str] = None + self._linters: t.Dict[str, Linter] = {} self._loaded: bool = False self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -497,6 +501,8 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: model.validate_definition() + self.lint_models(model) + return model def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: @@ -576,9 +582,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._metrics.clear() self._requirements.clear() self._excluded_requirements.clear() + self._linters.clear() self._environment_statements = [] - for project in loaded_projects: + for loader, project in zip(self._loaders, loaded_projects): self._jinja_macros = self._jinja_macros.merge(project.jinja_macros) self._macros.update(project.macros) self._models.update(project.models) @@ -590,6 +597,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: if project.environment_statements: self._environment_statements.append(project.environment_statements) + config = loader.config + self._linters[config.project] = Linter.from_rules( + BUILTIN_RULES.union(project.user_rules), config.linter + ) + uncached = set() if any(self._projects): @@ -623,10 +635,13 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: update_model_schemas(self.dag, models=self._models, context_path=self.path) - for model in self.models.values(): + models = self.models.values() + for model in models: # The model definition can be validated correctly only after the schema is set. model.validate_definition() + self.lint_models(*models) + duplicates = set(self._models) & set(self._standalone_audits) if duplicates: raise ConfigError( @@ -2353,6 +2368,19 @@ def _get_models_for_interval_end( ) return models_for_interval_end + def lint_models(self, *models: Model) -> None: + found_error = False + + for model in models: + # Linter may be `None` if the context is not loaded yet + if linter := self._linters.get(model.project): + found_error = linter.lint_model(model) or found_error + + if found_error: + raise LinterError( + "Linter detected errors in the code. Please fix them before proceeding." + ) + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/linter/__init__.py b/sqlmesh/core/linter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py new file mode 100644 index 0000000000..27d33fd5e1 --- /dev/null +++ b/sqlmesh/core/linter/definition.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.config.linter import LinterConfig + +from sqlmesh.core.model import Model + +from sqlmesh.utils.errors import raise_config_error +from sqlmesh.core.console import get_console +from sqlmesh.core.linter.rule import RuleSet + + +def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: + if "all" in rule_names: + return all_rules + + rules = set() + for rule_name in rule_names: + if rule_name not in all_rules: + raise_config_error(f"Rule {rule_name} could not be found") + + rules.add(all_rules[rule_name]) + + return RuleSet(rules) + + +class Linter: + def __init__( + self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet + ) -> None: + self.enabled = enabled + self.all_rules = all_rules + self.rules = rules + self.warn_rules = warn_rules + + @classmethod + def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: + ignored_rules = select_rules(all_rules, config.ignored_rules) + included_rules = all_rules.difference(ignored_rules) + + rules = select_rules(included_rules, config.rules) + warn_rules = select_rules(included_rules, config.warn_rules) + + if overlapping := rules.intersection(warn_rules): + overlapping_rules = ", ".join(rule for rule in overlapping) + raise_config_error( + f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" + ) + + return Linter(config.enabled, all_rules, rules, warn_rules) + + def lint_model(self, model: Model) -> bool: + if not self.enabled: + return False + + ignored_rules = select_rules(self.all_rules, model.ignored_rules) + + rules = self.rules.difference(ignored_rules) + warn_rules = self.warn_rules.difference(ignored_rules) + + error_violations = rules.check_model(model) + warn_violations = warn_rules.check_model(model) + + if warn_violations: + get_console().show_linter_violations(warn_violations, model) + + if error_violations: + get_console().show_linter_violations(error_violations, model, is_error=True) + return True + + return False diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py new file mode 100644 index 0000000000..93d80a52f4 --- /dev/null +++ b/sqlmesh/core/linter/rule.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import abc + +import operator as op +from collections.abc import Iterator, Iterable, Set, Mapping, Callable +from functools import reduce + +from sqlmesh.core.model import Model + +from typing import Type + +import typing as t + + +class _Rule(abc.ABCMeta): + def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule: + attrs["name"] = clsname.lower() + return super().__new__(cls, clsname, bases, attrs) + + +class Rule(abc.ABC, metaclass=_Rule): + """The base class for a rule.""" + + name = "rule" + + @abc.abstractmethod + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + """The evaluation function that'll check for a violation of this rule.""" + + @property + def summary(self) -> str: + """A summary of what this rule checks for.""" + return self.__doc__ or "" + + def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation: + """Create a RuleViolation instance for this rule""" + return RuleViolation(rule=self, violation_msg=violation_msg or self.summary) + + def __repr__(self) -> str: + return self.name + + +class RuleViolation: + def __init__(self, rule: Rule, violation_msg: str) -> None: + self.rule = rule + self.violation_msg = violation_msg + + def __repr__(self) -> str: + return f"{self.rule.name}: {self.violation_msg}" + + +class RuleSet(Mapping[str, type[Rule]]): + def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: + self._underlying = {rule.name: rule for rule in rules} + + def check_model(self, model: Model) -> t.List[RuleViolation]: + violations = [] + + for rule in self._underlying.values(): + violation = rule().check_model(model) + + if violation: + violations.append(violation) + + return violations + + def __iter__(self) -> Iterator[str]: + return iter(self._underlying) + + def __len__(self) -> int: + return len(self._underlying) + + def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: + key = rule if isinstance(rule, str) else rule.name + return self._underlying[key] + + def __op( + self, + op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], + other: RuleSet, + /, + ) -> RuleSet: + rules = set() + for rule in op(set(self.values()), set(other.values())): + rules.add(other[rule] if rule in other else self[rule]) + + return RuleSet(rules) + + def union(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) + + def intersection(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) + + def difference(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others)) diff --git a/sqlmesh/core/linter/rules/__init__.py b/sqlmesh/core/linter/rules/__init__.py new file mode 100644 index 0000000000..43812479a5 --- /dev/null +++ b/sqlmesh/core/linter/rules/__init__.py @@ -0,0 +1 @@ +from sqlmesh.core.linter.rules.builtin import BUILTIN_RULES as BUILTIN_RULES diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py new file mode 100644 index 0000000000..664bcc9b23 --- /dev/null +++ b/sqlmesh/core/linter/rules/builtin.py @@ -0,0 +1,51 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlglot.helper import subclasses + +from sqlmesh.core.linter.rule import Rule, RuleViolation, RuleSet +from sqlmesh.core.model import Model, SqlModel + + +class NoSelectStar(Rule): + """Query should not contain SELECT * on its outer most projections, even if it can be expanded.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + if not isinstance(model, SqlModel): + return None + + return self.violation() if model.query.is_star else None + + +class InvalidSelectStarExpansion(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + deps = model.violated_rules_for_query.get(InvalidSelectStarExpansion) + if not deps: + return None + + violation_msg = ( + f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " + "Run `sqlmesh create_external_models` and / or make sure that the model " + f"'{model.fqn}' can be rendered at parse time." + ) + + return self.violation(violation_msg) + + +class AmbiguousOrInvalidColumn(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + sqlglot_err = model.violated_rules_for_query.get(AmbiguousOrInvalidColumn) + if not sqlglot_err: + return None + + violation_msg = ( + f"{sqlglot_err} for model '{model.fqn}', the column may not exist or is ambiguous." + ) + + return self.violation(violation_msg) + + +BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,))) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index f90462d9ee..bb9a669ccd 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -12,11 +12,13 @@ from sqlglot.errors import SqlglotError from sqlglot import exp +from sqlglot.helper import subclasses from sqlmesh.core import constants as c from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits from sqlmesh.core.dialect import parse from sqlmesh.core.environment import EnvironmentStatements +from sqlmesh.core.linter.rule import RuleSet, Rule from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl from sqlmesh.core.model import ( @@ -54,6 +56,7 @@ class LoadedProject: requirements: t.Dict[str, str] excluded_requirements: t.Set[str] environment_statements: t.Optional[EnvironmentStatements] + user_rules: RuleSet class Loader(abc.ABC): @@ -120,6 +123,8 @@ def load(self) -> LoadedProject: environment_statements = self._load_environment_statements(macros=macros) + user_rules = self._load_linting_rules() + project = LoadedProject( macros=macros, jinja_macros=jinja_macros, @@ -130,6 +135,7 @@ def load(self) -> LoadedProject: requirements=requirements, excluded_requirements=excluded_requirements, environment_statements=environment_statements, + user_rules=user_rules, ) return project @@ -265,6 +271,10 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: return requirements, excluded_requirements + def _load_linting_rules(self) -> RuleSet: + """Loads user linting rules""" + return RuleSet() + def _glob_paths( self, path: Path, @@ -630,6 +640,23 @@ def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStat return EnvironmentStatements(**statements, python_env=python_env) return None + def _load_linting_rules(self) -> RuleSet: + user_rules: UniqueKeyDict[str, type[Rule]] = UniqueKeyDict("rules") + + for path in self._glob_paths( + self.config_path / c.LINTER, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if os.path.getsize(path): + self._track_file(path) + module = import_python_file(path, self.config_path) + module_rules = subclasses(module.__name__, Rule, (Rule,)) + for user_rule in module_rules: + user_rules[user_rule.name] = user_rule + + return RuleSet(user_rules.values()) + class _Cache: def __init__(self, loader: SqlMeshLoader, config_path: Path): self._loader = loader diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index 268abad966..6bb5c31846 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -22,6 +22,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.snapshot import SnapshotId + from sqlmesh.core.linter.rule import Rule T = t.TypeVar("T") @@ -71,6 +72,7 @@ def get_or_load( @dataclass class OptimizedQueryCacheEntry: optimized_rendered_query: t.Optional[exp.Expression] + renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]] class OptimizedQueryCache: @@ -100,15 +102,15 @@ def with_optimized_query(self, model: Model, name: t.Optional[str] = None) -> bo cache_entry = self._file_cache.get(name) if cache_entry: try: - if cache_entry.optimized_rendered_query: - model._query_renderer.update_cache( - cache_entry.optimized_rendered_query, optimized=True - ) - else: - # If the optimized rendered query is None, then there are likely adapter calls in the query - # that prevent us from rendering it at load time. This means that we can safely set the - # unoptimized cache to None as well to prevent attempts to render it downstream. - model._query_renderer.update_cache(None, optimized=False) + # If the optimized rendered query is None, then there are likely adapter calls in the query + # that prevent us from rendering it at load time. This means that we can safely set the + # unoptimized cache to None as well to prevent attempts to render it downstream. + optimized = cache_entry.optimized_rendered_query is not None + model._query_renderer.update_cache( + cache_entry.optimized_rendered_query, + cache_entry.renderer_violations, + optimized=optimized, + ) return True except Exception as ex: logger.warning("Failed to load a cache entry '%s': %s", name, ex) @@ -130,7 +132,10 @@ def put(self, model: Model) -> t.Optional[str]: def _put(self, name: str, model: SqlModel) -> None: optimized_query = model.render_query() - new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query) + new_entry = OptimizedQueryCacheEntry( + optimized_rendered_query=optimized_query, + renderer_violations=model.violated_rules_for_query, + ) self._file_cache.put(name, value=new_entry) @staticmethod @@ -140,7 +145,6 @@ def _entry_name(model: SqlModel) -> str: hash_data.append(str([gen(d) for d in model.macro_definitions])) hash_data.append(str([(k, v) for k, v in model.sorted_python_env])) hash_data.extend(model.jinja_macros.data_hash_values) - hash_data.extend(str(model.validate_query)) return f"{model.name}_{crc32(hash_data)}" diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 4122d6d2ff..d4bed645a2 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -321,7 +321,6 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ "allow_partials", "enabled", "optimize_query", - "validate_query", mode="before", check_fields=False, )(parse_bool) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index d3d212e0e1..caadeba4cb 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -64,6 +64,7 @@ from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF + from sqlmesh.core.linter.rule import Rule from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot from sqlmesh.utils.jinja import MacroReference @@ -237,7 +238,7 @@ def render_definition( "enabled", "inline_audits", "optimize_query", - "validate_query", + "ignored_rules_", ): expressions.append( exp.Property( @@ -980,12 +981,6 @@ def validate_definition(self) -> None: self._path, ) - if self.validate_query: - raise_config_error( - "Query validation can only be enabled for SQL models", - self._path, - ) - if isinstance(self.kind, CustomKind): from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type_or_raise @@ -1087,7 +1082,6 @@ def metadata_hash(self) -> str: self.project, str(self.allow_partials), gen(self.session_properties_) if self.session_properties_ else None, - str(self.validate_query) if self.validate_query is not None else None, *[gen(g) for g in self.grains], ] @@ -1228,6 +1222,10 @@ def _is_time_column_in_partitioned_by(self) -> bool: col for expr in self.partitioned_by_ for col in expr.find_all(exp.Column) } + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + return {} + class SqlModel(_Model): """The model definition which relies on a SQL query to fetch the data. @@ -1287,6 +1285,7 @@ def render_query( engine_adapter=engine_adapter, **kwargs, ) + return query def render_definition( @@ -1475,7 +1474,6 @@ def _query_renderer(self) -> QueryRenderer: default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, optimize_query=self.optimize_query, - validate_query=self.validate_query, ) @property @@ -1491,6 +1489,11 @@ def _data_hash_values(self) -> t.List[str]: def _additional_metadata(self) -> t.List[str]: return [*super()._additional_metadata, gen(self.query)] + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + self.render_query() + return self._query_renderer._violated_rules + class SeedModel(_Model): """The model definition which uses a pre-built static dataset to source the data from. @@ -2325,7 +2328,6 @@ def _create_model( defaults = {k: v for k, v in (defaults or {}).items() if k in klass.all_fields()} if not issubclass(klass, SqlModel): defaults.pop("optimize_query", None) - defaults.pop("validate_query", None) statements = [] diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 3b89b63441..3c62d3db79 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -10,6 +10,7 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core import dialect as d +from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.dialect import normalize_model_name, extract_func_call from sqlmesh.core.model.common import ( bool_validator, @@ -76,7 +77,9 @@ class ModelMeta(_Node): physical_version: t.Optional[str] = None gateway: t.Optional[str] = None optimize_query: t.Optional[bool] = None - validate_query: t.Optional[bool] = None + ignored_rules_: t.Optional[t.Set[str]] = Field( + default=None, exclude=True, alias="ignored_rules" + ) _bool_validator = bool_validator _model_kind_validator = model_kind_validator @@ -285,6 +288,10 @@ def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expressi return refs + @field_validator("ignored_rules_", mode="before") + def ignored_rules_validator(cls, vs: t.Any) -> t.Any: + return LinterConfig._validate_rules(vs) + @model_validator(mode="before") def _pre_root_validator(cls, data: t.Any) -> t.Any: if not isinstance(data, dict): @@ -459,3 +466,7 @@ def fqn(self) -> str: @property def on_destructive_change(self) -> OnDestructiveChange: return getattr(self.kind, "on_destructive_change", OnDestructiveChange.ALLOW) + + @property + def ignored_rules(self) -> t.Set[str]: + return self.ignored_rules_ or set() diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 3857b56d9f..efbebef761 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -29,6 +29,7 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType + from sqlmesh.core.linter.rule import Rule from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot @@ -51,7 +52,6 @@ def __init__( model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, optimize_query: t.Optional[bool] = True, - validate_query: t.Optional[bool] = False, ): self._expression = expression self._dialect = dialect @@ -67,7 +67,6 @@ def __init__( self._cache: t.List[t.Optional[exp.Expression]] = [] self._model_fqn = model_fqn self._optimize_query_flag = optimize_query is not False - self._validate_query = validate_query def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -443,6 +442,7 @@ class QueryRenderer(BaseExpressionRenderer): def __init__(self, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) self._optimized_cache: t.Optional[exp.Query] = None + self._violated_rules: t.Dict[type[Rule], t.Any] = {} def update_schema(self, schema: t.Dict[str, t.Any]) -> None: super().update_schema(schema) @@ -550,7 +550,12 @@ def render( return query - def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = False) -> None: + def update_cache( + self, + expression: t.Optional[exp.Expression], + violated_rules: t.Optional[t.Dict[type[Rule], t.Any]] = None, + optimized: bool = False, + ) -> None: if optimized: if not isinstance(expression, exp.Query): raise SQLMeshError(f"Expected a Query but got: {expression}") @@ -558,7 +563,14 @@ def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = else: super().update_cache(expression) + self._violated_rules = violated_rules or {} + def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: + from sqlmesh.core.linter.rules.builtin import ( + AmbiguousOrInvalidColumn, + InvalidSelectStarExpansion, + ) + # We don't want to normalize names in the schema because that's handled by the optimizer original = query missing_deps = set() @@ -571,20 +583,8 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: missing_deps.add(dep) if self._model_fqn and not should_optimize and any(s.is_star for s in query.selects): - from sqlmesh.core.console import get_console - deps = ", ".join(f"'{dep}'" for dep in sorted(missing_deps)) - - warning = ( - f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " - "Run `sqlmesh create_external_models` and / or make sure that the model " - f"'{self._model_fqn}' can be rendered at parse time." - ) - - if self._validate_query: - raise_config_error(warning, self._path) - - get_console().log_warning(warning) + self._violated_rules[InvalidSelectStarExpansion] = deps try: if should_optimize: @@ -603,18 +603,10 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: ) ) except SqlglotError as ex: - from sqlmesh.core.console import get_console - - warning = ( - f"{ex} for model '{self._model_fqn}', the column may not exist or is ambiguous." - ) - - if self._validate_query: - raise_config_error(warning, self._path) + self._violated_rules[AmbiguousOrInvalidColumn] = ex query = original - get_console().log_warning(warning) except Exception as ex: raise_config_error( f"Failed to optimize query, please file an issue at https://github.com/TobikoData/sqlmesh/issues/new. {ex}", diff --git a/sqlmesh/migrations/v0075_remove_validate_query.py b/sqlmesh/migrations/v0075_remove_validate_query.py new file mode 100644 index 0000000000..3c637676d8 --- /dev/null +++ b/sqlmesh/migrations/v0075_remove_validate_query.py @@ -0,0 +1,82 @@ +"""Remove validate_query from existing snapshots.""" + +import json + +import pandas as pd +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils.migration import blob_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + parsed_snapshot["node"].pop("validate_query", None) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index 4d93ef7cd4..0159c42e3a 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -176,6 +176,10 @@ class MissingDefaultCatalogError(SQLMeshError): pass +class LinterError(SQLMeshError): + pass + + def raise_config_error( msg: str, location: t.Optional[str | Path] = None, diff --git a/tests/core/test_context.py b/tests/core/test_context.py index d54b4ded9b..eae4e71936 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -23,14 +23,17 @@ EnvironmentSuffixTarget, ModelDefaultsConfig, SnowflakeConnectionConfig, + LinterConfig, load_configs, ) from sqlmesh.core.context import Context -from sqlmesh.core.console import create_console +from sqlmesh.core.console import create_console, get_console from sqlmesh.core.dialect import parse, schema_ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements -from sqlmesh.core.model import load_sql_based_model, model +from sqlmesh.core.macros import MacroEvaluator +from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model +from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder @@ -44,8 +47,9 @@ to_timestamp, yesterday_ds, ) -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError from sqlmesh.utils.metaprogramming import Executable +from tests.utils.test_helpers import use_terminal_console from tests.utils.test_filesystem import create_temp_file @@ -1635,3 +1639,150 @@ def test_environment_statements_dialect(tmp_path: Path): with pytest.raises(ParseError, match=r"Invalid expression / Unexpected token*"): config.model_defaults.dialect = "duckdb" ctx = Context(paths=[tmp_path], config=config) + + +@pytest.mark.slow +@use_terminal_console +def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: + def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): + model = t.cast(SqlModel, model) + cache_entry = cache._file_cache.get(cache._entry_name(model)) + assert cache_entry is not None + assert cache_entry.optimized_rendered_query is not None + assert cache_entry.renderer_violations is not None + + cfg = LinterConfig(enabled=True, rules="ALL") + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), + paths=tmp_path, + ) + + config_err = "Linter detected errors in the code. Please fix them before proceeding." + + # Case: Ensure load DOES NOT work if linter is enabled + for query in ("SELECT * FROM tbl", "SELECT t.* FROM tbl"): + with pytest.raises(LinterError, match=config_err): + ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) + + error_model = load_sql_based_model(d.parse("MODEL (name test); SELECT col")) + with pytest.raises(LinterError, match=config_err): + ctx.upsert_model(error_model) + + # Case: Ensure error violations are cached if the model did not pass linting + cache = OptimizedQueryCache(tmp_path / c.CACHE) + + assert_cached_violations_exist(cache, error_model) + + # Case: Ensure NoSelectStar only raises for top-level SELECTs, new model shouldn't raise + # and thus should also be cached + model2 = load_sql_based_model( + d.parse("MODEL (name test2); SELECT col FROM (SELECT * FROM tbl)") + ) + ctx.upsert_model(model2) + + model2 = t.cast(SqlModel, model2) + assert cache._file_cache.exists(cache._entry_name(model2)) + + # Case: Ensure warning violations are found again even if the optimized query is cached + ctx.config.linter = LinterConfig(enabled=True, warn_rules="ALL") + ctx.load() + + for i in range(3): + with patch.object(get_console(), "log_warning") as mock_logger: + if i > 1: + # Model's violations have been cached from the previous upserts + assert_cached_violations_exist(cache, model2) + + ctx.upsert_model(error_model) + assert ( + """Column '"col"' could not be resolved for model '"test"'""" + in mock_logger.call_args[0][0] + ) + + # Model's violations have been cached after the former upsert + assert_cached_violations_exist(cache, model2) + + # Case: Ensure load WORKS if linter is enabled but the rules are not + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test); SELECT * FROM (SELECT 1 AS col);", + ) + + ignore_or_warn_cfgs = [ + LinterConfig(enabled=True, warn_rules=["noselectstar"]), + LinterConfig(enabled=True, ignored_rules=["noselectstar"]), + ] + for cfg in ignore_or_warn_cfgs: + ctx.config.linter = cfg + ctx.load() + + # Case: Ensure load DOES NOT work if LinterConfig has overlapping rules + with pytest.raises( + ConfigError, + match=r"Rules cannot simultaneously warn and raise an error: \[noselectstar\]", + ): + ctx.config.linter = LinterConfig( + enabled=True, rules=["noselectstar"], warn_rules=["noselectstar"] + ) + ctx.load() + + # Case: Ensure model attribute overrides global config + ctx.config.linter = LinterConfig(enabled=True, rules=["noselectstar"]) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test, ignored_rules ['ALL']); SELECT * FROM (SELECT 1 AS col);", + ) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test2.sql"), + "MODEL(name test2, ignored_rules ['noselectstar']); SELECT * FROM (SELECT 1 AS col);", + ) + + ctx.load() + + # Case: Ensure we can load & use the user-defined rules + sushi_context.config.linter = LinterConfig(enabled=True, rules=["aLl"]) + sushi_context.upsert_model( + load_sql_based_model( + d.parse("MODEL (name sushi.test); SELECT col FROM (SELECT * FROM tbl)"), + default_catalog="memory", + ) + ) + + with pytest.raises(LinterError, match=config_err): + sushi_context.load() + + # Case: Ensure the Linter also picks up Python model violations + @model(name="memory.sushi.model3", is_sql=True, kind="full", dialect="snowflake") + def model3_entrypoint(evaluator: MacroEvaluator) -> str: + return "select * from model1" + + model3 = model.get_registry()["memory.sushi.model3"].model( + module_path=Path("."), path=Path(".") + ) + + @model(name="memory.sushi.model4", columns={"col": "int"}) + def model4_entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + model4 = model.get_registry()["memory.sushi.model4"].model( + module_path=Path("."), path=Path(".") + ) + + for python_model in (model3, model4): + with pytest.raises(LinterError, match=config_err): + sushi_context.upsert_model(python_model) + + @model(name="memory.sushi.model5", columns={"col": "int"}, owner="test") + def model5_entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + model5 = model.get_registry()["memory.sushi.model5"].model( + module_path=Path("."), path=Path(".") + ) + + sushi_context.upsert_model(model5) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 43bbdfe231..71eafcd5a1 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4,6 +4,7 @@ from collections import Counter from datetime import timedelta from unittest import mock +from unittest.mock import patch import numpy as np import pandas as pd @@ -25,7 +26,7 @@ ModelDefaultsConfig, DuckDBConnectionConfig, ) -from sqlmesh.core.console import Console +from sqlmesh.core.console import Console, get_console from sqlmesh.core.context import Context from sqlmesh.core.config.categorizer import CategorizerConfig from sqlmesh.core.engine_adapter import EngineAdapter @@ -59,7 +60,7 @@ from sqlmesh.utils.errors import NoChangesPlanError from sqlmesh.utils.pydantic import validate_string from tests.conftest import DuckDBMetadata, SushiDataValidator - +from tests.utils.test_helpers import use_terminal_console if t.TYPE_CHECKING: from sqlmesh import QueryOrDF @@ -4265,8 +4266,19 @@ def test_auto_categorization(sushi_context: Context): ) +@use_terminal_console def test_multi(mocker): - context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + context = Context( + paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory", load=False + ) + + with patch.object(get_console(), "log_warning") as mock_logger: + context.load() + warnings = mock_logger.call_args[0][0] + repo1_path, repo2_path = context.configs.keys() + assert f"Linter warnings for {repo1_path}" in warnings + assert f"Linter warnings for {repo2_path}" not in warnings + assert ( context.render("bronze.a").sql() == '''SELECT 1 AS "col_a", 'b' AS "col_b", 1 AS "one", 'repo_1' AS "dup"''' diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 06a8fd8044..8c6e763f39 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -27,6 +27,7 @@ GatewayConfig, NameInferenceConfig, ModelDefaultsConfig, + LinterConfig, ) from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.dialect import parse @@ -60,10 +61,11 @@ from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable from sqlmesh.core.macros import RuntimeStage +from tests.utils.test_helpers import use_terminal_console def missing_schema_warning_msg(model, deps): @@ -280,7 +282,8 @@ def test_model_validation_union_query(): model.validate_definition() -def test_model_qualification(): +@use_terminal_console +def test_model_qualification(tmp_path: Path): with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ @@ -293,11 +296,14 @@ def test_model_qualification(): """ ) - model = load_sql_based_model(expressions) - model.render_query(needs_optimization=True) + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + ctx.upsert_model(load_sql_based_model(expressions)) + assert ( - mock_logger.call_args[0][0] - == """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous.""" + """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous.""" + in mock_logger.call_args[0][0] ) @@ -2726,7 +2732,8 @@ def runtime_macro(**kwargs) -> None: model.validate_definition() -def test_update_schema(): +@use_terminal_console +def test_update_schema(tmp_path: Path): expressions = d.parse( """ MODEL (name db.table); @@ -2743,10 +2750,14 @@ def test_update_schema(): model.update_schema(schema) assert model.mapping_schema == {'"table_a"': {"a": "INT"}} + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) with patch.object(get_console(), "log_warning") as mock_logger: - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg( - '"db"."table"', ('"table_b"',) + ctx.upsert_model(model) + assert ( + missing_schema_warning_msg('"db"."table"', ('"table_b"',)) + in mock_logger.call_args[0][0] ) schema.add_table('"table_b"', {"b": exp.DataType.build("int")}) @@ -2758,7 +2769,8 @@ def test_update_schema(): model.render_query(needs_optimization=True) -def test_missing_schema_warnings(): +@use_terminal_console +def test_missing_schema_warnings(tmp_path: Path): full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2775,6 +2787,10 @@ def test_missing_schema_warnings(): console = get_console() + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + # star, no schema, no deps with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM (SELECT 1 a) x")) @@ -2792,14 +2808,15 @@ def test_missing_schema_warnings(): with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(partial_schema) - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"b"',)) + ctx.upsert_model(model) + + assert missing_schema_warning_msg('"test"', ('"b"',)) in mock_logger.call_args[0][0] # star, no schema with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM b JOIN a")) - model.render_query(needs_optimization=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"a"', '"b"')) + ctx.upsert_model(model) + assert missing_schema_warning_msg('"test"', ('"a"', '"b"')) in mock_logger.call_args[0][0] # no star, full schema with patch.object(console, "log_warning") as mock_logger: @@ -3502,7 +3519,6 @@ def test_project_level_properties(sushi_context): enabled=False, allow_partials=True, interval_unit="quarter_hour", - validate_query=True, optimize_query=True, cron="@hourly", ) @@ -3529,7 +3545,6 @@ def test_project_level_properties(sushi_context): assert model.allow_partials assert model.interval_unit == IntervalUnit.QUARTER_HOUR assert model.optimize_query - assert model.validate_query assert model.cron == "@hourly" assert model.session_properties == { @@ -3580,7 +3595,6 @@ def test_project_level_properties_python_model(): "enabled": False, "allow_partials": True, "interval_unit": "quarter_hour", - "validate_query": True, "optimize_query": True, } @@ -3607,7 +3621,6 @@ def python_model_prop(context, **kwargs): # Even if in the project wide defaults these are ignored for python models assert not m.optimize_query - assert not m.validate_query assert not m.enabled assert m.allow_partials @@ -7646,123 +7659,46 @@ def model_with_virtual_statements(context, **kwargs): ) -def test_compile_time_checks(tmp_path: Path, assert_exp_eq): +def test_compile_time_checks(tmp_path: Path): + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) + + cfg_err = "Linter detected errors in the code. Please fix them before proceeding." + # Strict SELECT * expansion + linter_cfg = LinterConfig( + enabled=True, rules=["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] + ) + ctx.config.linter = linter_cfg strict_query = d.parse( """ MODEL ( name test, - validate_query True, ); SELECT * FROM tbl """ ) - with pytest.raises( - ConfigError, - match=r".*cannot be expanded due to missing schema.*", - ): - load_sql_based_model(strict_query).render_query() + ctx.load() + + with pytest.raises(LinterError, match=cfg_err): + ctx.upsert_model(load_sql_based_model(strict_query)) # Strict column resolution strict_query = d.parse( """ MODEL ( name test, - validate_query True, ); SELECT foo """ ) - with pytest.raises( - ConfigError, - match=r"""Column '"foo"' could not be resolved for model.*""", - ): - load_sql_based_model(strict_query).render_query() - - # Non-strict model with strict defaults raises error, otherwise can still render - strict_default = ModelDefaultsConfig(validate_query=True).dict() - query = d.parse( - """ - MODEL ( - name test, - ); - - SELECT * FROM tbl - """ - ) - - with pytest.raises( - ConfigError, - match=r".*cannot be expanded due to missing schema.*", - ): - load_sql_based_model(query, defaults=strict_default).render_query() - - assert_exp_eq(load_sql_based_model(query).render_query(), 'SELECT * FROM "tbl" AS "tbl"') - - # Ensure plan works for valid queries & cache is invalidated if strict changes - context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) - - query = d.parse( - """ - MODEL ( - name db.test, - validate_query True, - ); - - SELECT 1 AS col - """ - ) - - context.upsert_model(load_sql_based_model(query, default_catalog=context.default_catalog)) - context.plan(auto_apply=True, no_prompts=True) - - context.upsert_model("db.test", validate_query=False) - plan = context.plan(no_prompts=True, auto_apply=True) - - snapshots = list(plan.snapshots.values()) - assert len(snapshots) == 1 - - snapshot = snapshots[0] - assert len(snapshot.previous_versions) == 1 - assert snapshot.change_category == SnapshotChangeCategory.METADATA - - # Ensure non-SQLModels raise if strict mode is set to True - seed_path = tmp_path / "seed.csv" - model_kind = SeedKind(path=str(seed_path.absolute())) - with open(seed_path, "w", encoding="utf-8") as fd: - fd.write( - """ -col_a,col_b,col_c -1,text_a,1.0""" - ) - model = create_seed_model("test_db.test_seed_model", model_kind, validate_query=True) - context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) - - with pytest.raises( - ConfigError, - match=r"Query validation can only be enabled for SQL models at", - ): - context.upsert_model(model) - context.plan(auto_apply=True, no_prompts=True) - - model = create_seed_model("test_db.test_seed_model", model_kind, validate_query=False) - context.upsert_model(model) - context.plan(auto_apply=True, no_prompts=True) - - # Ensure strict defaults don't break all non SQL models to which they weren't applicable in the first place - seed_strict_defaults = create_seed_model( - "test_db.test_seed_model", model_kind, defaults=strict_default - ) - external_strict_defaults = create_external_model( - "test_db.test_external_model", columns={"a": "int", "limit": "int"}, defaults=strict_default - ) - context.upsert_model(seed_strict_defaults) - context.upsert_model(external_strict_defaults) - context.plan(auto_apply=True, no_prompts=True) + with pytest.raises(LinterError, match=cfg_err): + ctx.upsert_model(load_sql_based_model(strict_query)) def test_partition_interval_unit(): @@ -8023,3 +7959,28 @@ def test_missing_column_data_in_columns_key(): ) with pytest.raises(ConfigError, match="Missing data type for column 'culprit'."): load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + +def test_ignored_rules_serialization(): + expressions = d.parse( + """ + MODEL( + name test_model, + ignored_rules ['foo', 'bar'] + ); + + SELECT * FROM tbl; + """, + default_dialect="bigquery", + ) + + model = load_sql_based_model(expressions) + + model_json = model.json() + model_json_parsed = json.loads(model_json) + + assert "ignored_rules" not in model_json_parsed + assert "ignored_rules_" not in model_json_parsed + + deserialized_model = SqlModel.parse_raw(model_json) + assert deserialized_model.dict() == model.dict() diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 06cd7c159b..5b622d264c 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -860,7 +860,7 @@ def test_fingerprint(model: Model, parent_model: Model): original_fingerprint = SnapshotFingerprint( data_hash="1312415267", - metadata_hash="2906564841", + metadata_hash="221611364", ) assert fingerprint == original_fingerprint @@ -921,7 +921,7 @@ def test_fingerprint_seed_model(): expected_fingerprint = SnapshotFingerprint( data_hash="1909791099", - metadata_hash="1153541408", + metadata_hash="3403817841", ) model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) @@ -960,7 +960,7 @@ def test_fingerprint_jinja_macros(model: Model): ) original_fingerprint = SnapshotFingerprint( data_hash="923305614", - metadata_hash="2906564841", + metadata_hash="221611364", ) fingerprint = fingerprint_from_node(model, nodes={}) diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 26792ec296..b19869a0e5 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -81,6 +81,10 @@ def snapshots(make_snapshot: t.Callable) -> t.List[Snapshot]: ] +def compare_snapshot_intervals(x: SnapshotIntervals) -> str: + return x.identifier or "" + + def promote_snapshots( state_sync: EngineAdapterStateSync, snapshots: t.List[Snapshot], @@ -1512,24 +1516,27 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=snapshot.identifier, - version=snapshot.version, - dev_version=snapshot.dev_version, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - dev_intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) # Delete the expired snapshot assert state_sync.delete_expired_snapshots() == [ @@ -1547,25 +1554,28 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - # The intervals of the old snapshot is preserved with the null identifier - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=None, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - # The intervals of the new snapshot has identifier - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + # The intervals of the old snapshot is preserved with the null identifier + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + # The intervals of the new snapshot has identifier + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( @@ -1625,24 +1635,27 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=snapshot.identifier, - version=snapshot.version, - dev_version=snapshot.dev_version, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], + ), + ], + key=compare_snapshot_intervals, + ) # Delete the expired snapshot assert state_sync.delete_expired_snapshots() == [] @@ -1660,30 +1673,33 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), - key=lambda x: x.identifier or "", - ) == [ - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=None, - intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], - ), - SnapshotIntervals( - name='"a"', - identifier=None, - version=snapshot.version, - dev_version=snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - name='"a"', - identifier=new_snapshot.identifier, - version=snapshot.version, - dev_version=new_snapshot.dev_version, - dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], - ), - ] + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-10"))], + ), + ], + key=compare_snapshot_intervals, + ) def test_compact_intervals_after_cleanup( diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 48008064a9..586d8abb6d 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -1,7 +1,10 @@ import pytest +from functools import wraps from sqlglot import expressions from sqlglot.optimizer.annotate_types import annotate_types +from sqlmesh.core.console import set_console, get_console, TerminalConsole + from sqlmesh.utils import columns_to_types_all_known @@ -72,3 +75,16 @@ ) def test_columns_to_types_all_known(columns_to_types, expected) -> None: assert columns_to_types_all_known(columns_to_types) == expected + + +def use_terminal_console(func): + @wraps(func) + def test_wrapper(*args, **kwargs): + orig_console = get_console() + try: + set_console(TerminalConsole()) + func(*args, **kwargs) + finally: + set_console(orig_console) + + return test_wrapper