diff --git a/tests/conftest.py b/tests/conftest.py index 955b50234c..e76588ffe2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -261,9 +261,12 @@ def rescope_lineage_cache(request): @pytest.fixture(autouse=True) def reset_console(): - from sqlmesh.core.console import set_console, NoopConsole + from sqlmesh.core.console import set_console, NoopConsole, get_console + orig_console = get_console() set_console(NoopConsole()) + yield + set_console(orig_console) @pytest.fixture diff --git a/tests/dbt/conftest.py b/tests/dbt/conftest.py index 846dfc6aa9..5e6444c8e6 100644 --- a/tests/dbt/conftest.py +++ b/tests/dbt/conftest.py @@ -127,3 +127,26 @@ def dbt_dummy_postgres_config() -> PostgresConfig: port=5432, schema="schema", ) + + +@pytest.fixture(scope="function", autouse=True) +def reset_dbt_globals(): + # This fixture is used to clear the memoized cache for _get_package_with_retries + # in dbt.clients.registry. This is necessary because the cache is shared across + # tests and can cause unexpected behavior if not cleared as some tests depend on + # the deprecation warning that _get_package_with_retries fires + yield + # https://github.com/dbt-labs/dbt-core/blob/main/tests/functional/conftest.py#L9 + try: + from dbt.clients.registry import _get_cached + + _get_cached.cache = {} + except Exception: + pass + # https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/tests/util.py#L82 + try: + from dbt_common.events.functions import reset_metadata_vars + + reset_metadata_vars() + except Exception: + pass diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index 797d638858..489d69683b 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -6,6 +6,7 @@ from sqlglot import exp from sqlglot.errors import SchemaError from sqlmesh import Context +from sqlmesh.core.console import NoopConsole, get_console from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json @@ -537,6 +538,7 @@ def test_load_deprecated_incremental_time_column( f.write(incremental_time_range_contents) snapshot_fqn = '"local"."main"."incremental_time_range"' + assert isinstance(get_console(), NoopConsole) context = Context(paths=project_dir) model = context.snapshots[snapshot_fqn].model # Validate model-level attributes