diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 5c5cb36380..c907aacf7c 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -1,11 +1,5 @@ from __future__ import annotations -import pathlib -import typing as t -import unittest - -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.model import Model from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, @@ -14,121 +8,7 @@ load_model_test_file as load_model_test_file, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult -from sqlmesh.utils import UniqueKeyDict, Verbosity - -if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C - - -def run_tests( - model_test_metadata: list[ModelTestMetadata], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: str | None = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Create a test suite of ModelTest objects and run it. - - Args: - model_test_metadata: A list of ModelTestMetadata named tuples. - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} - default_gateway = gateway or config.default_gateway_name - - try: - tests = [] - for metadata in model_test_metadata: - body = metadata.body - gateway = body.get("gateway") or default_gateway - testing_engine_adapter = testing_adapter_by_gateway.get(gateway) - if not testing_engine_adapter: - testing_engine_adapter = config.get_test_connection( - gateway, - default_catalog, - default_catalog_dialect, - ).create_engine_adapter(register_comments_override=False) - testing_adapter_by_gateway[gateway] = testing_engine_adapter - - test = ModelTest.create_test( - body=body, - test_name=metadata.test_name, - models=models, - engine_adapter=testing_engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) - if test: - tests.append(test) - - result = t.cast( - ModelTextTestResult, - unittest.TextTestRunner( - stream=stream, - verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, - resultclass=ModelTextTestResult, - ).run(unittest.TestSuite(tests)), - ) - finally: - for testing_engine_adapter in testing_adapter_by_gateway.values(): - testing_engine_adapter.close() - - return result - - -def run_model_tests( - tests: list[str], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - patterns: list[str] | None = None, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: t.Optional[str] = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Load and run tests. - - Args: - tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - patterns: A list of patterns to match against. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - loaded_tests = [] - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - path = pathlib.Path(filename) - - if test_name: - loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) - else: - loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) - - if patterns: - loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) - - return run_tests( - loaded_tests, - models, - config, - gateway=gateway, - dialect=dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) +from sqlmesh.core.test.runner import ( + run_model_tests as run_model_tests, + run_tests as run_tests, +) diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index 27eb0c12da..304cb013aa 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -49,8 +49,19 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ exctype, value, tb = err + self.original_err = (test, err) return super().addFailure(test, (exctype, value, None)) # type: ignore + def addError(self, test: unittest.TestCase, err: ErrorType) -> None: + """Called when the test case test signals an error. + + Args: + test: The test case. + err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). + """ + self.original_err = (test, err) + return super().addError(test, err) # type: ignore + def addSuccess(self, test: unittest.TestCase) -> None: """Called when the test case test succeeds. @@ -59,3 +70,50 @@ def addSuccess(self, test: unittest.TestCase) -> None: """ super().addSuccess(test) self.successes.append(test) + + def log_test_report(self, test_duration: float) -> None: + """ + Log the test report following unittest's conventions. + + Args: + test_duration: The duration of the tests. + """ + tests_run = self.testsRun + errors = self.errors + failures = self.failures + skipped = self.skipped + + is_success = not (errors or failures) + + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + stream = self.stream + + stream.write("\n") + + for test_case, failure in failures: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"FAIL: {test_case}") + stream.writeln(f"{test_case.shortDescription()}") + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln(failure) + + for _, error in errors: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"ERROR: {error}") + stream.writeln(unittest.TextTestResult.separator2) + + # Output final report + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln( + f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' + ) + stream.writeln( + f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' + ) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py new file mode 100644 index 0000000000..08a12bb59b --- /dev/null +++ b/sqlmesh/core/test/runner.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import sys +import time +import pathlib +import threading +import typing as t +import unittest + +import concurrent +from concurrent.futures import ThreadPoolExecutor + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.model import Model +from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test +from sqlmesh.core.test.discovery import ( + ModelTestMetadata as ModelTestMetadata, + filter_tests_by_patterns as filter_tests_by_patterns, + get_all_model_tests as get_all_model_tests, + load_model_test_file as load_model_test_file, +) +from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig + +from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult +from sqlmesh.utils import UniqueKeyDict, Verbosity + + +if t.TYPE_CHECKING: + from sqlmesh.core.config.loader import C + + +class ModelTextTestRunner(unittest.TextTestRunner): + def __init__( + self, + **kwargs: t.Any, + ) -> None: + # StringIO is used to capture the output of the tests since we'll + # run them in parallel and we don't want to mix the output streams + from io import StringIO + + super().__init__( + stream=StringIO(), + resultclass=ModelTextTestResult, + **kwargs, + ) + + +def create_testing_engine_adapters( + model_test_metadata: list[ModelTestMetadata], + config: C, + default_gateway: str, + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> t.Dict[ModelTestMetadata, EngineAdapter]: + testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} + metadata_to_adapter = {} + + for metadata in model_test_metadata: + gateway = metadata.body.get("gateway") or default_gateway + test_connection = config.get_test_connection( + gateway, default_catalog, default_catalog_dialect + ) + + concurrent_tasks = test_connection.concurrent_tasks + + is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig) + adapter = None + if is_duckdb_connection: + # Ensure DuckDB connections are fully isolated from each other + # by forcing the creation of a new adapter with SingletonConnectionPool + test_connection.concurrent_tasks = 1 + adapter = test_connection.create_engine_adapter(register_comments_override=False) + test_connection.concurrent_tasks = concurrent_tasks + elif gateway not in testing_adapter_by_gateway: + # All other engines can share connections between threads + testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter( + register_comments_override=False + ) + + metadata_to_adapter[metadata] = adapter or testing_adapter_by_gateway[gateway] + + return metadata_to_adapter + + +def run_tests( + model_test_metadata: list[ModelTestMetadata], + models: UniqueKeyDict[str, Model], + config: C, + gateway: t.Optional[str] = None, + dialect: str | None = None, + verbosity: Verbosity = Verbosity.DEFAULT, + preserve_fixtures: bool = False, + stream: t.TextIO | None = None, + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> ModelTextTestResult: + """Create a test suite of ModelTest objects and run it. + + Args: + model_test_metadata: A list of ModelTestMetadata named tuples. + models: All models to use for expansion and mapping of physical locations. + verbosity: The verbosity level. + preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. + """ + default_gateway = gateway or config.default_gateway_name + + default_test_connection = config.get_test_connection( + gateway_name=default_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + lock = threading.Lock() + + combined_results = ModelTextTestResult( + stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore + verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, + descriptions=True, + ) + + metadata_to_adapter = create_testing_engine_adapters( + model_test_metadata=model_test_metadata, + config=config, + default_gateway=default_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + def _run_single_test( + metadata: ModelTestMetadata, engine_adapter: EngineAdapter + ) -> ModelTextTestResult: + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + ) + + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.original_err[0], result.original_err[1]) + elif result.failures: + combined_results.addFailure(result.original_err[0], result.original_err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) + return result + + test_results = [] + + # Ensure workers are not greater than the number of tests + num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) + + start_time = time.perf_counter() + try: + with ThreadPoolExecutor(max_workers=num_workers) as pool: + futures = [ + pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) + for metadata, engine_adapter in metadata_to_adapter.items() + ] + + for future in concurrent.futures.as_completed(futures): + test_results.append(future.result()) + finally: + for engine_adapter in set(metadata_to_adapter.values()): + # The engine adapters list might have duplicates, so we ensure that we close each adapter once + if engine_adapter: + engine_adapter.close() + + end_time = time.perf_counter() + + combined_results.testsRun = len(test_results) + + combined_results.log_test_report(test_duration=end_time - start_time) + + return combined_results + + +def run_model_tests( + tests: list[str], + models: UniqueKeyDict[str, Model], + config: C, + gateway: t.Optional[str] = None, + dialect: str | None = None, + verbosity: Verbosity = Verbosity.DEFAULT, + patterns: list[str] | None = None, + preserve_fixtures: bool = False, + stream: t.TextIO | None = None, + default_catalog: t.Optional[str] = None, + default_catalog_dialect: str = "", +) -> ModelTextTestResult: + """Load and run tests. + + Args: + tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] + models: All models to use for expansion and mapping of physical locations. + verbosity: The verbosity level. + patterns: A list of patterns to match against. + preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. + """ + loaded_tests = [] + for test in tests: + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + path = pathlib.Path(filename) + + if test_name: + loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) + else: + loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) + + if patterns: + loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) + + return run_tests( + loaded_tests, + models, + config, + gateway=gateway, + dialect=dialect, + verbosity=verbosity, + preserve_fixtures=preserve_fixtures, + stream=stream, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) diff --git a/tests/core/test_test.py b/tests/core/test_test.py index b9c0562d36..e50c3bdf06 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -4,11 +4,13 @@ import typing as t from pathlib import Path from unittest.mock import call, patch +from shutil import copyfile import pandas as pd import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp +from IPython.utils.capture import capture_output from sqlmesh.cli.example_project import init_example_project from sqlmesh.core import constants as c @@ -2128,3 +2130,76 @@ def test_test_with_resolve_template_macro(tmp_path: Path): context = Context(paths=tmp_path, config=config) _check_successful_or_raise(context.test()) + + +def test_test_output(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + original_test_file = tmp_path / "tests" / "test_full_model.yaml" + + new_test_file = tmp_path / "tests" / "test_full_model_error.yaml" + new_test_file.write_text( + """ +test_example_full_model: + model: sqlmesh_example.full_model + description: This is a test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 2 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8), + ) + context = Context(paths=tmp_path, config=config) + + # Case 1: Assert the log report is structured correctly + with capture_output() as output: + context.test() + + # Order may change due to concurrent execution + assert "F." in output.stderr or ".F" in output.stderr + assert ( + f"""====================================================================== +FAIL: test_example_full_model ({new_test_file}) +This is a test +---------------------------------------------------------------------- +AssertionError: Data mismatch (exp: expected, act: actual) + + num_orders + exp act +1 2.0 1.0 + +----------------------------------------------------------------------""" + in output.stderr + ) + + assert "Ran 2 tests" in output.stderr + assert "FAILED (failures=1)" in output.stderr + + # Case 2: Assert that concurrent execution is working properly + for i in range(50): + copyfile(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml") + copyfile(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml") + + with capture_output() as output: + context.test() + + assert "Ran 102 tests" in output.stderr + assert "FAILED (failures=51)" in output.stderr