From 870a5045c1ac469f0b66f0d0f671a959125e2a1d Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 9 Apr 2025 12:00:12 +0300 Subject: [PATCH 01/11] Fix: Propagate gateway variables to model tests --- sqlmesh/core/context.py | 64 +++++++----------- sqlmesh/core/test/__init__.py | 6 +- sqlmesh/core/test/discovery.py | 47 ++++++++++++++ sqlmesh/core/test/runner.py | 114 ++++++++++----------------------- tests/core/test_test.py | 86 ++++++++++++++++++++++++- 5 files changed, 190 insertions(+), 127 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index ec88ecea14..749e4302d3 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -110,8 +110,7 @@ from sqlmesh.core.test import ( ModelTextTestResult, generate_test, - get_all_model_tests, - run_model_tests, + load_model_tests, run_tests, ) from sqlmesh.core.user import User @@ -1786,47 +1785,30 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - if tests: - result = run_model_tests( - tests=tests, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - patterns=match_patterns, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) - else: - test_meta = [] - - for path, config in self.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - patterns=match_patterns, - ignore_patterns=config.ignore_patterns, - variables=config.variables, - ) - ) + default_gateway = self.gateway or self.config.default_gateway_name - result = run_tests( - model_test_metadata=test_meta, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) + # Merge the root variables with the gateway's variables + variables = {**self.config.variables, **self.config.get_gateway(default_gateway).variables} - return result + test_meta = load_model_tests( + configs=self.configs, + tests=tests, + patterns=match_patterns, + variables=variables, + ) + + return run_tests( + model_test_metadata=test_meta, + models=self._models, + config=self.config, + default_gateway=default_gateway, + dialect=self.default_dialect, + verbosity=verbosity, + preserve_fixtures=preserve_fixtures, + stream=stream, + default_catalog=self.default_catalog, + default_catalog_dialect=self.engine_adapter.DIALECT, + ) @python_api_analytics def audit( diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index c907aacf7c..3e1f5b0c81 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -6,9 +6,7 @@ filter_tests_by_patterns as filter_tests_by_patterns, get_all_model_tests as get_all_model_tests, load_model_test_file as load_model_test_file, + load_model_tests as load_model_tests, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult -from sqlmesh.core.test.runner import ( - run_model_tests as run_model_tests, - run_tests as run_tests, -) +from sqlmesh.core.test.runner import run_tests as run_tests diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index a7977d2036..8c4d544696 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -8,10 +8,14 @@ import ruamel +from sqlmesh.core import constants as c from sqlmesh.utils import unique from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.utils.yaml import load as yaml_load +if t.TYPE_CHECKING: + from sqlmesh.core.config.loader import C + class ModelTestMetadata(PydanticModel): path: pathlib.Path @@ -113,3 +117,46 @@ def get_all_model_tests( if patterns: model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns) return model_test_metadatas + + +def load_model_tests( + configs: dict[pathlib.Path, C], + tests: t.Optional[t.List[str]] = None, + patterns: list[str] | None = None, + variables: dict[str, t.Any] | None = None, +) -> list[ModelTestMetadata]: + """Load model tests into a list of ModelTestMetadata which will be propagated to the test runner. + + Args: + tests: A list of tests to load; If not specified, all tests are loaded + patterns: A list of patterns to match against. + variables: A dictionary of variables to use when loading the tests. + configs: A dictionary of configs to use when loading all the tests. + """ + test_meta = [] + + if tests: + for test in tests: + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + + test_file = load_model_test_file(pathlib.Path(filename), variables=variables) + if test_name: + test_meta.append(test_file[test_name]) + else: + test_meta.extend(test_file.values()) + + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) + + else: + for path, config in configs.items(): + test_meta.extend( + get_all_model_tests( + path / c.TESTS, + patterns=patterns, + ignore_patterns=config.ignore_patterns, + variables=variables, + ) + ) + + return test_meta diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 08a12bb59b..3a6286ff7a 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -2,7 +2,6 @@ import sys import time -import pathlib import threading import typing as t import unittest @@ -86,7 +85,7 @@ def run_tests( model_test_metadata: list[ModelTestMetadata], models: UniqueKeyDict[str, Model], config: C, - gateway: t.Optional[str] = None, + default_gateway: str, dialect: str | None = None, verbosity: Verbosity = Verbosity.DEFAULT, preserve_fixtures: bool = False, @@ -102,8 +101,6 @@ def run_tests( verbosity: The verbosity level. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ - default_gateway = gateway or config.default_gateway_name - default_test_connection = config.get_test_connection( gateway_name=default_gateway, default_catalog=default_catalog, @@ -128,34 +125,40 @@ def run_tests( def _run_single_test( metadata: ModelTestMetadata, engine_adapter: EngineAdapter - ) -> ModelTextTestResult: - test = ModelTest.create_test( - body=metadata.body, - test_name=metadata.test_name, - models=models, - engine_adapter=engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) + ) -> t.Optional[ModelTextTestResult]: + result: t.Optional[ModelTextTestResult] = None + try: + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + ) - result = t.cast( - ModelTextTestResult, - ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), - ) + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.original_err[0], result.original_err[1]) + elif result.failures: + combined_results.addFailure(result.original_err[0], result.original_err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) - with lock: - if result.successes: - combined_results.addSuccess(result.successes[0]) - elif result.errors: - combined_results.addError(result.original_err[0], result.original_err[1]) - elif result.failures: - combined_results.addFailure(result.original_err[0], result.original_err[1]) - elif result.skipped: - skipped_args = result.skipped[0] - combined_results.addSkip(skipped_args[0], skipped_args[1]) - return result + combined_results.testsRun += 1 + + finally: + return result test_results = [] @@ -180,57 +183,6 @@ def _run_single_test( end_time = time.perf_counter() - combined_results.testsRun = len(test_results) - combined_results.log_test_report(test_duration=end_time - start_time) return combined_results - - -def run_model_tests( - tests: list[str], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - patterns: list[str] | None = None, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: t.Optional[str] = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Load and run tests. - - Args: - tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - patterns: A list of patterns to match against. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - loaded_tests = [] - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - path = pathlib.Path(filename) - - if test_name: - loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) - else: - loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) - - if patterns: - loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) - - return run_tests( - loaded_tests, - models, - config, - gateway=gateway, - dialect=dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) diff --git a/tests/core/test_test.py b/tests/core/test_test.py index e50c3bdf06..85f687a1f8 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -32,6 +32,8 @@ from sqlmesh.utils.yaml import dump as dump_yaml from sqlmesh.utils.yaml import load as load_yaml +from tests.utils.test_helpers import use_terminal_console + if t.TYPE_CHECKING: from unittest import TestResult @@ -1560,10 +1562,13 @@ def execute(context, start, end, execution_time, **kwargs): def test_variable_usage(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") + variables = {"gold": "gold_db", "silver": "silver_db"} + + # Case 1: Test root variables config = Config( default_connection=DuckDBConnectionConfig(), model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables={"gold": "gold_db", "silver": "silver_db"}, + variables=variables, ) context = Context(paths=tmp_path, config=config) @@ -1611,6 +1616,35 @@ def test_variable_usage(tmp_path: Path) -> None: # The example project has one test and we added another one above assert len(results.successes) == 2 + # Case 2: Test gateway variables + context.config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + results = context.test() + + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + + # Case 3: Test gateway variables overriding root variables + context.config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"gold": "foo", "silver": "bar"}, + ) + + results = context.test() + + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + def test_custom_testing_schema(mocker: MockerFixture) -> None: test = _create_test( @@ -2203,3 +2237,53 @@ def test_test_output(tmp_path: Path) -> None: assert "Ran 102 tests" in output.stderr assert "FAILED (failures=51)" in output.stderr + + +@use_terminal_console +def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + wrong_test_file = tmp_path / "tests" / "test_incorrect_model_name.yaml" + wrong_test_file.write_text( + """ +test_example_full_model: + model: invalid_model + description: This is an invalid test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 2 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + + with patch.object(get_console(), "log_warning") as mock_logger: + with capture_output() as output: + context.test() + + assert ( + f"""Model '"invalid_model"' was not found at {wrong_test_file}""" + in mock_logger.call_args[0][0] + ) + assert ( + ".\n----------------------------------------------------------------------\nRan 1 test in" + in output.stderr + ) + assert "OK" in output.stderr From b148c71f2d1667d96a6d51752512f938cb2dfc21 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 9 Apr 2025 19:05:01 +0300 Subject: [PATCH 02/11] PR Feedback 1 --- sqlmesh/core/context.py | 9 ++--- sqlmesh/core/test/runner.py | 72 ++++++++++++++++++------------------- tests/core/test_test.py | 4 +-- 3 files changed, 43 insertions(+), 42 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 749e4302d3..140a48ec78 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1785,10 +1785,11 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - default_gateway = self.gateway or self.config.default_gateway_name - # Merge the root variables with the gateway's variables - variables = {**self.config.variables, **self.config.get_gateway(default_gateway).variables} + variables = { + **self.config.variables, + **self.config.get_gateway(self.selected_gateway).variables, + } test_meta = load_model_tests( configs=self.configs, @@ -1801,7 +1802,7 @@ def test( model_test_metadata=test_meta, models=self._models, config=self.config, - default_gateway=default_gateway, + selected_gateway=self.selected_gateway, dialect=self.default_dialect, verbosity=verbosity, preserve_fixtures=preserve_fixtures, diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 3a6286ff7a..04b7af7461 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -47,7 +47,7 @@ def __init__( def create_testing_engine_adapters( model_test_metadata: list[ModelTestMetadata], config: C, - default_gateway: str, + selected_gateway: str, default_catalog: str | None = None, default_catalog_dialect: str = "", ) -> t.Dict[ModelTestMetadata, EngineAdapter]: @@ -55,7 +55,7 @@ def create_testing_engine_adapters( metadata_to_adapter = {} for metadata in model_test_metadata: - gateway = metadata.body.get("gateway") or default_gateway + gateway = metadata.body.get("gateway") or selected_gateway test_connection = config.get_test_connection( gateway, default_catalog, default_catalog_dialect ) @@ -85,7 +85,7 @@ def run_tests( model_test_metadata: list[ModelTestMetadata], models: UniqueKeyDict[str, Model], config: C, - default_gateway: str, + selected_gateway: str, dialect: str | None = None, verbosity: Verbosity = Verbosity.DEFAULT, preserve_fixtures: bool = False, @@ -102,7 +102,7 @@ def run_tests( preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ default_test_connection = config.get_test_connection( - gateway_name=default_gateway, + gateway_name=selected_gateway, default_catalog=default_catalog, default_catalog_dialect=default_catalog_dialect, ) @@ -118,7 +118,7 @@ def run_tests( metadata_to_adapter = create_testing_engine_adapters( model_test_metadata=model_test_metadata, config=config, - default_gateway=default_gateway, + selected_gateway=selected_gateway, default_catalog=default_catalog, default_catalog_dialect=default_catalog_dialect, ) @@ -126,39 +126,39 @@ def run_tests( def _run_single_test( metadata: ModelTestMetadata, engine_adapter: EngineAdapter ) -> t.Optional[ModelTextTestResult]: - result: t.Optional[ModelTextTestResult] = None - try: - test = ModelTest.create_test( - body=metadata.body, - test_name=metadata.test_name, - models=models, - engine_adapter=engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + ) - result = t.cast( - ModelTextTestResult, - ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), - ) + if not test: + return None + + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.original_err[0], result.original_err[1]) + elif result.failures: + combined_results.addFailure(result.original_err[0], result.original_err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) + + combined_results.testsRun += 1 - with lock: - if result.successes: - combined_results.addSuccess(result.successes[0]) - elif result.errors: - combined_results.addError(result.original_err[0], result.original_err[1]) - elif result.failures: - combined_results.addFailure(result.original_err[0], result.original_err[1]) - elif result.skipped: - skipped_args = result.skipped[0] - combined_results.addSkip(skipped_args[0], skipped_args[1]) - - combined_results.testsRun += 1 - - finally: - return result + return result test_results = [] diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 85f687a1f8..c2c6e64e79 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -1619,7 +1619,7 @@ def test_variable_usage(tmp_path: Path) -> None: # Case 2: Test gateway variables context.config = Config( gateways={ - "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + "": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), }, model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) @@ -1633,7 +1633,7 @@ def test_variable_usage(tmp_path: Path) -> None: # Case 3: Test gateway variables overriding root variables context.config = Config( gateways={ - "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + "": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), }, model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"gold": "foo", "silver": "bar"}, From be0f0fb85d7f35e74194410f8df747b66a3cf105 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 10 Apr 2025 11:16:28 +0300 Subject: [PATCH 03/11] Load variables from gateway specified in YAML --- sqlmesh/core/context.py | 9 +--- sqlmesh/core/test/__init__.py | 3 -- sqlmesh/core/test/discovery.py | 60 ++++++++++-------------- sqlmesh/core/test/runner.py | 3 -- sqlmesh/magics.py | 13 +----- sqlmesh/utils/yaml.py | 9 ++++ tests/core/test_test.py | 84 ++++++++++++++++++++++++++++------ 7 files changed, 104 insertions(+), 77 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 140a48ec78..226a8861d0 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1785,17 +1785,10 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - # Merge the root variables with the gateway's variables - variables = { - **self.config.variables, - **self.config.get_gateway(self.selected_gateway).variables, - } - test_meta = load_model_tests( - configs=self.configs, + loaders=self._loaders, tests=tests, patterns=match_patterns, - variables=variables, ) return run_tests( diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 3e1f5b0c81..909bcb34da 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -3,9 +3,6 @@ from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, - filter_tests_by_patterns as filter_tests_by_patterns, - get_all_model_tests as get_all_model_tests, - load_model_test_file as load_model_test_file, load_model_tests as load_model_tests, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index 8c4d544696..52783a84ff 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -14,7 +14,7 @@ from sqlmesh.utils.yaml import load as yaml_load if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C + from sqlmesh.core.loader import Loader class ModelTestMetadata(PydanticModel): @@ -31,7 +31,8 @@ def __hash__(self) -> int: def load_model_test_file( - path: pathlib.Path, variables: dict[str, t.Any] | None = None + path: pathlib.Path, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], ) -> dict[str, ModelTestMetadata]: """Load a single model test file. @@ -42,7 +43,7 @@ def load_model_test_file( A list of ModelTestMetadata named tuples. """ model_test_metadata = {} - contents = yaml_load(path, variables=variables) + contents = yaml_load(path, get_variables=get_variables) for test_name, value in contents.items(): model_test_metadata[test_name] = ModelTestMetadata( @@ -53,8 +54,8 @@ def load_model_test_file( def discover_model_tests( path: pathlib.Path, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, ) -> Iterator[ModelTestMetadata]: """Discover model tests. @@ -78,7 +79,7 @@ def discover_model_tests( break else: for model_test_metadata in load_model_test_file( - yaml_file, variables=variables + yaml_file, get_variables=get_variables ).values(): yield model_test_metadata @@ -103,34 +104,16 @@ def filter_tests_by_patterns( ) -def get_all_model_tests( - *paths: pathlib.Path, - patterns: list[str] | None = None, - ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, -) -> list[ModelTestMetadata]: - model_test_metadatas = [ - meta - for path in paths - for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables) - ] - if patterns: - model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns) - return model_test_metadatas - - def load_model_tests( - configs: dict[pathlib.Path, C], + loaders: list[Loader], tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, ) -> list[ModelTestMetadata]: """Load model tests into a list of ModelTestMetadata which will be propagated to the test runner. Args: tests: A list of tests to load; If not specified, all tests are loaded - patterns: A list of patterns to match against. - variables: A dictionary of variables to use when loading the tests. + patterns: A list of patterns that'll be used to filter tests by file name. configs: A dictionary of configs to use when loading all the tests. """ test_meta = [] @@ -139,24 +122,27 @@ def load_model_tests( for test in tests: filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - test_file = load_model_test_file(pathlib.Path(filename), variables=variables) + test_file = load_model_test_file( + pathlib.Path(filename), get_variables=loaders[0]._get_variables + ) if test_name: test_meta.append(test_file[test_name]) else: test_meta.extend(test_file.values()) - - if patterns: - test_meta = filter_tests_by_patterns(test_meta, patterns) - else: - for path, config in configs.items(): + for loader in loaders: test_meta.extend( - get_all_model_tests( - path / c.TESTS, - patterns=patterns, - ignore_patterns=config.ignore_patterns, - variables=variables, - ) + [ + meta + for meta in discover_model_tests( + pathlib.Path(loader.config_path / c.TESTS), + ignore_patterns=loader.config.ignore_patterns, # type: ignore + get_variables=loader._get_variables, + ) + ] ) + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) + return test_meta diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 04b7af7461..a070a6e34b 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -14,9 +14,6 @@ from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, - filter_tests_by_patterns as filter_tests_by_patterns, - get_all_model_tests as get_all_model_tests, - load_model_test_file as load_model_test_file, ) from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 19eac1d51f..c7178fcd23 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -26,13 +26,12 @@ from rich.jupyter import JupyterRenderable from sqlmesh.cli.example_project import ProjectTemplate, init_example_project from sqlmesh.core import analytics -from sqlmesh.core import constants as c from sqlmesh.core.config import load_configs from sqlmesh.core.console import create_console, set_console, configure_console from sqlmesh.core.context import Context from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model -from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests +from sqlmesh.core.test import ModelTestMetadata, load_model_tests from sqlmesh.utils import sqlglot_dialects, yaml, Verbosity from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError @@ -273,15 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = [] - - for path, config in context.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - ignore_patterns=config.ignore_patterns, - ) - ) + test_meta = load_model_tests(loaders=context._loaders) tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index 549d849902..768fc750f3 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -5,6 +5,7 @@ from decimal import Decimal from os import getenv from pathlib import Path +import re from ruamel import yaml @@ -17,6 +18,9 @@ } +gateway_pattern = re.compile(r"gateway:\s*([^\s]+)") + + def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML: yaml_obj = yaml.YAML(typ=typ) @@ -36,6 +40,7 @@ def load( render_jinja: bool = True, allow_duplicate_keys: bool = False, variables: t.Optional[t.Dict[str, t.Any]] = None, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None, ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None @@ -45,6 +50,10 @@ def load( with open(source, "r", encoding="utf-8") as file: source = file.read() + if get_variables: + gateway = gateway_pattern.search(source) + variables = get_variables(gateway.group(1) if gateway else None) + if render_jinja: source = ENVIRONMENT.from_string(source).render( { diff --git a/tests/core/test_test.py b/tests/core/test_test.py index c2c6e64e79..8d87077de9 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -1563,26 +1563,32 @@ def test_variable_usage(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") variables = {"gold": "gold_db", "silver": "silver_db"} - - # Case 1: Test root variables - config = Config( - default_connection=DuckDBConnectionConfig(), - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables=variables, - ) - context = Context(paths=tmp_path, config=config) + incorrect_variables = {"gold": "foo", "silver": "bar"} parent = _create_model( "SELECT 1 AS id, '2022-01-02'::DATE AS ds, @start_ts AS start_ts", meta="MODEL (name silver_db.sch.b, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - parent = t.cast(SqlModel, context.upsert_model(parent)) child = _create_model( "SELECT ds, @IF(@VAR('myvar'), id, id + 1) AS id FROM silver_db.sch.b WHERE ds BETWEEN @start_ds and @end_ds", meta="MODEL (name gold_db.sch.a, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - child = t.cast(SqlModel, context.upsert_model(child)) + + def init_context(config: Config, **kwargs): + context = Context(paths=tmp_path, config=config, **kwargs) + context.upsert_model(parent) + context.upsert_model(child) + return context + + # Case 1: Test root variables + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=variables, + ) + + context = init_context(config) test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" test_file.write_text( @@ -1617,13 +1623,15 @@ def test_variable_usage(tmp_path: Path) -> None: assert len(results.successes) == 2 # Case 2: Test gateway variables - context.config = Config( + config = Config( gateways={ - "": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), }, model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) + context = init_context(config, gateway="main") + results = context.test() assert not results.failures @@ -1631,14 +1639,60 @@ def test_variable_usage(tmp_path: Path) -> None: assert len(results.successes) == 2 # Case 3: Test gateway variables overriding root variables - context.config = Config( + config = Config( gateways={ - "": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), }, model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables={"gold": "foo", "silver": "bar"}, + variables=incorrect_variables, ) + context = init_context(config, gateway="main") + + results = context.test() + + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + + # Case 4: Use variable from the defined gateway + test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" + test_file.write_text( + """ +test_parameterized_model_names: + model: {{ var('gold') }}.sch.a + gateway: secondary + vars: + myvar: True + start_ds: 2022-01-01 + end_ds: 2022-01-03 + inputs: + {{ var('silver') }}.sch.b: + - ds: 2022-01-01 + id: 1 + - ds: 2022-01-01 + id: 2 + outputs: + query: + - ds: 2022-01-01 + id: 1 + - ds: 2022-01-01 + id: 2 + """ + ) + + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secondary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + context = init_context(config, gateway="main") + results = context.test() assert not results.failures From 9cfa8f6732f201c7fe077fddc22290e24f9ddace Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 10 Apr 2025 15:29:35 +0300 Subject: [PATCH 04/11] Move get_variables up to Context --- sqlmesh/core/context.py | 30 ++++++++++++++++++++++++++- sqlmesh/core/loader.py | 22 +------------------- sqlmesh/core/test/discovery.py | 37 +++++++++++++++++----------------- sqlmesh/magics.py | 2 +- sqlmesh/utils/yaml.py | 13 ++++++++---- 5 files changed, 59 insertions(+), 45 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 226a8861d0..d00ca84d3b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -364,6 +364,7 @@ def __init__( self._excluded_requirements: t.Set[str] = set() self._default_catalog: t.Optional[str] = None self._linters: t.Dict[str, Linter] = {} + self._variables_by_project_gateway: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {} self._loaded: bool = False self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -1786,9 +1787,10 @@ def test( pd.set_option("display.max_columns", None) test_meta = load_model_tests( - loaders=self._loaders, + configs=self.configs, tests=tests, patterns=match_patterns, + get_variables=self._get_variables, ) return run_tests( @@ -2466,6 +2468,32 @@ def lint_models( "Linter detected errors in the code. Please fix them before proceeding." ) + def _get_variables( + self, config: t.Optional[C] = None, gateway_name: t.Optional[str] = None + ) -> t.Dict[str, t.Any]: + config = config or self.config + gateway_name = gateway_name or self.selected_gateway + + key = (config.project, gateway_name) + if key not in self._variables_by_project_gateway: + try: + gateway = config.get_gateway(gateway_name) + except ConfigError: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Gateway '{gateway_name}' not found in project '{config.project}'." + ) + gateway = None + + self._variables_by_project_gateway[key] = { + **config.variables, + **(gateway.variables if gateway else {}), + c.GATEWAY: gateway_name, + } + + return self._variables_by_project_gateway[key] + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index edd4911156..0fe723a84c 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -74,7 +74,6 @@ def __init__(self, context: GenericContext, path: Path) -> None: self.context = context self.config_path = path self.config = self.context.configs[self.config_path] - self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {} def load(self) -> LoadedProject: """ @@ -329,26 +328,7 @@ def _track_file(self, path: Path) -> None: self._path_mtimes[path] = path.stat().st_mtime def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]: - gateway_name = gateway_name or self.context.selected_gateway - - if gateway_name not in self._variables_by_gateway: - try: - gateway = self.config.get_gateway(gateway_name) - except ConfigError: - from sqlmesh.core.console import get_console - - get_console().log_warning( - f"Gateway '{gateway_name}' not found in project '{self.config.project}'." - ) - gateway = None - - self._variables_by_gateway[gateway_name] = { - **self.config.variables, - **(gateway.variables if gateway else {}), - c.GATEWAY: gateway_name, - } - - return self._variables_by_gateway[gateway_name] + return self.context._get_variables(config=self.config, gateway_name=gateway_name) class SqlMeshLoader(Loader): diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index 52783a84ff..4b8d64db05 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -14,7 +14,7 @@ from sqlmesh.utils.yaml import load as yaml_load if t.TYPE_CHECKING: - from sqlmesh.core.loader import Loader + from sqlmesh.core.config.loader import C class ModelTestMetadata(PydanticModel): @@ -32,7 +32,8 @@ def __hash__(self) -> int: def load_model_test_file( path: pathlib.Path, - get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + config: C, + get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], ) -> dict[str, ModelTestMetadata]: """Load a single model test file. @@ -43,7 +44,7 @@ def load_model_test_file( A list of ModelTestMetadata named tuples. """ model_test_metadata = {} - contents = yaml_load(path, get_variables=get_variables) + contents = yaml_load(path, config=config, get_variables=get_variables) for test_name, value in contents.items(): model_test_metadata[test_name] = ModelTestMetadata( @@ -54,8 +55,8 @@ def load_model_test_file( def discover_model_tests( path: pathlib.Path, - get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], - ignore_patterns: list[str] | None = None, + config: C, + get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], ) -> Iterator[ModelTestMetadata]: """Discover model tests. @@ -74,12 +75,12 @@ def discover_model_tests( search_path.glob("**/test*.yaml"), search_path.glob("**/test*.yml"), ): - for ignore_pattern in ignore_patterns or []: + for ignore_pattern in config.ignore_patterns or []: if yaml_file.match(ignore_pattern): break else: for model_test_metadata in load_model_test_file( - yaml_file, get_variables=get_variables + yaml_file, config=config, get_variables=get_variables ).values(): yield model_test_metadata @@ -105,7 +106,8 @@ def filter_tests_by_patterns( def load_model_tests( - loaders: list[Loader], + configs: t.Dict[pathlib.Path, C], + get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None, ) -> list[ModelTestMetadata]: @@ -123,23 +125,22 @@ def load_model_tests( filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") test_file = load_model_test_file( - pathlib.Path(filename), get_variables=loaders[0]._get_variables + pathlib.Path(filename), + config=next(iter(configs.values())), + get_variables=get_variables, ) if test_name: test_meta.append(test_file[test_name]) else: test_meta.extend(test_file.values()) else: - for loader in loaders: + for path, config in configs.items(): test_meta.extend( - [ - meta - for meta in discover_model_tests( - pathlib.Path(loader.config_path / c.TESTS), - ignore_patterns=loader.config.ignore_patterns, # type: ignore - get_variables=loader._get_variables, - ) - ] + discover_model_tests( + pathlib.Path(path / c.TESTS), + config=config, + get_variables=get_variables, + ) ) if patterns: diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index c7178fcd23..7463656a79 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -272,7 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = load_model_tests(loaders=context._loaders) + test_meta = load_model_tests(configs=context.configs, get_variables=context._get_variables) tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index 768fc750f3..9fd7af4152 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -13,12 +13,16 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.jinja import ENVIRONMENT, create_var +if t.TYPE_CHECKING: + from sqlmesh.core.config.loader import C + + JINJA_METHODS = { "env_var": lambda key, default=None: getenv(key, default), } -gateway_pattern = re.compile(r"gateway:\s*([^\s]+)") +GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML: @@ -40,7 +44,8 @@ def load( render_jinja: bool = True, allow_duplicate_keys: bool = False, variables: t.Optional[t.Dict[str, t.Any]] = None, - get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None, + config: t.Optional[C] = None, + get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]] | None = None, ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None @@ -51,8 +56,8 @@ def load( source = file.read() if get_variables: - gateway = gateway_pattern.search(source) - variables = get_variables(gateway.group(1) if gateway else None) + gateway = GATEWAY_PATTERN.search(source) + variables = get_variables(config, gateway.group(1) if gateway else None) if render_jinja: source = ENVIRONMENT.from_string(source).render( From 860f7d7836f1b9d7ea1e881f3bf59631cd5fcf34 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 10 Apr 2025 20:27:19 +0300 Subject: [PATCH 05/11] Handle special characters in gateway name --- sqlmesh/utils/yaml.py | 10 ++-- tests/core/test_test.py | 115 ++++++++++++++-------------------------- 2 files changed, 48 insertions(+), 77 deletions(-) diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index 9fd7af4152..c453d557fd 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -49,6 +49,7 @@ def load( ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None + yaml = YAML() if isinstance(source, Path): path = source @@ -56,8 +57,12 @@ def load( source = file.read() if get_variables: - gateway = GATEWAY_PATTERN.search(source) - variables = get_variables(config, gateway.group(1) if gateway else None) + # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to + # parse it as YAML to match the gateway name stored in the config + gateway_line = GATEWAY_PATTERN.search(source) + gateway = yaml.load(gateway_line.group(0))["gateway"] if gateway_line else None + + variables = get_variables(config, gateway) if render_jinja: source = ENVIRONMENT.from_string(source).render( @@ -67,7 +72,6 @@ def load( } ) - yaml = YAML() yaml.allow_duplicate_keys = allow_duplicate_keys contents = yaml.load(source) if contents is None: diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 8d87077de9..c358772ef1 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -1575,32 +1575,15 @@ def test_variable_usage(tmp_path: Path) -> None: meta="MODEL (name gold_db.sch.a, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - def init_context(config: Config, **kwargs): - context = Context(paths=tmp_path, config=config, **kwargs) - context.upsert_model(parent) - context.upsert_model(child) - return context - - # Case 1: Test root variables - config = Config( - default_connection=DuckDBConnectionConfig(), - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables=variables, - ) - - context = init_context(config) - - test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" - test_file.write_text( - """ + test_text = """ test_parameterized_model_names: - model: {{ var('gold') }}.sch.a + model: {{{{ var('gold') }}}}.sch.a {gateway} vars: myvar: True start_ds: 2022-01-01 end_ds: 2022-01-03 inputs: - {{ var('silver') }}.sch.b: + {{{{ var('silver') }}}}.sch.b: - ds: 2022-01-01 id: 1 - ds: 2022-01-01 @@ -1610,17 +1593,31 @@ def init_context(config: Config, **kwargs): - ds: 2022-01-01 id: 1 - ds: 2022-01-01 - id: 2 - """ - ) + id: 2""" - results = context.test() + test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" + + def init_context_and_validate_results(config: Config, **kwargs): + context = Context(paths=tmp_path, config=config, **kwargs) + context.upsert_model(parent) + context.upsert_model(child) + + results = context.test() + + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + + # Case 1: Test root variables + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=variables, + ) - assert not results.failures - assert not results.errors + test_file.write_text(test_text.format(gateway="")) - # The example project has one test and we added another one above - assert len(results.successes) == 2 + init_context_and_validate_results(config) # Case 2: Test gateway variables config = Config( @@ -1629,14 +1626,7 @@ def init_context(config: Config, **kwargs): }, model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - - context = init_context(config, gateway="main") - - results = context.test() - - assert not results.failures - assert not results.errors - assert len(results.successes) == 2 + init_context_and_validate_results(config) # Case 3: Test gateway variables overriding root variables config = Config( @@ -1646,41 +1636,9 @@ def init_context(config: Config, **kwargs): model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables=incorrect_variables, ) - - context = init_context(config, gateway="main") - - results = context.test() - - assert not results.failures - assert not results.errors - assert len(results.successes) == 2 + init_context_and_validate_results(config, gateway="main") # Case 4: Use variable from the defined gateway - test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" - test_file.write_text( - """ -test_parameterized_model_names: - model: {{ var('gold') }}.sch.a - gateway: secondary - vars: - myvar: True - start_ds: 2022-01-01 - end_ds: 2022-01-03 - inputs: - {{ var('silver') }}.sch.b: - - ds: 2022-01-01 - id: 1 - - ds: 2022-01-01 - id: 2 - outputs: - query: - - ds: 2022-01-01 - id: 1 - - ds: 2022-01-01 - id: 2 - """ - ) - config = Config( gateways={ "main": GatewayConfig( @@ -1691,13 +1649,22 @@ def init_context(config: Config, **kwargs): model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - context = init_context(config, gateway="main") + test_file.write_text(test_text.format(gateway="\n gateway: secondary")) + init_context_and_validate_results(config, gateway="main") - results = context.test() + # Case 5: Use gateways with escaped characters + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secon\tdary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) - assert not results.failures - assert not results.errors - assert len(results.successes) == 2 + test_file.write_text(test_text.format(gateway='\n gateway: "secon\\tdary"')) + init_context_and_validate_results(config, gateway="main") def test_custom_testing_schema(mocker: MockerFixture) -> None: From 552d50d89cb66b359ab85e42264ee4b3464007ba Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 11 Apr 2025 12:04:55 +0300 Subject: [PATCH 06/11] Move test loading logic to Loader --- sqlmesh/core/context.py | 40 +++---------- sqlmesh/core/loader.py | 80 ++++++++++++++++++++++++- sqlmesh/core/test/__init__.py | 2 +- sqlmesh/core/test/discovery.py | 105 --------------------------------- sqlmesh/core/test/result.py | 1 - sqlmesh/magics.py | 4 +- sqlmesh/utils/yaml.py | 9 +-- tests/core/test_test.py | 64 ++++++++++++++++++++ 8 files changed, 156 insertions(+), 149 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index d00ca84d3b..1ec32c8ef5 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -109,8 +109,8 @@ from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, + ModelTestMetadata, generate_test, - load_model_tests, run_tests, ) from sqlmesh.core.user import User @@ -364,7 +364,6 @@ def __init__( self._excluded_requirements: t.Set[str] = set() self._default_catalog: t.Optional[str] = None self._linters: t.Dict[str, Linter] = {} - self._variables_by_project_gateway: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {} self._loaded: bool = False self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) @@ -1786,12 +1785,7 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - test_meta = load_model_tests( - configs=self.configs, - tests=tests, - patterns=match_patterns, - get_variables=self._get_variables, - ) + test_meta = self._load_model_tests(tests=tests, patterns=match_patterns) return run_tests( model_test_metadata=test_meta, @@ -2468,31 +2462,15 @@ def lint_models( "Linter detected errors in the code. Please fix them before proceeding." ) - def _get_variables( - self, config: t.Optional[C] = None, gateway_name: t.Optional[str] = None - ) -> t.Dict[str, t.Any]: - config = config or self.config - gateway_name = gateway_name or self.selected_gateway - - key = (config.project, gateway_name) - if key not in self._variables_by_project_gateway: - try: - gateway = config.get_gateway(gateway_name) - except ConfigError: - from sqlmesh.core.console import get_console + def _load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + model_tests = [] - get_console().log_warning( - f"Gateway '{gateway_name}' not found in project '{config.project}'." - ) - gateway = None - - self._variables_by_project_gateway[key] = { - **config.variables, - **(gateway.variables if gateway else {}), - c.GATEWAY: gateway_name, - } + for loader in self._loaders: + model_tests.extend(loader._load_model_tests(tests=tests, patterns=patterns)) - return self._variables_by_project_gateway[key] + return model_tests class Context(GenericContext[Config]): diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 0fe723a84c..5fa73fddc2 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -2,6 +2,7 @@ import abc import glob +import itertools import linecache import logging import os @@ -31,11 +32,13 @@ from sqlmesh.core.model import model as model_registry from sqlmesh.core.model.common import make_python_env from sqlmesh.core.signal import signal +from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor from sqlmesh.utils.metaprogramming import import_python_file -from sqlmesh.utils.yaml import YAML +from sqlmesh.utils.yaml import YAML, load as yaml_load + if t.TYPE_CHECKING: from sqlmesh.core.context import GenericContext @@ -74,6 +77,7 @@ def __init__(self, context: GenericContext, path: Path) -> None: self.context = context self.config_path = path self.config = self.context.configs[self.config_path] + self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {} def load(self) -> LoadedProject: """ @@ -289,6 +293,12 @@ def _load_linting_rules(self) -> RuleSet: """Loads user linting rules""" return RuleSet() + def _load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + return [] + def _glob_paths( self, path: Path, @@ -328,7 +338,26 @@ def _track_file(self, path: Path) -> None: self._path_mtimes[path] = path.stat().st_mtime def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]: - return self.context._get_variables(config=self.config, gateway_name=gateway_name) + gateway_name = gateway_name or self.context.selected_gateway + + if gateway_name not in self._variables_by_gateway: + try: + gateway = self.config.get_gateway(gateway_name) + except ConfigError: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Gateway '{gateway_name}' not found in project '{self.config.project}'." + ) + gateway = None + + self._variables_by_gateway[gateway_name] = { + **self.config.variables, + **(gateway.variables if gateway else {}), + c.GATEWAY: gateway_name, + } + + return self._variables_by_gateway[gateway_name] class SqlMeshLoader(Loader): @@ -658,6 +687,53 @@ def _load_linting_rules(self) -> RuleSet: return RuleSet(user_rules.values()) + def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: + """Load a single model test file.""" + model_test_metadata = {} + contents = yaml_load(path, get_variables=self._get_variables) + + for test_name, value in contents.items(): + model_test_metadata[test_name] = ModelTestMetadata( + path=path, test_name=test_name, body=value + ) + + return model_test_metadata + + def _load_model_tests( + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None + ) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + test_meta: t.List[ModelTestMetadata] = [] + + if tests: + for test in tests: + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + + test_file = self._load_model_test_file(Path(filename)) + if test_name: + test_meta.append(test_file[test_name]) + else: + test_meta.extend(test_file.values()) + else: + search_path = Path(self.config_path) / c.TESTS + + for yaml_file in itertools.chain( + search_path.glob("**/test*.yaml"), + search_path.glob("**/test*.yml"), + ): + if any( + yaml_file.match(ignore_pattern) + for ignore_pattern in self.config.ignore_patterns or [] + ): + continue + + test_meta.extend(self._load_model_test_file(yaml_file).values()) + + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) + + return test_meta + class _Cache(CacheBase): def __init__(self, loader: SqlMeshLoader, config_path: Path): self._loader = loader diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 909bcb34da..6353370f45 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -3,7 +3,7 @@ from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, - load_model_tests as load_model_tests, + filter_tests_by_patterns as filter_tests_by_patterns, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult from sqlmesh.core.test.runner import run_tests as run_tests diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index 4b8d64db05..0f60fe6fa9 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -4,17 +4,11 @@ import itertools import pathlib import typing as t -from collections.abc import Iterator import ruamel -from sqlmesh.core import constants as c from sqlmesh.utils import unique from sqlmesh.utils.pydantic import PydanticModel -from sqlmesh.utils.yaml import load as yaml_load - -if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C class ModelTestMetadata(PydanticModel): @@ -30,61 +24,6 @@ def __hash__(self) -> int: return self.fully_qualified_test_name.__hash__() -def load_model_test_file( - path: pathlib.Path, - config: C, - get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], -) -> dict[str, ModelTestMetadata]: - """Load a single model test file. - - Args: - path: The path to the test file - - returns: - A list of ModelTestMetadata named tuples. - """ - model_test_metadata = {} - contents = yaml_load(path, config=config, get_variables=get_variables) - - for test_name, value in contents.items(): - model_test_metadata[test_name] = ModelTestMetadata( - path=path, test_name=test_name, body=value - ) - return model_test_metadata - - -def discover_model_tests( - path: pathlib.Path, - config: C, - get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], -) -> Iterator[ModelTestMetadata]: - """Discover model tests. - - Model tests are defined in YAML files and contain the inputs and outputs used to test model queries. - - Args: - path: A path to search for tests. - ignore_patterns: An optional list of patterns to ignore. - - Returns: - A list of ModelTestMetadata named tuples. - """ - search_path = pathlib.Path(path) - - for yaml_file in itertools.chain( - search_path.glob("**/test*.yaml"), - search_path.glob("**/test*.yml"), - ): - for ignore_pattern in config.ignore_patterns or []: - if yaml_file.match(ignore_pattern): - break - else: - for model_test_metadata in load_model_test_file( - yaml_file, config=config, get_variables=get_variables - ).values(): - yield model_test_metadata - - def filter_tests_by_patterns( tests: list[ModelTestMetadata], patterns: list[str] ) -> list[ModelTestMetadata]: @@ -103,47 +42,3 @@ def filter_tests_by_patterns( if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern)) or pattern in test.fully_qualified_test_name ) - - -def load_model_tests( - configs: t.Dict[pathlib.Path, C], - get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]], - tests: t.Optional[t.List[str]] = None, - patterns: list[str] | None = None, -) -> list[ModelTestMetadata]: - """Load model tests into a list of ModelTestMetadata which will be propagated to the test runner. - - Args: - tests: A list of tests to load; If not specified, all tests are loaded - patterns: A list of patterns that'll be used to filter tests by file name. - configs: A dictionary of configs to use when loading all the tests. - """ - test_meta = [] - - if tests: - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - - test_file = load_model_test_file( - pathlib.Path(filename), - config=next(iter(configs.values())), - get_variables=get_variables, - ) - if test_name: - test_meta.append(test_file[test_name]) - else: - test_meta.extend(test_file.values()) - else: - for path, config in configs.items(): - test_meta.extend( - discover_model_tests( - pathlib.Path(path / c.TESTS), - config=config, - get_variables=get_variables, - ) - ) - - if patterns: - test_meta = filter_tests_by_patterns(test_meta, patterns) - - return test_meta diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index 304cb013aa..30144f8508 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -107,7 +107,6 @@ def log_test_report(self, test_duration: float) -> None: for _, error in errors: stream.writeln(unittest.TextTestResult.separator1) stream.writeln(f"ERROR: {error}") - stream.writeln(unittest.TextTestResult.separator2) # Output final report stream.writeln(unittest.TextTestResult.separator2) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 7463656a79..ca17ba5e9b 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -31,7 +31,7 @@ from sqlmesh.core.context import Context from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model -from sqlmesh.core.test import ModelTestMetadata, load_model_tests +from sqlmesh.core.test import ModelTestMetadata from sqlmesh.utils import sqlglot_dialects, yaml, Verbosity from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError @@ -272,7 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = load_model_tests(configs=context.configs, get_variables=context._get_variables) + test_meta = context._load_model_tests() tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index c453d557fd..ac1c374b90 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -13,10 +13,6 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.jinja import ENVIRONMENT, create_var -if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C - - JINJA_METHODS = { "env_var": lambda key, default=None: getenv(key, default), } @@ -44,8 +40,7 @@ def load( render_jinja: bool = True, allow_duplicate_keys: bool = False, variables: t.Optional[t.Dict[str, t.Any]] = None, - config: t.Optional[C] = None, - get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]] | None = None, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None, ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None @@ -62,7 +57,7 @@ def load( gateway_line = GATEWAY_PATTERN.search(source) gateway = yaml.load(gateway_line.group(0))["gateway"] if gateway_line else None - variables = get_variables(config, gateway) + variables = get_variables(gateway) if render_jinja: source = ENVIRONMENT.from_string(source).render( diff --git a/tests/core/test_test.py b/tests/core/test_test.py index c358772ef1..70666d99bd 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2308,3 +2308,67 @@ def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: in output.stderr ) assert "OK" in output.stderr + + +def test_number_of_tests_found(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + # Example project contains 1 test and we add a new file with 2 tests + test_file = tmp_path / "tests" / "test_new.yaml" + test_file.write_text( + """ +test_example_full_model1: + model: sqlmesh_example.full_model + description: This is an invalid test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + +test_example_full_model2: + model: sqlmesh_example.full_model + description: This is an invalid test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + """ + ) + + context = Context(paths=tmp_path) + + # Case 1: All 3 tests should run without any tests specified + results = context.test() + assert len(results.successes) == 3 + + # Case 1: The "new_test.yaml" should amount to 2 subtests + results = context.test(tests=[f"{test_file}"]) + assert len(results.successes) == 2 + + # Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest + results = context.test(tests=[f"{test_file}::test_example_full_model2"]) + assert len(results.successes) == 1 From e0815b86843b45fb1dc7613350dcd5d325302a68 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 11 Apr 2025 15:20:17 +0300 Subject: [PATCH 07/11] Use a single loader if specific tests are specified --- sqlmesh/core/context.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 1ec32c8ef5..e47d3c060d 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2465,9 +2465,12 @@ def lint_models( def _load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: - model_tests = [] + # If a set of tests is provided, use a single loader to load them + # Otherwise, gather all tests from all loaders/repos + loaders = [self._loaders[0]] if tests else self._loaders - for loader in self._loaders: + model_tests = [] + for loader in loaders: model_tests.extend(loader._load_model_tests(tests=tests, patterns=patterns)) return model_tests From a0f2fc6cd64774a7dec3f0ab89a2074fcc09dc8b Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 11 Apr 2025 18:48:06 +0300 Subject: [PATCH 08/11] Make load_model_tests public --- sqlmesh/core/context.py | 4 ++-- sqlmesh/magics.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index e47d3c060d..1948fd7dc3 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1785,7 +1785,7 @@ def test( if verbosity >= Verbosity.VERBOSE: pd.set_option("display.max_columns", None) - test_meta = self._load_model_tests(tests=tests, patterns=match_patterns) + test_meta = self.load_model_tests(tests=tests, patterns=match_patterns) return run_tests( model_test_metadata=test_meta, @@ -2462,7 +2462,7 @@ def lint_models( "Linter detected errors in the code. Please fix them before proceeding." ) - def _load_model_tests( + def load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: # If a set of tests is provided, use a single loader to load them diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index ca17ba5e9b..6bf28a8496 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -272,7 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = context._load_model_tests() + test_meta = context.load_model_tests() tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: From 45c0f80dd3690e684ac3ed1e714a5b49c3777cd7 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 14 Apr 2025 14:13:15 +0300 Subject: [PATCH 09/11] PR Feedback 2 --- sqlmesh/core/context.py | 6 +++--- sqlmesh/core/loader.py | 18 +++++++++--------- tests/core/test_test.py | 2 -- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 1948fd7dc3..2dbc9e4c45 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2465,13 +2465,13 @@ def lint_models( def load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: - # If a set of tests is provided, use a single loader to load them - # Otherwise, gather all tests from all loaders/repos + # If a set of specific test path(s) are provided, we can use a single loader + # since it's not required to walk every tests/ folder in each repo loaders = [self._loaders[0]] if tests else self._loaders model_tests = [] for loader in loaders: - model_tests.extend(loader._load_model_tests(tests=tests, patterns=patterns)) + model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns)) return model_tests diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 5fa73fddc2..a42113b381 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -293,7 +293,7 @@ def _load_linting_rules(self) -> RuleSet: """Loads user linting rules""" return RuleSet() - def _load_model_tests( + def load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: """Loads YAML-based model tests""" @@ -699,21 +699,21 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: return model_test_metadata - def _load_model_tests( + def load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: """Loads YAML-based model tests""" - test_meta: t.List[ModelTestMetadata] = [] + test_meta_list: t.List[ModelTestMetadata] = [] if tests: for test in tests: filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - test_file = self._load_model_test_file(Path(filename)) + test_meta = self._load_model_test_file(Path(filename)) if test_name: - test_meta.append(test_file[test_name]) + test_meta_list.append(test_meta[test_name]) else: - test_meta.extend(test_file.values()) + test_meta_list.extend(test_meta.values()) else: search_path = Path(self.config_path) / c.TESTS @@ -727,12 +727,12 @@ def _load_model_tests( ): continue - test_meta.extend(self._load_model_test_file(yaml_file).values()) + test_meta_list.extend(self._load_model_test_file(yaml_file).values()) if patterns: - test_meta = filter_tests_by_patterns(test_meta, patterns) + test_meta_list = filter_tests_by_patterns(test_meta_list, patterns) - return test_meta + return test_meta_list class _Cache(CacheBase): def __init__(self, loader: SqlMeshLoader, config_path: Path): diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 70666d99bd..fec0b49d3c 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2319,7 +2319,6 @@ def test_number_of_tests_found(tmp_path: Path) -> None: """ test_example_full_model1: model: sqlmesh_example.full_model - description: This is an invalid test inputs: sqlmesh_example.incremental_model: rows: @@ -2339,7 +2338,6 @@ def test_number_of_tests_found(tmp_path: Path) -> None: test_example_full_model2: model: sqlmesh_example.full_model - description: This is an invalid test inputs: sqlmesh_example.incremental_model: rows: From a754ca9a9bbc0d58f57c0ff3c8214af42977c15e Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 14 Apr 2025 18:46:02 +0300 Subject: [PATCH 10/11] Remove get_variables from yaml.py::load() --- sqlmesh/core/loader.py | 13 ++++++++++++- sqlmesh/utils/yaml.py | 15 +-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index a42113b381..bd27d1e752 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -6,6 +6,7 @@ import linecache import logging import os +import re import typing as t from collections import Counter, defaultdict from dataclasses import dataclass @@ -46,6 +47,8 @@ logger = logging.getLogger(__name__) +GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") + @dataclass class LoadedProject: @@ -690,7 +693,15 @@ def _load_linting_rules(self) -> RuleSet: def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: """Load a single model test file.""" model_test_metadata = {} - contents = yaml_load(path, get_variables=self._get_variables) + + with open(path, "r", encoding="utf-8") as file: + source = file.read() + # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to + # parse it as YAML to match the gateway name stored in the config + gateway_line = GATEWAY_PATTERN.search(source) + gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None + + contents = yaml_load(source, variables=self._get_variables(gateway)) for test_name, value in contents.items(): model_test_metadata[test_name] = ModelTestMetadata( diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index ac1c374b90..549d849902 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -5,7 +5,6 @@ from decimal import Decimal from os import getenv from pathlib import Path -import re from ruamel import yaml @@ -18,9 +17,6 @@ } -GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") - - def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML: yaml_obj = yaml.YAML(typ=typ) @@ -40,25 +36,15 @@ def load( render_jinja: bool = True, allow_duplicate_keys: bool = False, variables: t.Optional[t.Dict[str, t.Any]] = None, - get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None, ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None - yaml = YAML() if isinstance(source, Path): path = source with open(source, "r", encoding="utf-8") as file: source = file.read() - if get_variables: - # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to - # parse it as YAML to match the gateway name stored in the config - gateway_line = GATEWAY_PATTERN.search(source) - gateway = yaml.load(gateway_line.group(0))["gateway"] if gateway_line else None - - variables = get_variables(gateway) - if render_jinja: source = ENVIRONMENT.from_string(source).render( { @@ -67,6 +53,7 @@ def load( } ) + yaml = YAML() yaml.allow_duplicate_keys = allow_duplicate_keys contents = yaml.load(source) if contents is None: From e9a35174c721f705c5e21cc1d39d31939e57fae5 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 14 Apr 2025 18:56:46 +0300 Subject: [PATCH 11/11] Test typo --- tests/core/test_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_test.py b/tests/core/test_test.py index fec0b49d3c..df9ff3ea5e 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2363,7 +2363,7 @@ def test_number_of_tests_found(tmp_path: Path) -> None: results = context.test() assert len(results.successes) == 3 - # Case 1: The "new_test.yaml" should amount to 2 subtests + # Case 2: The "new_test.yaml" should amount to 2 subtests results = context.test(tests=[f"{test_file}"]) assert len(results.successes) == 2