Skip to content
69 changes: 27 additions & 42 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@
from sqlmesh.core.table_diff import TableDiff
from sqlmesh.core.test import (
ModelTextTestResult,
ModelTestMetadata,
generate_test,
get_all_model_tests,
run_model_tests,
run_tests,
)
from sqlmesh.core.user import User
Expand Down Expand Up @@ -1786,47 +1785,20 @@ def test(
if verbosity >= Verbosity.VERBOSE:
pd.set_option("display.max_columns", None)

if tests:
result = run_model_tests(
tests=tests,
models=self._models,
config=self.config,
gateway=self.gateway,
dialect=self.default_dialect,
verbosity=verbosity,
patterns=match_patterns,
preserve_fixtures=preserve_fixtures,
stream=stream,
default_catalog=self.default_catalog,
default_catalog_dialect=self.engine_adapter.DIALECT,
)
else:
test_meta = []

for path, config in self.configs.items():
test_meta.extend(
get_all_model_tests(
path / c.TESTS,
patterns=match_patterns,
ignore_patterns=config.ignore_patterns,
variables=config.variables,
)
)
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)

result = run_tests(
model_test_metadata=test_meta,
models=self._models,
config=self.config,
gateway=self.gateway,
dialect=self.default_dialect,
verbosity=verbosity,
preserve_fixtures=preserve_fixtures,
stream=stream,
default_catalog=self.default_catalog,
default_catalog_dialect=self.engine_adapter.DIALECT,
)

return result
return run_tests(
model_test_metadata=test_meta,
models=self._models,
config=self.config,
selected_gateway=self.selected_gateway,
dialect=self.default_dialect,
verbosity=verbosity,
preserve_fixtures=preserve_fixtures,
stream=stream,
default_catalog=self.default_catalog,
default_catalog_dialect=self.engine_adapter.DIALECT,
)

@python_api_analytics
def audit(
Expand Down Expand Up @@ -2490,6 +2462,19 @@ def lint_models(
"Linter detected errors in the code. Please fix them before proceeding."
)

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
) -> t.List[ModelTestMetadata]:
# If a set of specific test path(s) are provided, we can use a single loader
# since it's not required to walk every tests/ folder in each repo
loaders = [self._loaders[0]] if tests else self._loaders

model_tests = []
for loader in loaders:
model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns))

return model_tests


class Context(GenericContext[Config]):
CONFIG_TYPE = Config
69 changes: 68 additions & 1 deletion sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import abc
import glob
import itertools
import linecache
import logging
import os
import re
import typing as t
from collections import Counter, defaultdict
from dataclasses import dataclass
Expand All @@ -31,18 +33,22 @@
from sqlmesh.core.model import model as model_registry
from sqlmesh.core.model.common import make_python_env
from sqlmesh.core.signal import signal
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
from sqlmesh.utils import UniqueKeyDict, sys_path
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
from sqlmesh.utils.metaprogramming import import_python_file
from sqlmesh.utils.yaml import YAML
from sqlmesh.utils.yaml import YAML, load as yaml_load


if t.TYPE_CHECKING:
from sqlmesh.core.context import GenericContext


logger = logging.getLogger(__name__)

GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)")


@dataclass
class LoadedProject:
Expand Down Expand Up @@ -290,6 +296,12 @@ def _load_linting_rules(self) -> RuleSet:
"""Loads user linting rules"""
return RuleSet()

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
) -> t.List[ModelTestMetadata]:
"""Loads YAML-based model tests"""
return []

