Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/multi/repo_1/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ after_all:

model_defaults:
dialect: 'duckdb'

linter:
enabled: True

warn_rules: "ALL"
15 changes: 15 additions & 0 deletions examples/multi/repo_1/linter/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Contains all the standard rules included with SQLMesh"""

from __future__ import annotations

import typing as t

from sqlmesh.core.linter.rule import Rule, RuleViolation
from sqlmesh.core.model import Model


class NoMissingDescription(Rule):
"""All models should be documented."""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return self.violation() if not model.description else None
5 changes: 5 additions & 0 deletions examples/multi/repo_2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ after_all:

model_defaults:
dialect: 'duckdb'

linter:
enabled: True

ignored_rules: "ALL"
13 changes: 13 additions & 0 deletions examples/sushi/linter/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

import typing as t

from sqlmesh.core.linter.rule import Rule, RuleViolation
from sqlmesh.core.model import Model


class NoMissingOwner(Rule):
"""All models should have an owner specified."""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return self.violation() if not model.owner else None
1 change: 1 addition & 0 deletions sqlmesh/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig
from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig
from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig
from sqlmesh.core.config.linter import LinterConfig as LinterConfig
from sqlmesh.core.config.plan import PlanConfig as PlanConfig
from sqlmesh.core.config.root import Config as Config
from sqlmesh.core.config.run import RunConfig as RunConfig
Expand Down
44 changes: 44 additions & 0 deletions sqlmesh/core/config/linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.helper import ensure_collection

from sqlmesh.core.config.base import BaseConfig

from sqlmesh.utils.pydantic import field_validator


class LinterConfig(BaseConfig):
"""Configuration for model linting

Args:
enabled: Flag indicating whether the linter should run

rules: A list of error rules to be applied on model
warn_rules: A list of rules to be applied on models but produce warnings instead of raising errors.
ignored_rules: A list of rules to be excluded/ignored

