From e9f146a25b9ddea37155672730d4136324a70588 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 26 Mar 2025 12:34:38 -0400 Subject: [PATCH 1/7] Feat: Make model tests concurrent --- sqlmesh/core/test/__init__.py | 162 ++++++++++++++++++++++++++++------ sqlmesh/core/test/result.py | 13 +++ 2 files changed, 147 insertions(+), 28 deletions(-) diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 5c5cb36380..a15dda5c73 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -1,10 +1,16 @@ from __future__ import annotations +import sys +import time import pathlib +import threading import typing as t import unittest -from sqlmesh.core.engine_adapter import EngineAdapter + +import concurrent +from concurrent.futures import ThreadPoolExecutor + from sqlmesh.core.model import Model from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( @@ -20,6 +26,66 @@ from sqlmesh.core.config.loader import C +class ModelTextTestRunner(unittest.TextTestRunner): + def __init__( + self, + **kwargs: t.Any, + ) -> None: + # StringIO is used to capture the output of the tests since we'll + # run them in parallel and we don't want to mix the output streams + from io import StringIO + + super().__init__( + stream=StringIO(), + resultclass=ModelTextTestResult, + **kwargs, + ) + + +def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: + # Aggregate parallel test run results + tests_run = results.testsRun + errors = results.errors + failures = results.failures + skipped = results.skipped + + is_success = not (errors or failures) + + # Compute test info + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + # Report test errors + stream = results.stream + + stream.write("\n") + + if errors or failures: + stream.writeln(unittest.TextTestResult.separator1) + for failure in failures: + stream.writeln(f"FAIL: {failure[0]}") + + stream.writeln(unittest.TextTestResult.separator2) + for error in errors: + stream.writeln(error[1]) + for failure in failures: + stream.writeln(failure[1]) + + # Test report + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln( + f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' + ) + stream.write( + f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' + ) + + def run_tests( model_test_metadata: list[ModelTestMetadata], models: UniqueKeyDict[str, Model], @@ -40,22 +106,35 @@ def run_tests( verbosity: The verbosity level. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ - testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} default_gateway = gateway or config.default_gateway_name - try: - tests = [] - for metadata in model_test_metadata: + default_test_connection = config.get_test_connection( + gateway_name=default_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + lock = threading.Lock() + + combined_results = ModelTextTestResult( + stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore + verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, + descriptions=None, + ) + + def _run_single_test(metadata: ModelTestMetadata) -> ModelTextTestResult: + testing_engine_adapter = None + + try: body = metadata.body gateway = body.get("gateway") or default_gateway - testing_engine_adapter = testing_adapter_by_gateway.get(gateway) - if not testing_engine_adapter: - testing_engine_adapter = config.get_test_connection( - gateway, - default_catalog, - default_catalog_dialect, - ).create_engine_adapter(register_comments_override=False) - testing_adapter_by_gateway[gateway] = testing_engine_adapter + + # Create new connection for each test to avoid concurrency issues + testing_engine_adapter = config.get_test_connection( + gateway, + default_catalog, + default_catalog_dialect, + ).create_engine_adapter(register_comments_override=False) test = ModelTest.create_test( body=body, @@ -67,22 +146,49 @@ def run_tests( default_catalog=default_catalog, preserve_fixtures=preserve_fixtures, ) - if test: - tests.append(test) - - result = t.cast( - ModelTextTestResult, - unittest.TextTestRunner( - stream=stream, - verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, - resultclass=ModelTextTestResult, - ).run(unittest.TestSuite(tests)), - ) - finally: - for testing_engine_adapter in testing_adapter_by_gateway.values(): - testing_engine_adapter.close() - return result + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.err[0], result.err[1]) + elif result.failures: + combined_results.addFailure(result.err[0], result.err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) + + finally: + if testing_engine_adapter: + testing_engine_adapter.close() + + return result + + test_results = [] + + workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) + + start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = [ + pool.submit(_run_single_test, metadata=metadata) for metadata in model_test_metadata + ] + + for future in concurrent.futures.as_completed(futures): + test_results.append(future.result()) + + end_time = time.perf_counter() + + combined_results.testsRun = len(test_results) + + log_test_report(combined_results, test_duration=end_time - start_time) + + return combined_results def run_model_tests( diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index 27eb0c12da..fac59cb360 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -49,8 +49,21 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ exctype, value, tb = err + self.err = (test, err) return super().addFailure(test, (exctype, value, None)) # type: ignore + def addError(self, test: unittest.TestCase, err: ErrorType) -> None: + """Called when the test case test signals an error. + + The traceback is suppressed because it is redundant and not useful. + + Args: + test: The test case. + err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). + """ + self.err = (test, err) + return super().addError(test, err) # type: ignore + def addSuccess(self, test: unittest.TestCase) -> None: """Called when the test case test succeeds. From 803609063c26642bd48582b7f4d213bc80567d5d Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 28 Mar 2025 12:49:30 -0400 Subject: [PATCH 2/7] Reuse engine adapters and close at the end --- sqlmesh/core/test/__init__.py | 151 ++++++++++++++++++---------------- sqlmesh/core/test/runner.py | 22 +++++ 2 files changed, 101 insertions(+), 72 deletions(-) create mode 100644 sqlmesh/core/test/runner.py diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index a15dda5c73..a420ada7af 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -11,6 +11,7 @@ import concurrent from concurrent.futures import ThreadPoolExecutor +from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.model import Model from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( @@ -20,28 +21,13 @@ load_model_test_file as load_model_test_file, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult +from sqlmesh.core.test.runner import ModelTextTestRunner as ModelTextTestRunner from sqlmesh.utils import UniqueKeyDict, Verbosity if t.TYPE_CHECKING: from sqlmesh.core.config.loader import C -class ModelTextTestRunner(unittest.TextTestRunner): - def __init__( - self, - **kwargs: t.Any, - ) -> None: - # StringIO is used to capture the output of the tests since we'll - # run them in parallel and we don't want to mix the output streams - from io import StringIO - - super().__init__( - stream=StringIO(), - resultclass=ModelTextTestResult, - **kwargs, - ) - - def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: # Aggregate parallel test run results tests_run = results.testsRun @@ -65,23 +51,23 @@ def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: stream.write("\n") - if errors or failures: + for test_case, err in failures: stream.writeln(unittest.TextTestResult.separator1) - for failure in failures: - stream.writeln(f"FAIL: {failure[0]}") + stream.writeln(f"FAIL: {test_case}") + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln(err) + for error in errors: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"ERROR: {error[1]}") stream.writeln(unittest.TextTestResult.separator2) - for error in errors: - stream.writeln(error[1]) - for failure in failures: - stream.writeln(failure[1]) # Test report stream.writeln(unittest.TextTestResult.separator2) stream.writeln( f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' ) - stream.write( + stream.writeln( f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' ) @@ -106,6 +92,7 @@ def run_tests( verbosity: The verbosity level. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ + testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} default_gateway = gateway or config.default_gateway_name default_test_connection = config.get_test_connection( @@ -122,51 +109,65 @@ def run_tests( descriptions=None, ) - def _run_single_test(metadata: ModelTestMetadata) -> ModelTextTestResult: - testing_engine_adapter = None - - try: - body = metadata.body - gateway = body.get("gateway") or default_gateway - - # Create new connection for each test to avoid concurrency issues - testing_engine_adapter = config.get_test_connection( - gateway, - default_catalog, - default_catalog_dialect, - ).create_engine_adapter(register_comments_override=False) - - test = ModelTest.create_test( - body=body, - test_name=metadata.test_name, - models=models, - engine_adapter=testing_engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) + worker_payload = [] - result = t.cast( - ModelTextTestResult, - ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + for metadata in model_test_metadata: + gateway = metadata.body.get("gateway") or default_gateway + test_connection = config.get_test_connection( + gateway, default_catalog, default_catalog_dialect + ) + + concurrent_tasks = test_connection.concurrent_tasks + + from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig + + is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig) + + engine_adapter = None + if is_duckdb_connection: + # Ensure DuckDB connections are fully isolated from each other + # by forcing the creation of a new adapter with SingletonConnectionPool + test_connection.concurrent_tasks = 1 + engine_adapter = test_connection.create_engine_adapter(register_comments_override=False) + test_connection.concurrent_tasks = concurrent_tasks + elif gateway not in testing_adapter_by_gateway: + # All other engines can share connections between threads + testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter( + register_comments_override=False ) - with lock: - if result.successes: - combined_results.addSuccess(result.successes[0]) - elif result.errors: - combined_results.addError(result.err[0], result.err[1]) - elif result.failures: - combined_results.addFailure(result.err[0], result.err[1]) - elif result.skipped: - skipped_args = result.skipped[0] - combined_results.addSkip(skipped_args[0], skipped_args[1]) - - finally: - if testing_engine_adapter: - testing_engine_adapter.close() + engine_adapter = engine_adapter or testing_adapter_by_gateway[gateway] + worker_payload.append((metadata, engine_adapter)) + + def _run_single_test( + metadata: ModelTestMetadata, engine_adapter: EngineAdapter + ) -> ModelTextTestResult: + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + ) + + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.err[0], result.err[1]) + elif result.failures: + combined_results.addFailure(result.err[0], result.err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) return result test_results = [] @@ -174,13 +175,19 @@ def _run_single_test(metadata: ModelTestMetadata) -> ModelTextTestResult: workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=workers) as pool: - futures = [ - pool.submit(_run_single_test, metadata=metadata) for metadata in model_test_metadata - ] - - for future in concurrent.futures.as_completed(futures): - test_results.append(future.result()) + try: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = [ + pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) + for metadata, engine_adapter in worker_payload + ] + + for future in concurrent.futures.as_completed(futures): + test_results.append(future.result()) + finally: + for _, engine_adapter in worker_payload: + if engine_adapter: + engine_adapter.close() end_time = time.perf_counter() diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py new file mode 100644 index 0000000000..96b2c7eb1c --- /dev/null +++ b/sqlmesh/core/test/runner.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import typing as t +import unittest + +from sqlmesh.core.test.result import ModelTextTestResult + + +class ModelTextTestRunner(unittest.TextTestRunner): + def __init__( + self, + **kwargs: t.Any, + ) -> None: + # StringIO is used to capture the output of the tests since we'll + # run them in parallel and we don't want to mix the output streams + from io import StringIO + + super().__init__( + stream=StringIO(), + resultclass=ModelTextTestResult, + **kwargs, + ) From 5dabc2ded1d00254a6d50c6ef82a46ccb630556f Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 31 Mar 2025 12:32:26 -0400 Subject: [PATCH 3/7] PR Feedback 2 --- sqlmesh/core/test/__init__.py | 49 +++-------------------------------- sqlmesh/core/test/runner.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index a420ada7af..0d4f5fe518 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -21,57 +21,16 @@ load_model_test_file as load_model_test_file, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult -from sqlmesh.core.test.runner import ModelTextTestRunner as ModelTextTestRunner +from sqlmesh.core.test.runner import ( + ModelTextTestRunner, + log_test_report, +) from sqlmesh.utils import UniqueKeyDict, Verbosity if t.TYPE_CHECKING: from sqlmesh.core.config.loader import C -def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: - # Aggregate parallel test run results - tests_run = results.testsRun - errors = results.errors - failures = results.failures - skipped = results.skipped - - is_success = not (errors or failures) - - # Compute test info - infos = [] - if failures: - infos.append(f"failures={len(failures)}") - if errors: - infos.append(f"errors={len(errors)}") - if skipped: - infos.append(f"skipped={skipped}") - - # Report test errors - stream = results.stream - - stream.write("\n") - - for test_case, err in failures: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"FAIL: {test_case}") - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln(err) - - for error in errors: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"ERROR: {error[1]}") - stream.writeln(unittest.TextTestResult.separator2) - - # Test report - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln( - f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' - ) - stream.writeln( - f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' - ) - - def run_tests( model_test_metadata: list[ModelTestMetadata], models: UniqueKeyDict[str, Model], diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 96b2c7eb1c..ae22607e67 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -20,3 +20,47 @@ def __init__( resultclass=ModelTextTestResult, **kwargs, ) + + +def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: + # Aggregate parallel test run results + tests_run = results.testsRun + errors = results.errors + failures = results.failures + skipped = results.skipped + + is_success = not (errors or failures) + + # Compute test info + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + # Report test errors + stream = results.stream + + stream.write("\n") + + for test_case, err in failures: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"FAIL: {test_case}") + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln(err) + + for error in errors: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"ERROR: {error[1]}") + stream.writeln(unittest.TextTestResult.separator2) + + # Test report + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln( + f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' + ) + stream.writeln( + f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' + ) From ed6851b0795aa4987483c6494f03f8cf1ed2a7f8 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Tue, 1 Apr 2025 09:58:04 -0400 Subject: [PATCH 4/7] PR Feedback 3 --- sqlmesh/core/test/__init__.py | 196 +---------------------------- sqlmesh/core/test/result.py | 4 +- sqlmesh/core/test/runner.py | 230 +++++++++++++++++++++++++++++++++- tests/core/test_test.py | 75 +++++++++++ 4 files changed, 304 insertions(+), 201 deletions(-) diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 0d4f5fe518..c907aacf7c 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -1,18 +1,5 @@ from __future__ import annotations -import sys -import time -import pathlib -import threading -import typing as t -import unittest - - -import concurrent -from concurrent.futures import ThreadPoolExecutor - -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.model import Model from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, @@ -22,185 +9,6 @@ ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult from sqlmesh.core.test.runner import ( - ModelTextTestRunner, - log_test_report, + run_model_tests as run_model_tests, + run_tests as run_tests, ) -from sqlmesh.utils import UniqueKeyDict, Verbosity - -if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C - - -def run_tests( - model_test_metadata: list[ModelTestMetadata], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: str | None = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Create a test suite of ModelTest objects and run it. - - Args: - model_test_metadata: A list of ModelTestMetadata named tuples. - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} - default_gateway = gateway or config.default_gateway_name - - default_test_connection = config.get_test_connection( - gateway_name=default_gateway, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) - - lock = threading.Lock() - - combined_results = ModelTextTestResult( - stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore - verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, - descriptions=None, - ) - - worker_payload = [] - - for metadata in model_test_metadata: - gateway = metadata.body.get("gateway") or default_gateway - test_connection = config.get_test_connection( - gateway, default_catalog, default_catalog_dialect - ) - - concurrent_tasks = test_connection.concurrent_tasks - - from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig - - is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig) - - engine_adapter = None - if is_duckdb_connection: - # Ensure DuckDB connections are fully isolated from each other - # by forcing the creation of a new adapter with SingletonConnectionPool - test_connection.concurrent_tasks = 1 - engine_adapter = test_connection.create_engine_adapter(register_comments_override=False) - test_connection.concurrent_tasks = concurrent_tasks - elif gateway not in testing_adapter_by_gateway: - # All other engines can share connections between threads - testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter( - register_comments_override=False - ) - - engine_adapter = engine_adapter or testing_adapter_by_gateway[gateway] - worker_payload.append((metadata, engine_adapter)) - - def _run_single_test( - metadata: ModelTestMetadata, engine_adapter: EngineAdapter - ) -> ModelTextTestResult: - test = ModelTest.create_test( - body=metadata.body, - test_name=metadata.test_name, - models=models, - engine_adapter=engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) - - result = t.cast( - ModelTextTestResult, - ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), - ) - - with lock: - if result.successes: - combined_results.addSuccess(result.successes[0]) - elif result.errors: - combined_results.addError(result.err[0], result.err[1]) - elif result.failures: - combined_results.addFailure(result.err[0], result.err[1]) - elif result.skipped: - skipped_args = result.skipped[0] - combined_results.addSkip(skipped_args[0], skipped_args[1]) - return result - - test_results = [] - - workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) - - start_time = time.perf_counter() - try: - with ThreadPoolExecutor(max_workers=workers) as pool: - futures = [ - pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) - for metadata, engine_adapter in worker_payload - ] - - for future in concurrent.futures.as_completed(futures): - test_results.append(future.result()) - finally: - for _, engine_adapter in worker_payload: - if engine_adapter: - engine_adapter.close() - - end_time = time.perf_counter() - - combined_results.testsRun = len(test_results) - - log_test_report(combined_results, test_duration=end_time - start_time) - - return combined_results - - -def run_model_tests( - tests: list[str], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: Verbosity = Verbosity.DEFAULT, - patterns: list[str] | None = None, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: t.Optional[str] = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Load and run tests. - - Args: - tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - patterns: A list of patterns to match against. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - loaded_tests = [] - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - path = pathlib.Path(filename) - - if test_name: - loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) - else: - loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) - - if patterns: - loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) - - return run_tests( - loaded_tests, - models, - config, - gateway=gateway, - dialect=dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index fac59cb360..ce6a8a25ed 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -49,7 +49,7 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ exctype, value, tb = err - self.err = (test, err) + self.original_err = (test, err) return super().addFailure(test, (exctype, value, None)) # type: ignore def addError(self, test: unittest.TestCase, err: ErrorType) -> None: @@ -61,7 +61,7 @@ def addError(self, test: unittest.TestCase, err: ErrorType) -> None: test: The test case. err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ - self.err = (test, err) + self.original_err = (test, err) return super().addError(test, err) # type: ignore def addSuccess(self, test: unittest.TestCase) -> None: diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index ae22607e67..8a0c928c5c 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -1,9 +1,32 @@ from __future__ import annotations +import sys +import time +import pathlib +import threading import typing as t import unittest -from sqlmesh.core.test.result import ModelTextTestResult +import concurrent +from concurrent.futures import ThreadPoolExecutor + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.model import Model +from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test +from sqlmesh.core.test.discovery import ( + ModelTestMetadata as ModelTestMetadata, + filter_tests_by_patterns as filter_tests_by_patterns, + get_all_model_tests as get_all_model_tests, + load_model_test_file as load_model_test_file, +) +from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig + +from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult +from sqlmesh.utils import UniqueKeyDict, Verbosity + + +if t.TYPE_CHECKING: + from sqlmesh.core.config.loader import C class ModelTextTestRunner(unittest.TextTestRunner): @@ -45,15 +68,16 @@ def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: stream.write("\n") - for test_case, err in failures: + for test_case, failure in failures: stream.writeln(unittest.TextTestResult.separator1) stream.writeln(f"FAIL: {test_case}") + stream.writeln(f"{test_case.shortDescription()}") stream.writeln(unittest.TextTestResult.separator2) - stream.writeln(err) + stream.writeln(failure) - for error in errors: + for _, error in errors: stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"ERROR: {error[1]}") + stream.writeln(f"ERROR: {error}") stream.writeln(unittest.TextTestResult.separator2) # Test report @@ -64,3 +88,199 @@ def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: stream.writeln( f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' ) + + +def create_test_engine_adapters( + model_test_metadata: list[ModelTestMetadata], + config: C, + default_gateway: str, + testing_adapter_by_gateway: t.Dict[str, EngineAdapter], + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> list[EngineAdapter]: + engine_adapters = [] + for metadata in model_test_metadata: + gateway = metadata.body.get("gateway") or default_gateway + test_connection = config.get_test_connection( + gateway, default_catalog, default_catalog_dialect + ) + + concurrent_tasks = test_connection.concurrent_tasks + + is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig) + adapter = None + if is_duckdb_connection: + # Ensure DuckDB connections are fully isolated from each other + # by forcing the creation of a new adapter with SingletonConnectionPool + test_connection.concurrent_tasks = 1 + adapter = test_connection.create_engine_adapter(register_comments_override=False) + test_connection.concurrent_tasks = concurrent_tasks + elif gateway not in testing_adapter_by_gateway: + # All other engines can share connections between threads + testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter( + register_comments_override=False + ) + + engine_adapters.append(adapter or testing_adapter_by_gateway[gateway]) + + return engine_adapters + + +def run_tests( + model_test_metadata: list[ModelTestMetadata], + models: UniqueKeyDict[str, Model], + config: C, + gateway: t.Optional[str] = None, + dialect: str | None = None, + verbosity: Verbosity = Verbosity.DEFAULT, + preserve_fixtures: bool = False, + stream: t.TextIO | None = None, + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> ModelTextTestResult: + """Create a test suite of ModelTest objects and run it. + + Args: + model_test_metadata: A list of ModelTestMetadata named tuples. + models: All models to use for expansion and mapping of physical locations. + verbosity: The verbosity level. + preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. + """ + testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} + default_gateway = gateway or config.default_gateway_name + + default_test_connection = config.get_test_connection( + gateway_name=default_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + lock = threading.Lock() + + combined_results = ModelTextTestResult( + stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore + verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, + descriptions=True, + ) + + engine_adapters = create_test_engine_adapters( + model_test_metadata, + config, + default_gateway, + testing_adapter_by_gateway, + default_catalog, + default_catalog_dialect, + ) + + def _run_single_test( + metadata: ModelTestMetadata, engine_adapter: EngineAdapter + ) -> ModelTextTestResult: + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + ) + + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + if result.successes: + combined_results.addSuccess(result.successes[0]) + elif result.errors: + combined_results.addError(result.original_err[0], result.original_err[1]) + elif result.failures: + combined_results.addFailure(result.original_err[0], result.original_err[1]) + elif result.skipped: + skipped_args = result.skipped[0] + combined_results.addSkip(skipped_args[0], skipped_args[1]) + return result + + test_results = [] + + # Ensure workers are not greater than the number of tests + num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) + + start_time = time.perf_counter() + try: + with ThreadPoolExecutor(max_workers=num_workers) as pool: + futures = [ + pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) + for metadata, engine_adapter in zip(model_test_metadata, engine_adapters) + ] + + for future in concurrent.futures.as_completed(futures): + test_results.append(future.result()) + finally: + closed_adapters: t.Set[int] = set() + + for engine_adapter in engine_adapters: + # The engine adapters list might have duplicates, so we ensure that we close each adapter once + hashed_adapter = hash(engine_adapter) + if engine_adapter and hashed_adapter not in closed_adapters: + engine_adapter.close() + closed_adapters.add(hashed_adapter) + + end_time = time.perf_counter() + + combined_results.testsRun = len(test_results) + + log_test_report(combined_results, test_duration=end_time - start_time) + + return combined_results + + +def run_model_tests( + tests: list[str], + models: UniqueKeyDict[str, Model], + config: C, + gateway: t.Optional[str] = None, + dialect: str | None = None, + verbosity: Verbosity = Verbosity.DEFAULT, + patterns: list[str] | None = None, + preserve_fixtures: bool = False, + stream: t.TextIO | None = None, + default_catalog: t.Optional[str] = None, + default_catalog_dialect: str = "", +) -> ModelTextTestResult: + """Load and run tests. + + Args: + tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] + models: All models to use for expansion and mapping of physical locations. + verbosity: The verbosity level. + patterns: A list of patterns to match against. + preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. + """ + loaded_tests = [] + for test in tests: + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") + path = pathlib.Path(filename) + + if test_name: + loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) + else: + loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) + + if patterns: + loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) + + return run_tests( + loaded_tests, + models, + config, + gateway=gateway, + dialect=dialect, + verbosity=verbosity, + preserve_fixtures=preserve_fixtures, + stream=stream, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) diff --git a/tests/core/test_test.py b/tests/core/test_test.py index b9c0562d36..3afaac4145 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -4,11 +4,13 @@ import typing as t from pathlib import Path from unittest.mock import call, patch +from shutil import copyfile import pandas as pd import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp +from IPython.utils.capture import capture_output from sqlmesh.cli.example_project import init_example_project from sqlmesh.core import constants as c @@ -2128,3 +2130,76 @@ def test_test_with_resolve_template_macro(tmp_path: Path): context = Context(paths=tmp_path, config=config) _check_successful_or_raise(context.test()) + + +def test_test_generation_report(tmp_path: Path) -> None: + init_example_project(tmp_path, dialect="duckdb") + + original_test_file = tmp_path / "tests" / "test_full_model.yaml" + + new_test_file = tmp_path / "tests" / "test_full_model_error.yaml" + new_test_file.write_text( + """ +test_example_full_model: + model: sqlmesh_example.full_model + description: This is a test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 2 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8), + ) + context = Context(paths=tmp_path, config=config) + + # Case 1: Assert the log report is structured correctly + with capture_output() as output: + context.test() + + # Order may change due to concurrent execution + assert "F." in output.stderr or ".F" in output.stderr + assert ( + f"""====================================================================== +FAIL: test_example_full_model ({new_test_file}) +This is a test +---------------------------------------------------------------------- +AssertionError: Data mismatch (exp: expected, act: actual) + + num_orders + exp act +1 2.0 1.0 + +----------------------------------------------------------------------""" + in output.stderr + ) + + assert "Ran 2 tests" in output.stderr + assert "FAILED (failures=1)" in output.stderr + + # Case 2: Assert that concurrent execution is working properly + for i in range(50): + copyfile(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml") + copyfile(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml") + + with capture_output() as output: + context.test() + + assert "Ran 102 tests" in output.stderr + assert "FAILED (failures=51)" in output.stderr From f960ebeb033f9d9a06ac6fc199d37867d767ed86 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Tue, 1 Apr 2025 16:19:45 -0400 Subject: [PATCH 5/7] Move log_test_report to result class --- sqlmesh/core/test/result.py | 47 +++++++++++++++++++++++++++++++++++++ sqlmesh/core/test/runner.py | 47 +------------------------------------ 2 files changed, 48 insertions(+), 46 deletions(-) diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index ce6a8a25ed..d58a5d750d 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -72,3 +72,50 @@ def addSuccess(self, test: unittest.TestCase) -> None: """ super().addSuccess(test) self.successes.append(test) + + def log_test_report(self, test_duration: float) -> None: + """ + Log the test report following unittest's conventions. + + Args: + test_duration: The duration of the tests. + """ + tests_run = self.testsRun + errors = self.errors + failures = self.failures + skipped = self.skipped + + is_success = not (errors or failures) + + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + stream = self.stream + + stream.write("\n") + + for test_case, failure in failures: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"FAIL: {test_case}") + stream.writeln(f"{test_case.shortDescription()}") + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln(failure) + + for _, error in errors: + stream.writeln(unittest.TextTestResult.separator1) + stream.writeln(f"ERROR: {error}") + stream.writeln(unittest.TextTestResult.separator2) + + # Output final report + stream.writeln(unittest.TextTestResult.separator2) + stream.writeln( + f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' + ) + stream.writeln( + f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' + ) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 8a0c928c5c..179e03a34f 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -45,51 +45,6 @@ def __init__( ) -def log_test_report(results: ModelTextTestResult, test_duration: float) -> None: - # Aggregate parallel test run results - tests_run = results.testsRun - errors = results.errors - failures = results.failures - skipped = results.skipped - - is_success = not (errors or failures) - - # Compute test info - infos = [] - if failures: - infos.append(f"failures={len(failures)}") - if errors: - infos.append(f"errors={len(errors)}") - if skipped: - infos.append(f"skipped={skipped}") - - # Report test errors - stream = results.stream - - stream.write("\n") - - for test_case, failure in failures: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"FAIL: {test_case}") - stream.writeln(f"{test_case.shortDescription()}") - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln(failure) - - for _, error in errors: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"ERROR: {error}") - stream.writeln(unittest.TextTestResult.separator2) - - # Test report - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln( - f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n' - ) - stream.writeln( - f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}' - ) - - def create_test_engine_adapters( model_test_metadata: list[ModelTestMetadata], config: C, @@ -232,7 +187,7 @@ def _run_single_test( combined_results.testsRun = len(test_results) - log_test_report(combined_results, test_duration=end_time - start_time) + combined_results.log_test_report(test_duration=end_time - start_time) return combined_results From 311618631b7e5f8032dc8a78e3287d0ace5daeb2 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 2 Apr 2025 11:02:42 -0400 Subject: [PATCH 6/7] PR Feedback 4 --- sqlmesh/core/test/result.py | 2 -- sqlmesh/core/test/runner.py | 37 ++++++++++++++++--------------------- tests/core/test_test.py | 2 +- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index d58a5d750d..304cb013aa 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -55,8 +55,6 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: def addError(self, test: unittest.TestCase, err: ErrorType) -> None: """Called when the test case test signals an error. - The traceback is suppressed because it is redundant and not useful. - Args: test: The test case. err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 179e03a34f..1c6ef1b25e 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -45,15 +45,16 @@ def __init__( ) -def create_test_engine_adapters( +def create_test_engine_adapters_for_tests( model_test_metadata: list[ModelTestMetadata], config: C, default_gateway: str, - testing_adapter_by_gateway: t.Dict[str, EngineAdapter], default_catalog: str | None = None, default_catalog_dialect: str = "", -) -> list[EngineAdapter]: - engine_adapters = [] +) -> t.Dict[ModelTestMetadata, EngineAdapter]: + testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} + metadata_to_adapter = {} + for metadata in model_test_metadata: gateway = metadata.body.get("gateway") or default_gateway test_connection = config.get_test_connection( @@ -76,9 +77,9 @@ def create_test_engine_adapters( register_comments_override=False ) - engine_adapters.append(adapter or testing_adapter_by_gateway[gateway]) + metadata_to_adapter[metadata] = adapter or testing_adapter_by_gateway[gateway] - return engine_adapters + return metadata_to_adapter def run_tests( @@ -101,7 +102,6 @@ def run_tests( verbosity: The verbosity level. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ - testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} default_gateway = gateway or config.default_gateway_name default_test_connection = config.get_test_connection( @@ -118,13 +118,12 @@ def run_tests( descriptions=True, ) - engine_adapters = create_test_engine_adapters( - model_test_metadata, - config, - default_gateway, - testing_adapter_by_gateway, - default_catalog, - default_catalog_dialect, + metadata_to_adapter = create_test_engine_adapters_for_tests( + model_test_metadata=model_test_metadata, + config=config, + default_gateway=default_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, ) def _run_single_test( @@ -168,20 +167,16 @@ def _run_single_test( with ThreadPoolExecutor(max_workers=num_workers) as pool: futures = [ pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) - for metadata, engine_adapter in zip(model_test_metadata, engine_adapters) + for metadata, engine_adapter in metadata_to_adapter.items() ] for future in concurrent.futures.as_completed(futures): test_results.append(future.result()) finally: - closed_adapters: t.Set[int] = set() - - for engine_adapter in engine_adapters: + for engine_adapter in set(metadata_to_adapter.values()): # The engine adapters list might have duplicates, so we ensure that we close each adapter once - hashed_adapter = hash(engine_adapter) - if engine_adapter and hashed_adapter not in closed_adapters: + if engine_adapter: engine_adapter.close() - closed_adapters.add(hashed_adapter) end_time = time.perf_counter() diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 3afaac4145..e50c3bdf06 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2132,7 +2132,7 @@ def test_test_with_resolve_template_macro(tmp_path: Path): _check_successful_or_raise(context.test()) -def test_test_generation_report(tmp_path: Path) -> None: +def test_test_output(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") original_test_file = tmp_path / "tests" / "test_full_model.yaml" From 2f5e48f02afa7d99c446f9b45dfa185908814a19 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Wed, 2 Apr 2025 16:36:38 -0400 Subject: [PATCH 7/7] Rename helper --- sqlmesh/core/test/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index 1c6ef1b25e..08a12bb59b 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -45,7 +45,7 @@ def __init__( ) -def create_test_engine_adapters_for_tests( +def create_testing_engine_adapters( model_test_metadata: list[ModelTestMetadata], config: C, default_gateway: str, @@ -118,7 +118,7 @@ def run_tests( descriptions=True, ) - metadata_to_adapter = create_test_engine_adapters_for_tests( + metadata_to_adapter = create_testing_engine_adapters( model_test_metadata=model_test_metadata, config=config, default_gateway=default_gateway,