def _glob_paths(
self,
path: Path,
Expand Down Expand Up @@ -678,6 +690,61 @@ def _load_linting_rules(self) -> RuleSet:

return RuleSet(user_rules.values())

def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
"""Load a single model test file."""
model_test_metadata = {}

with open(path, "r", encoding="utf-8") as file:
source = file.read()
# If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to
# parse it as YAML to match the gateway name stored in the config
gateway_line = GATEWAY_PATTERN.search(source)
gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None

contents = yaml_load(source, variables=self._get_variables(gateway))

for test_name, value in contents.items():
model_test_metadata[test_name] = ModelTestMetadata(
path=path, test_name=test_name, body=value
)

return model_test_metadata

def load_model_tests(
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
) -> t.List[ModelTestMetadata]:
"""Loads YAML-based model tests"""
test_meta_list: t.List[ModelTestMetadata] = []

if tests:
for test in tests:
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")

test_meta = self._load_model_test_file(Path(filename))
if test_name:
test_meta_list.append(test_meta[test_name])
else:
test_meta_list.extend(test_meta.values())
else:
search_path = Path(self.config_path) / c.TESTS

for yaml_file in itertools.chain(
search_path.glob("**/test*.yaml"),
search_path.glob("**/test*.yml"),
):
if any(
yaml_file.match(ignore_pattern)
for ignore_pattern in self.config.ignore_patterns or []
):
continue

test_meta_list.extend(self._load_model_test_file(yaml_file).values())

if patterns:
test_meta_list = filter_tests_by_patterns(test_meta_list, patterns)

return test_meta_list

class _Cache(CacheBase):
def __init__(self, loader: SqlMeshLoader, config_path: Path):
self._loader = loader
Expand Down
7 changes: 1 addition & 6 deletions sqlmesh/core/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
from sqlmesh.core.test.discovery import (
ModelTestMetadata as ModelTestMetadata,
filter_tests_by_patterns as filter_tests_by_patterns,
get_all_model_tests as get_all_model_tests,
load_model_test_file as load_model_test_file,
)
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
from sqlmesh.core.test.runner import (
run_model_tests as run_model_tests,
run_tests as run_tests,
)
from sqlmesh.core.test.runner import run_tests as run_tests
71 changes: 0 additions & 71 deletions sqlmesh/core/test/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import itertools
import pathlib
import typing as t
from collections.abc import Iterator

import ruamel

from sqlmesh.utils import unique
from sqlmesh.utils.pydantic import PydanticModel
from sqlmesh.utils.yaml import load as yaml_load


class ModelTestMetadata(PydanticModel):
Expand All @@ -26,59 +24,6 @@ def __hash__(self) -> int:
return self.fully_qualified_test_name.__hash__()


def load_model_test_file(
path: pathlib.Path, variables: dict[str, t.Any] | None = None
) -> dict[str, ModelTestMetadata]:
"""Load a single model test file.

Args:
path: The path to the test file

returns:
A list of ModelTestMetadata named tuples.
"""
model_test_metadata = {}
contents = yaml_load(path, variables=variables)

for test_name, value in contents.items():
model_test_metadata[test_name] = ModelTestMetadata(
path=path, test_name=test_name, body=value
)
return model_test_metadata


def discover_model_tests(
path: pathlib.Path,
ignore_patterns: list[str] | None = None,
variables: dict[str, t.Any] | None = None,
) -> Iterator[ModelTestMetadata]:
"""Discover model tests.

Model tests are defined in YAML files and contain the inputs and outputs used to test model queries.

Args:
path: A path to search for tests.
ignore_patterns: An optional list of patterns to ignore.

Returns:
A list of ModelTestMetadata named tuples.
"""
search_path = pathlib.Path(path)

for yaml_file in itertools.chain(
search_path.glob("**/test*.yaml"),
search_path.glob("**/test*.yml"),
):
for ignore_pattern in ignore_patterns or []:
if yaml_file.match(ignore_pattern):
break
else:
for model_test_metadata in load_model_test_file(
yaml_file, variables=variables
).values():
yield model_test_metadata


def filter_tests_by_patterns(
tests: list[ModelTestMetadata], patterns: list[str]
) -> list[ModelTestMetadata]:
Expand All @@ -97,19 +42,3 @@ def filter_tests_by_patterns(
if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern))
or pattern in test.fully_qualified_test_name
)


def get_all_model_tests(
*paths: pathlib.Path,
patterns: list[str] | None = None,
ignore_patterns: list[str] | None = None,
variables: dict[str, t.Any] | None = None,
) -> list[ModelTestMetadata]:
model_test_metadatas = [
meta
for path in paths
for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables)
]
if patterns:
model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns)
return model_test_metadatas
1 change: 0 additions & 1 deletion sqlmesh/core/test/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def log_test_report(self, test_duration: float) -> None:
for _, error in errors:
stream.writeln(unittest.TextTestResult.separator1)
stream.writeln(f"ERROR: {error}")
stream.writeln(unittest.TextTestResult.separator2)

# Output final report
stream.writeln(unittest.TextTestResult.separator2)
Expand Down
Loading