"""

enabled: bool = False

rules: t.Set[str] = set()
warn_rules: t.Set[str] = set()
ignored_rules: t.Set[str] = set()

@classmethod
def _validate_rules(cls, v: t.Any) -> t.Set[str]:
if isinstance(v, exp.Paren):
v = v.unnest().name
elif isinstance(v, (exp.Tuple, exp.Array)):
v = [e.name for e in v.expressions]
elif isinstance(v, exp.Expression):
v = v.name

return {name.lower() for name in ensure_collection(v)}

@field_validator("rules", "warn_rules", "ignored_rules", mode="before")
def rules_validator(cls, vs: t.Any) -> t.Set[str]:
return cls._validate_rules(vs)
2 changes: 0 additions & 2 deletions sqlmesh/core/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class ModelDefaultsConfig(BaseConfig):
session_properties: A key-value mapping of properties specific to the target engine that are applied to the engine session.
audits: The audits to be applied globally to all models in the project.
optimize_query: Whether the SQL models should be optimized.
validate_query: Whether the SQL models should be validated at compile time.
allow_partials: Whether the models can process partial (incomplete) data intervals.
enabled: Whether the models are enabled.
interval_unit: The temporal granularity of the models data intervals. By default computed from cron.
Expand All @@ -58,7 +57,6 @@ class ModelDefaultsConfig(BaseConfig):
session_properties: t.Optional[t.Dict[str, t.Any]] = None
audits: t.Optional[t.List[FunctionCall]] = None
optimize_query: t.Optional[bool] = None
validate_query: t.Optional[bool] = None
allow_partials: t.Optional[bool] = None
interval_unit: t.Optional[IntervalUnit] = None
enabled: t.Optional[bool] = None
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/config/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sqlmesh.core.config.migration import MigrationConfig
from sqlmesh.core.config.model import ModelDefaultsConfig
from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig
from sqlmesh.core.config.linter import LinterConfig as LinterConfig
from sqlmesh.core.config.plan import PlanConfig
from sqlmesh.core.config.run import RunConfig
from sqlmesh.core.config.scheduler import (
Expand Down Expand Up @@ -124,6 +125,7 @@ class Config(BaseConfig):
disable_anonymized_analytics: bool = False
before_all: t.Optional[t.List[str]] = None
after_all: t.Optional[t.List[str]] = None
linter: LinterConfig = LinterConfig()

_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
"gateways": UpdateStrategy.NESTED_UPDATE,
Expand All @@ -141,6 +143,7 @@ class Config(BaseConfig):
"plan": UpdateStrategy.NESTED_UPDATE,
"before_all": UpdateStrategy.EXTEND,
"after_all": UpdateStrategy.EXTEND,
"linter": UpdateStrategy.NESTED_UPDATE,
}

_connection_config_validator = connection_config_validator
Expand Down
25 changes: 25 additions & 0 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from rich.tree import Tree

from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.linter.rule import RuleViolation
from sqlmesh.core.model import Model
from sqlmesh.core.snapshot import (
Snapshot,
SnapshotChangeCategory,
Expand Down Expand Up @@ -319,6 +321,12 @@ def _limit_model_names(self, tree: Tree, verbose: bool = False) -> Tree:
]
return tree

@abc.abstractmethod
def show_linter_violations(
self, violations: t.List[RuleViolation], model: Model, is_error: bool = False
) -> None:
"""Prints all linter violations depending on their severity"""


class NoopConsole(Console):
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
Expand Down Expand Up @@ -481,6 +489,11 @@ def show_row_diff(
def print_environments(self, environments_summary: t.Dict[str, int]) -> None:
pass

def show_linter_violations(
self, violations: t.List[RuleViolation], model: Model, is_error: bool = False
) -> None:
pass


def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> Progress:
return Progress(
Expand Down Expand Up @@ -1548,6 +1561,18 @@ def _snapshot_change_choices(
}
return labeled_choices

def show_linter_violations(
self, violations: t.List[RuleViolation], model: Model, is_error: bool = False
) -> None:
severity = "errors" if is_error else "warnings"
violations_msg = "\n".join(f" - {violation}" for violation in violations)
msg = f"Linter {severity} for {model._path}:\n{violations_msg}"

if is_error:
self.log_error(msg)
else:
self.log_warning(msg)


def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget:
"""Helper function to add a widget to a layout widget.
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
AUDITS = "audits"
CACHE = ".cache"
EXTERNAL_MODELS = "external_models"
LINTER = "linter"
MACROS = "macros"
MATERIALIZATIONS = "materializations"
METRICS = "metrics"
Expand Down
32 changes: 30 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
from sqlmesh.core.loader import Loader
from sqlmesh.core.linter.definition import Linter
from sqlmesh.core.linter.rules import BUILTIN_RULES
from sqlmesh.core.macros import ExecutableOrMacro, macro
from sqlmesh.core.metric import Metric, rewrite
from sqlmesh.core.model import Model, update_model_schemas
Expand Down Expand Up @@ -121,6 +123,7 @@
PlanError,
SQLMeshError,
UncategorizedPlanError,
LinterError,
)
from sqlmesh.utils.config import print_config
from sqlmesh.utils.jinja import JinjaMacroRegistry
Expand Down Expand Up @@ -349,6 +352,7 @@ def __init__(
self._environment_statements: t.List[EnvironmentStatements] = []
self._excluded_requirements: t.Set[str] = set()
self._default_catalog: t.Optional[str] = None
self._linters: t.Dict[str, Linter] = {}
self._loaded: bool = False

self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
Expand Down Expand Up @@ -497,6 +501,8 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:

model.validate_definition()

self.lint_models(model)

return model

def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
Expand Down Expand Up @@ -576,9 +582,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._metrics.clear()
self._requirements.clear()
self._excluded_requirements.clear()
self._linters.clear()
self._environment_statements = []

for project in loaded_projects:
for loader, project in zip(self._loaders, loaded_projects):
self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
self._macros.update(project.macros)
self._models.update(project.models)
Expand All @@ -590,6 +597,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
if project.environment_statements:
self._environment_statements.append(project.environment_statements)

config = loader.config
self._linters[config.project] = Linter.from_rules(
BUILTIN_RULES.union(project.user_rules), config.linter
)

uncached = set()

if any(self._projects):
Expand Down Expand Up @@ -623,10 +635,13 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:

update_model_schemas(self.dag, models=self._models, context_path=self.path)

for model in self.models.values():
models = self.models.values()
for model in models:
# The model definition can be validated correctly only after the schema is set.
model.validate_definition()

self.lint_models(*models)

duplicates = set(self._models) & set(self._standalone_audits)
if duplicates:
raise ConfigError(
Expand Down Expand Up @@ -2353,6 +2368,19 @@ def _get_models_for_interval_end(
)
return models_for_interval_end

def lint_models(self, *models: Model) -> None:
found_error = False

for model in models:
# Linter may be `None` if the context is not loaded yet
if linter := self._linters.get(model.project):
found_error = linter.lint_model(model) or found_error

if found_error:
raise LinterError(
"Linter detected errors in the code. Please fix them before proceeding."
)


class Context(GenericContext[Config]):
CONFIG_TYPE = Config
Empty file added sqlmesh/core/linter/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions sqlmesh/core/linter/definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import typing as t

from sqlmesh.core.config.linter import LinterConfig

from sqlmesh.core.model import Model

from sqlmesh.utils.errors import raise_config_error
from sqlmesh.core.console import get_console
from sqlmesh.core.linter.rule import RuleSet


def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
if "all" in rule_names:
return all_rules

rules = set()
for rule_name in rule_names:
if rule_name not in all_rules:
raise_config_error(f"Rule {rule_name} could not be found")

rules.add(all_rules[rule_name])

return RuleSet(rules)


class Linter:
def __init__(
self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet
) -> None:
self.enabled = enabled
self.all_rules = all_rules
self.rules = rules
self.warn_rules = warn_rules

@classmethod
def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
ignored_rules = select_rules(all_rules, config.ignored_rules)
included_rules = all_rules.difference(ignored_rules)

rules = select_rules(included_rules, config.rules)
warn_rules = select_rules(included_rules, config.warn_rules)

if overlapping := rules.intersection(warn_rules):
overlapping_rules = ", ".join(rule for rule in overlapping)
raise_config_error(
f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
)

return Linter(config.enabled, all_rules, rules, warn_rules)

def lint_model(self, model: Model) -> bool:
if not self.enabled:
return False

ignored_rules = select_rules(self.all_rules, model.ignored_rules)

rules = self.rules.difference(ignored_rules)
warn_rules = self.warn_rules.difference(ignored_rules)

error_violations = rules.check_model(model)
warn_violations = warn_rules.check_model(model)

if warn_violations:
get_console().show_linter_violations(warn_violations, model)

if error_violations:
get_console().show_linter_violations(error_violations, model, is_error=True)
return True

return False
Loading