diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index ec88ecea14..2dbc9e4c45 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -109,9 +109,8 @@ from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, + ModelTestMetadata, generate_test, - get_all_model_tests, - run_model_tests, run_tests, ) from sqlmesh.core.user import User @@ -1786,47 +1785,20 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - if tests: - result = run_model_tests( - tests=tests, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - patterns=match_patterns, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) - else: - test_meta = [] - - for path, config in self.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - patterns=match_patterns, - ignore_patterns=config.ignore_patterns, - variables=config.variables, - ) - ) + test_meta = self.load_model_tests(tests=tests, patterns=match_patterns) - result = run_tests( - model_test_metadata=test_meta, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) - - return result + return run_tests( + model_test_metadata=test_meta, + models=self._models, + config=self.config, + selected_gateway=self.selected_gateway, + dialect=self.default_dialect, + verbosity=verbosity, + preserve_fixtures=preserve_fixtures, + stream=stream, + default_catalog=self.default_catalog, + default_catalog_dialect=self.engine_adapter.DIALECT, + ) @python_api_analytics def audit( @@ -2490,6 +2462,19 @@ def lint_models( "Linter detected errors in the code. Please fix them before proceeding." ) + def load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + # If a set of specific test path(s) are provided, we can use a single loader + # since it's not required to walk every tests/ folder in each repo + loaders = [self._loaders[0]] if tests else self._loaders + + model_tests = [] + for loader in loaders: + model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns)) + + return model_tests + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index edd4911156..bd27d1e752 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -2,9 +2,11 @@ import abc import glob +import itertools import linecache import logging import os +import re import typing as t from collections import Counter, defaultdict from dataclasses import dataclass @@ -31,11 +33,13 @@ from sqlmesh.core.model import model as model_registry from sqlmesh.core.model.common import make_python_env from sqlmesh.core.signal import signal +from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor from sqlmesh.utils.metaprogramming import import_python_file -from sqlmesh.utils.yaml import YAML +from sqlmesh.utils.yaml import YAML, load as yaml_load + if t.TYPE_CHECKING: from sqlmesh.core.context import GenericContext @@ -43,6 +47,8 @@ logger = logging.getLogger(__name__) +GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") + @dataclass class LoadedProject: @@ -290,6 +296,12 @@ def _load_linting_rules(self) -> RuleSet: """Loads user linting rules""" return RuleSet() + def load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + return [] + def _glob_paths( self, path: Path, @@ -678,6 +690,61 @@ def _load_linting_rules(self) -> RuleSet: return RuleSet(user_rules.values()) + def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: + """Load a single model test file.""" + model_test_metadata = {} + + with open(path, "r", encoding="utf-8") as file: + source = file.read() + # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to + # parse it as YAML to match the gateway name stored in the config + gateway_line = GATEWAY_PATTERN.search(source) + gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None + + contents = yaml_load(source, variables=self._get_variables(gateway)) + + for test_name, value in contents.items(): + model_test_metadata[test_name] = ModelTestMetadata( + path=path, test_name=test_name, body=value + ) + + return model_test_metadata + + def load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + test_meta_list: t.List[ModelTestMetadata] = [] + + if tests: + for test in tests: + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + + test_meta = self._load_model_test_file(Path(filename)) + if test_name: + test_meta_list.append(test_meta[test_name]) + else: + test_meta_list.extend(test_meta.values()) + else: + search_path = Path(self.config_path) / c.TESTS + + for yaml_file in itertools.chain( + search_path.glob("**/test*.yaml"), + search_path.glob("**/test*.yml"), + ): + if any( + yaml_file.match(ignore_pattern) + for ignore_pattern in self.config.ignore_patterns or [] + ): + continue + + test_meta_list.extend(self._load_model_test_file(yaml_file).values()) + + if patterns: + test_meta_list = filter_tests_by_patterns(test_meta_list, patterns) + + return test_meta_list + class _Cache(CacheBase): def __init__(self, loader: SqlMeshLoader, config_path: Path): self._loader = loader diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index c907aacf7c..6353370f45 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -4,11 +4,6 @@ from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, filter_tests_by_patterns as filter_tests_by_patterns, - get_all_model_tests as get_all_model_tests, - load_model_test_file as load_model_test_file, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult -from sqlmesh.core.test.runner import ( - run_model_tests as run_model_tests, - run_tests as run_tests, -) +from sqlmesh.core.test.runner import run_tests as run_tests diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index a7977d2036..0f60fe6fa9 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -4,13 +4,11 @@ import itertools import pathlib import typing as t -from collections.abc import Iterator import ruamel from sqlmesh.utils import unique from sqlmesh.utils.pydantic import PydanticModel -from sqlmesh.utils.yaml import load as yaml_load class ModelTestMetadata(PydanticModel): @@ -26,59 +24,6 @@ def __hash__(self) -> int: return self.fully_qualified_test_name.__hash__() -def load_model_test_file( - path: pathlib.Path, variables: dict[str, t.Any] | None = None -) -> dict[str, ModelTestMetadata]: - """Load a single model test file. - - Args: - path: The path to the test file - - returns: - A list of ModelTestMetadata named tuples. - """ - model_test_metadata = {} - contents = yaml_load(path, variables=variables) - - for test_name, value in contents.items(): - model_test_metadata[test_name] = ModelTestMetadata( - path=path, test_name=test_name, body=value - ) - return model_test_metadata - - -def discover_model_tests( - path: pathlib.Path, - ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, -) -> Iterator[ModelTestMetadata]: - """Discover model tests. - - Model tests are defined in YAML files and contain the inputs and outputs used to test model queries. - - Args: - path: A path to search for tests. - ignore_patterns: An optional list of patterns to ignore. - - Returns: - A list of ModelTestMetadata named tuples. - """ - search_path = pathlib.Path(path) - - for yaml_file in itertools.chain( - search_path.glob("**/test*.yaml"), - search_path.glob("**/test*.yml"), - ): - for ignore_pattern in ignore_patterns or []: - if yaml_file.match(ignore_pattern): - break - else: - for model_test_metadata in load_model_test_file( - yaml_file, variables=variables - ).values(): - yield model_test_metadata - - def filter_tests_by_patterns( tests: list[ModelTestMetadata], patterns: list[str] ) -> list[ModelTestMetadata]: @@ -97,19 +42,3 @@ def filter_tests_by_patterns( if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern)) or pattern in test.fully_qualified_test_name ) - - -def get_all_model_tests( - *paths: pathlib.Path, - patterns: list[str] | None = None, - ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, -) -> list[ModelTestMetadata]: - model_test_metadatas = [ - meta - for path in paths - for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables) - ] - if patterns: - model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns) - return model_test_metadatas diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index 304cb013aa..30144f8508 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -107,7 +107,6 @@ def log_test_report(self, test_duration: float) -> None: for _, error in errors: stream.writeln(unittest.TextTestResult.separator1) stream.writeln(f"ERROR: {error}") - stream.writeln(unittest.TextTestResult.separator2) # Output final report stream.writeln(unittest.TextTestResult.separator2) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 08a12bb59b..a070a6e34b 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -2,7 +2,6 @@ import sys import time -import pathlib import threading import typing as t import unittest @@ -15,9 +14,6 @@ from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, - filter_tests_by_patterns as filter_tests_by_patterns, - get_all_model_tests as get_all_model_tests, - load_model_test_file as load_model_test_file, ) from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig @@ -48,7 +44,7 @@ def __init__( def create_testing_engine_adapters( model_test_metadata: list[ModelTestMetadata], config: C, - default_gateway: str, + selected_gateway: str, default_catalog: str | None = None, default_catalog_dialect: str = "", ) -> t.Dict[ModelTestMetadata, EngineAdapter]: @@ -56,7 +52,7 @@ def create_testing_engine_adapters( metadata_to_adapter = {} for metadata in model_test_metadata: - gateway = metadata.body.get("gateway") or default_gateway + gateway = metadata.body.get("gateway") or selected_gateway test_connection = config.get_test_connection( gateway, default_catalog, default_catalog_dialect ) @@ -86,7 +82,7 @@ def run_tests( model_test_metadata: list[ModelTestMetadata], models: UniqueKeyDict[str, Model], config: C, - gateway: t.Optional[str] = None, + selected_gateway: str, dialect: str | None = None, verbosity: Verbosity = Verbosity.DEFAULT, preserve_fixtures: bool = False, @@ -102,10 +98,8 @@ def run_tests( verbosity: The verbosity level. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ - default_gateway = gateway or config.default_gateway_name - default_test_connection = config.get_test_connection( - gateway_name=default_gateway, + gateway_name=selected_gateway, default_catalog=default_catalog, default_catalog_dialect=default_catalog_dialect, ) @@ -121,14 +115,14 @@ def run_tests( metadata_to_adapter = create_testing_engine_adapters( model_test_metadata=model_test_metadata, config=config, - default_gateway=default_gateway, + selected_gateway=selected_gateway, default_catalog=default_catalog, default_catalog_dialect=default_catalog_dialect, ) def _run_single_test( metadata: ModelTestMetadata, engine_adapter: EngineAdapter - ) -> ModelTextTestResult: + ) -> t.Optional[ModelTextTestResult]: test = ModelTest.create_test( body=metadata.body, test_name=metadata.test_name, @@ -140,6 +134,9 @@ def _run_single_test( preserve_fixtures=preserve_fixtures, ) + if not test: + return None + result = t.cast( ModelTextTestResult, ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), @@ -155,6 +152,9 @@ def _run_single_test( elif result.skipped: skipped_args = result.skipped[0] combined_results.addSkip(skipped_args[0], skipped_args[1]) + + combined_results.testsRun += 1 + return result test_results = [] @@ -180,57 +180,6 @@ def _run_single_test( end_time = time.perf_counter() - combined_results.testsRun = len(test_results) - combined_results.log_test_report(test_duration=end_time - start_time) return combined_results - - -def run_model_tests( - tests: list[str], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - patterns: list[str] | None = None, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: t.Optional[str] = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Load and run tests. - - Args: - tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - patterns: A list of patterns to match against. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - loaded_tests = [] - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - path = pathlib.Path(filename) - - if test_name: - loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) - else: - loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) - - if patterns: - loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) - - return run_tests( - loaded_tests, - models, - config, - gateway=gateway, - dialect=dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 19eac1d51f..6bf28a8496 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -26,13 +26,12 @@ from rich.jupyter import JupyterRenderable from sqlmesh.cli.example_project import ProjectTemplate, init_example_project from sqlmesh.core import analytics -from sqlmesh.core import constants as c from sqlmesh.core.config import load_configs from sqlmesh.core.console import create_console, set_console, configure_console from sqlmesh.core.context import Context from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model -from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests +from sqlmesh.core.test import ModelTestMetadata from sqlmesh.utils import sqlglot_dialects, yaml, Verbosity from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError @@ -273,15 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = [] - - for path, config in context.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - ignore_patterns=config.ignore_patterns, - ) - ) + test_meta = context.load_model_tests() tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: diff --git a/tests/core/test_test.py b/tests/core/test_test.py index e50c3bdf06..df9ff3ea5e 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -32,6 +32,8 @@ from sqlmesh.utils.yaml import dump as dump_yaml from sqlmesh.utils.yaml import load as load_yaml +from tests.utils.test_helpers import use_terminal_console + if t.TYPE_CHECKING: from unittest import TestResult @@ -1560,36 +1562,28 @@ def execute(context, start, end, execution_time, **kwargs): def test_variable_usage(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") - config = Config( - default_connection=DuckDBConnectionConfig(), - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables={"gold": "gold_db", "silver": "silver_db"}, - ) - context = Context(paths=tmp_path, config=config) + variables = {"gold": "gold_db", "silver": "silver_db"} + incorrect_variables = {"gold": "foo", "silver": "bar"} parent = _create_model( "SELECT 1 AS id, '2022-01-02'::DATE AS ds, @start_ts AS start_ts", meta="MODEL (name silver_db.sch.b, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - parent = t.cast(SqlModel, context.upsert_model(parent)) child = _create_model( "SELECT ds, @IF(@VAR('myvar'), id, id + 1) AS id FROM silver_db.sch.b WHERE ds BETWEEN @start_ds and @end_ds", meta="MODEL (name gold_db.sch.a, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - child = t.cast(SqlModel, context.upsert_model(child)) - test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" - test_file.write_text( - """ + test_text = """ test_parameterized_model_names: - model: {{ var('gold') }}.sch.a + model: {{{{ var('gold') }}}}.sch.a {gateway} vars: myvar: True start_ds: 2022-01-01 end_ds: 2022-01-03 inputs: - {{ var('silver') }}.sch.b: + {{{{ var('silver') }}}}.sch.b: - ds: 2022-01-01 id: 1 - ds: 2022-01-01 @@ -1599,17 +1593,78 @@ def test_variable_usage(tmp_path: Path) -> None: - ds: 2022-01-01 id: 1 - ds: 2022-01-01 - id: 2 - """ + id: 2""" + + test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" + + def init_context_and_validate_results(config: Config, **kwargs): + context = Context(paths=tmp_path, config=config, **kwargs) + context.upsert_model(parent) + context.upsert_model(child) + + results = context.test() + + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + + # Case 1: Test root variables + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=variables, ) - results = context.test() + test_file.write_text(test_text.format(gateway="")) - assert not results.failures - assert not results.errors + init_context_and_validate_results(config) - # The example project has one test and we added another one above - assert len(results.successes) == 2 + # Case 2: Test gateway variables + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + init_context_and_validate_results(config) + + # Case 3: Test gateway variables overriding root variables + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=incorrect_variables, + ) + init_context_and_validate_results(config, gateway="main") + + # Case 4: Use variable from the defined gateway + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secondary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + test_file.write_text(test_text.format(gateway="\n gateway: secondary")) + init_context_and_validate_results(config, gateway="main") + + # Case 5: Use gateways with escaped characters + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secon\tdary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + test_file.write_text(test_text.format(gateway='\n gateway: "secon\\tdary"')) + init_context_and_validate_results(config, gateway="main") def test_custom_testing_schema(mocker: MockerFixture) -> None: @@ -2203,3 +2258,115 @@ def test_test_output(tmp_path: Path) -> None: assert "Ran 102 tests" in output.stderr assert "FAILED (failures=51)" in output.stderr + + +@use_terminal_console +def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + wrong_test_file = tmp_path / "tests" / "test_incorrect_model_name.yaml" + wrong_test_file.write_text( + """ +test_example_full_model: + model: invalid_model + description: This is an invalid test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 2 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + + with patch.object(get_console(), "log_warning") as mock_logger: + with capture_output() as output: + context.test() + + assert ( + f"""Model '"invalid_model"' was not found at {wrong_test_file}""" + in mock_logger.call_args[0][0] + ) + assert ( + ".\n----------------------------------------------------------------------\nRan 1 test in" + in output.stderr + ) + assert "OK" in output.stderr + + +def test_number_of_tests_found(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + # Example project contains 1 test and we add a new file with 2 tests + test_file = tmp_path / "tests" / "test_new.yaml" + test_file.write_text( + """ +test_example_full_model1: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + +test_example_full_model2: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + """ + ) + + context = Context(paths=tmp_path) + + # Case 1: All 3 tests should run without any tests specified + results = context.test() + assert len(results.successes) == 3 + + # Case 2: The "new_test.yaml" should amount to 2 subtests + results = context.test(tests=[f"{test_file}"]) + assert len(results.successes) == 2 + + # Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest + results = context.test(tests=[f"{test_file}::test_example_full_model2"]) + assert len(results.successes) == 1