diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 92ffdef7d2..5d28ef9551 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -147,8 +147,8 @@ from typing_extensions import Literal from sqlmesh.core.engine_adapter._typing import ( - DF, BigframeSession, + DF, PySparkDataFrame, PySparkSession, SnowparkSession, @@ -403,6 +403,7 @@ def __init__( self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {} self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {} self._models_with_tests: t.Set[str] = set() + self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros") self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics") self._jinja_macros = JinjaMacroRegistry() @@ -656,6 +657,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._requirements.update(project.requirements) self._excluded_requirements.update(project.excluded_requirements) self._environment_statements.extend(project.environment_statements) + self._model_test_metadata.extend(project.model_test_metadata) for metadata in project.model_test_metadata: if metadata.path not in self._model_test_metadata_path_index: @@ -2243,9 +2245,7 @@ def test( pd.set_option("display.max_columns", None) - test_meta = self._select_tests( - test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns - ) + test_meta = self.select_tests(tests=tests, patterns=match_patterns) result = run_tests( model_test_metadata=test_meta, @@ -2807,33 +2807,6 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") return self.engine_adapter - def _select_tests( - self, - test_meta: t.List[ModelTestMetadata], - tests: t.Optional[t.List[str]] = None, - patterns: t.Optional[t.List[str]] = None, - ) -> t.List[ModelTestMetadata]: - """Filter pre-loaded test metadata based on tests and patterns.""" - - if tests: - filtered_tests = [] - for test in tests: - if "::" in test: - if test in self._model_test_metadata_fully_qualified_name_index: - filtered_tests.append( - self._model_test_metadata_fully_qualified_name_index[test] - ) - else: - test_path = Path(test) - if test_path in self._model_test_metadata_path_index: - filtered_tests.extend(self._model_test_metadata_path_index[test_path]) - test_meta = filtered_tests - - if patterns: - test_meta = filter_tests_by_patterns(test_meta, patterns) - - return test_meta - def _snapshots( self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None ) -> t.Dict[str, Snapshot]: @@ -3245,18 +3218,34 @@ def lint_models( return all_violations - def load_model_tests( - self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + def select_tests( + self, + tests: t.Optional[t.List[str]] = None, + patterns: t.Optional[t.List[str]] = None, ) -> t.List[ModelTestMetadata]: - # If a set of specific test path(s) are provided, we can use a single loader - # since it's not required to walk every tests/ folder in each repo - loaders = [self._loaders[0]] if tests else self._loaders + """Filter pre-loaded test metadata based on tests and patterns.""" + + test_meta = self._model_test_metadata + + if tests: + filtered_tests = [] + for test in tests: + if "::" in test: + if test in self._model_test_metadata_fully_qualified_name_index: + filtered_tests.append( + self._model_test_metadata_fully_qualified_name_index[test] + ) + else: + test_path = Path(test) + if test_path in self._model_test_metadata_path_index: + filtered_tests.extend(self._model_test_metadata_path_index[test_path]) + + test_meta = filtered_tests - model_tests = [] - for loader in loaders: - model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns)) + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) - return model_tests + return test_meta class Context(GenericContext[Config]): diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 5058f3a58a..c28822a154 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -130,7 +130,7 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]: class NoMissingUnitTest(Rule): - """All models must have a unit test found in the test/ directory yaml files""" + """All models must have a unit test found in the tests/ directory yaml files""" def check_model(self, model: Model) -> t.Optional[RuleViolation]: # External models cannot have unit tests diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index fda35ca75c..a43f5f28ff 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -35,7 +35,7 @@ from sqlmesh.core.model import model as model_registry from sqlmesh.core.model.common import make_python_env from sqlmesh.core.signal import signal -from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns +from sqlmesh.core.test import ModelTestMetadata from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor @@ -427,9 +427,7 @@ def _load_linting_rules(self) -> RuleSet: """Loads user linting rules""" return RuleSet() - def load_model_tests( - self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None - ) -> t.List[ModelTestMetadata]: + def load_model_tests(self) -> t.List[ModelTestMetadata]: """Loads YAML-based model tests""" return [] @@ -868,38 +866,23 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: return model_test_metadata - def load_model_tests( - self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None - ) -> t.List[ModelTestMetadata]: + def load_model_tests(self) -> t.List[ModelTestMetadata]: """Loads YAML-based model tests""" test_meta_list: t.List[ModelTestMetadata] = [] - if tests: - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + search_path = Path(self.config_path) / c.TESTS - test_meta = self._load_model_test_file(Path(filename)) - if test_name: - test_meta_list.append(test_meta[test_name]) - else: - test_meta_list.extend(test_meta.values()) - else: - search_path = Path(self.config_path) / c.TESTS - - for yaml_file in itertools.chain( - search_path.glob("**/test*.yaml"), - search_path.glob("**/test*.yml"), + for yaml_file in itertools.chain( + search_path.glob("**/test*.yaml"), + search_path.glob("**/test*.yml"), + ): + if any( + yaml_file.match(ignore_pattern) + for ignore_pattern in self.config.ignore_patterns or [] ): - if any( - yaml_file.match(ignore_pattern) - for ignore_pattern in self.config.ignore_patterns or [] - ): - continue - - test_meta_list.extend(self._load_model_test_file(yaml_file).values()) + continue - if patterns: - test_meta_list = filter_tests_by_patterns(test_meta_list, patterns) + test_meta_list.extend(self._load_model_test_file(yaml_file).values()) return test_meta_list diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 50265ec306..a94db7c421 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -72,7 +72,7 @@ def __init__(self, context: Context) -> None: def list_workspace_tests(self) -> t.List[TestEntry]: """List all tests in the workspace.""" - tests = self.context.load_model_tests() + tests = self.context.select_tests() # Use a set to ensure unique URIs unique_test_uris = {URI.from_path(test.path).value for test in tests} @@ -81,7 +81,9 @@ def list_workspace_tests(self) -> t.List[TestEntry]: test_ranges = get_test_ranges(URI(uri).to_path()) if uri not in test_uris: test_uris[uri] = {} + test_uris[uri].update(test_ranges) + return [ TestEntry( name=test.test_name, @@ -100,7 +102,7 @@ def get_document_tests(self, uri: URI) -> t.List[TestEntry]: Returns: List of TestEntry objects for the specified document. """ - tests = self.context.load_model_tests(tests=[str(uri.to_path())]) + tests = self.context.select_tests(tests=[str(uri.to_path())]) test_ranges = get_test_ranges(uri.to_path()) return [ TestEntry( diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 2b5f185aa9..0a433360df 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -337,7 +337,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = context.load_model_tests() + test_meta = context.select_tests() tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: