Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 4 additions & 124 deletions sqlmesh/core/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from __future__ import annotations

import pathlib
import typing as t
import unittest

from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.model import Model
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
from sqlmesh.core.test.discovery import (
ModelTestMetadata as ModelTestMetadata,
Expand All @@ -14,121 +8,7 @@
load_model_test_file as load_model_test_file,
)
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
from sqlmesh.utils import UniqueKeyDict, Verbosity

if t.TYPE_CHECKING:
from sqlmesh.core.config.loader import C


def run_tests(
model_test_metadata: list[ModelTestMetadata],
models: UniqueKeyDict[str, Model],
config: C,
gateway: t.Optional[str] = None,
dialect: str | None = None,
verbosity: Verbosity = Verbosity.DEFAULT,
preserve_fixtures: bool = False,
stream: t.TextIO | None = None,
default_catalog: str | None = None,
default_catalog_dialect: str = "",
) -> ModelTextTestResult:
"""Create a test suite of ModelTest objects and run it.

Args:
model_test_metadata: A list of ModelTestMetadata named tuples.
models: All models to use for expansion and mapping of physical locations.
verbosity: The verbosity level.
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
"""
testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {}
default_gateway = gateway or config.default_gateway_name

try:
tests = []
for metadata in model_test_metadata:
body = metadata.body
gateway = body.get("gateway") or default_gateway
testing_engine_adapter = testing_adapter_by_gateway.get(gateway)
if not testing_engine_adapter:
testing_engine_adapter = config.get_test_connection(
gateway,
default_catalog,
default_catalog_dialect,
).create_engine_adapter(register_comments_override=False)
testing_adapter_by_gateway[gateway] = testing_engine_adapter

test = ModelTest.create_test(
body=body,
test_name=metadata.test_name,
models=models,
engine_adapter=testing_engine_adapter,
dialect=dialect,
path=metadata.path,
default_catalog=default_catalog,
preserve_fixtures=preserve_fixtures,
)
if test:
tests.append(test)

result = t.cast(
ModelTextTestResult,
unittest.TextTestRunner(
stream=stream,
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
resultclass=ModelTextTestResult,
).run(unittest.TestSuite(tests)),
)
finally:
for testing_engine_adapter in testing_adapter_by_gateway.values():
testing_engine_adapter.close()

return result


def run_model_tests(
tests: list[str],
models: UniqueKeyDict[str, Model],
config: C,
gateway: t.Optional[str] = None,
dialect: str | None = None,
verbosity: Verbosity = Verbosity.DEFAULT,
patterns: list[str] | None = None,
preserve_fixtures: bool = False,
stream: t.TextIO | None = None,
default_catalog: t.Optional[str] = None,
default_catalog_dialect: str = "",
) -> ModelTextTestResult:
"""Load and run tests.

Args:
tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order]
models: All models to use for expansion and mapping of physical locations.
verbosity: The verbosity level.
patterns: A list of patterns to match against.
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
"""
loaded_tests = []
for test in tests:
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
path = pathlib.Path(filename)

if test_name:
loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name])
else:
loaded_tests.extend(load_model_test_file(path, variables=config.variables).values())

if patterns:
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)

return run_tests(
loaded_tests,
models,
config,
gateway=gateway,
dialect=dialect,
verbosity=verbosity,
preserve_fixtures=preserve_fixtures,
stream=stream,
default_catalog=default_catalog,
default_catalog_dialect=default_catalog_dialect,
)
from sqlmesh.core.test.runner import (
run_model_tests as run_model_tests,
run_tests as run_tests,
)
58 changes: 58 additions & 0 deletions sqlmesh/core/test/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,19 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None:
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
"""
exctype, value, tb = err
self.original_err = (test, err)
return super().addFailure(test, (exctype, value, None)) # type: ignore

def addError(self, test: unittest.TestCase, err: ErrorType) -> None:
"""Called when the test case test signals an error.

Args:
test: The test case.
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
"""
self.original_err = (test, err)
return super().addError(test, err) # type: ignore

def addSuccess(self, test: unittest.TestCase) -> None:
"""Called when the test case test succeeds.

Expand All @@ -59,3 +70,50 @@ def addSuccess(self, test: unittest.TestCase) -> None:
"""
super().addSuccess(test)
self.successes.append(test)

def log_test_report(self, test_duration: float) -> None:
"""
Log the test report following unittest's conventions.

Args:
test_duration: The duration of the tests.
"""
tests_run = self.testsRun
errors = self.errors
failures = self.failures
skipped = self.skipped

is_success = not (errors or failures)

infos = []
if failures:
infos.append(f"failures={len(failures)}")
if errors:
infos.append(f"errors={len(errors)}")
if skipped:
infos.append(f"skipped={skipped}")

stream = self.stream

stream.write("\n")

for test_case, failure in failures:
stream.writeln(unittest.TextTestResult.separator1)
stream.writeln(f"FAIL: {test_case}")
stream.writeln(f"{test_case.shortDescription()}")
stream.writeln(unittest.TextTestResult.separator2)
stream.writeln(failure)

for _, error in errors:
stream.writeln(unittest.TextTestResult.separator1)
stream.writeln(f"ERROR: {error}")
stream.writeln(unittest.TextTestResult.separator2)

# Output final report
stream.writeln(unittest.TextTestResult.separator2)
stream.writeln(
f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n'
)
stream.writeln(
f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}'
)
Loading