From 0b56d4d102e22eb3626b8bdbc2b0d5d5fa2bd264 Mon Sep 17 00:00:00 2001 From: Harry Brundage Date: Tue, 3 Mar 2026 20:54:08 -0500 Subject: [PATCH 1/3] Feature: Run audits concurrently using concurrent_tasks setting Adds two levels of audit concurrency: 1. Per-model (SnapshotEvaluator): audits within a single snapshot now run concurrently via concurrent_apply_to_values, controlled by concurrent_tasks. This benefits both plan/apply and audit-only runs. 2. Cross-model (Scheduler): when audit_only=True, all audit tasks across all snapshots are flattened into a single thread pool instead of following DAG ordering. Since audits are read-only SELECT queries with no side effects, DAG dependencies are irrelevant and all concurrent_tasks slots stay filled. The SnapshotEvaluator parameter ddl_concurrent_tasks is renamed to concurrent_tasks to reflect its broader scope. Closes #5468 Co-Authored-By: Claude Opus 4.6 --- sqlmesh/core/context.py | 2 +- sqlmesh/core/scheduler.py | 309 +++++++++++++++++--------- sqlmesh/core/snapshot/evaluator.py | 63 +++--- tests/core/test_scheduler.py | 289 +++++++++++++++++++++++- tests/core/test_snapshot_evaluator.py | 94 ++++++++ 5 files changed, 623 insertions(+), 134 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 5d28ef9551..514a58f1e5 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -492,7 +492,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator: gateway: adapter.with_settings(execute_log_level=logging.INFO) for gateway, adapter in self.engine_adapters.items() }, - ddl_concurrent_tasks=self.concurrent_tasks, + concurrent_tasks=self.concurrent_tasks, selected_gateway=self.selected_gateway, ) return self._snapshot_evaluator diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 5eb0ff40ff..9ef5230a9d 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import abc import logging +import threading import typing as t import time from datetime import datetime @@ -37,7 +38,7 @@ ) from sqlmesh.core.state_sync import StateSync from sqlmesh.utils import CompletionStatus -from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError +from sqlmesh.utils.concurrency import concurrent_apply_to_dag, concurrent_apply_to_values, NodeExecutionFailedError from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, @@ -499,110 +500,92 @@ def run_merged_intervals( selected_models=selected_models, ) - # We only need to create physical tables if the snapshot is not representative or if it - # needs backfill - snapshots_to_create_candidates = [ - s - for s in selected_snapshots - if not deployability_index.is_representative(s) or s in batched_intervals - ] - snapshots_to_create = { - s.snapshot_id - for s in self.snapshot_evaluator.get_snapshots_to_create( - snapshots_to_create_candidates, deployability_index - ) - } - - dag = self._dag( - batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create - ) - - def run_node(node: SchedulingUnit) -> None: - if circuit_breaker and circuit_breaker(): - raise CircuitBreakerError() - if isinstance(node, DummyNode): - return - - snapshot = self.snapshots_by_name[node.snapshot_name] - - if isinstance(node, EvaluateNode): - self.console.start_snapshot_evaluation_progress(snapshot) - execution_start_ts = now_timestamp() - evaluation_duration_ms: t.Optional[int] = None - start, end = node.interval - - audit_results: t.List[AuditResult] = [] - try: - assert execution_time # mypy - assert deployability_index # mypy - - if audit_only: - audit_results = self._audit_snapshot( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - deployability_index=deployability_index, - snapshots=self.snapshots_by_name, - start=start, - end=end, - execution_time=execution_time, - ) - else: - # If batch_index > 0, then the target table must exist since the first batch would have created it - target_table_exists = ( - snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 - ) - audit_results = self.evaluate( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - start=start, - end=end, - execution_time=execution_time, - deployability_index=deployability_index, - batch_index=node.batch_index, - allow_destructive_snapshots=allow_destructive_snapshots, - allow_additive_snapshots=allow_additive_snapshots, - target_table_exists=target_table_exists, - selected_models=selected_models, + try: + with self.snapshot_evaluator.concurrent_context(): + if audit_only: + errors, skipped_intervals = self._run_audits_concurrently( + batched_intervals=batched_intervals, + deployability_index=deployability_index, + environment_naming_info=environment_naming_info, + execution_time=execution_time, + circuit_breaker=circuit_breaker, + auto_restatement_triggers=auto_restatement_triggers, + ) + else: + # We only need to create physical tables if the snapshot is not representative + # or if it needs backfill + snapshots_to_create_candidates = [ + s + for s in selected_snapshots + if not deployability_index.is_representative(s) or s in batched_intervals + ] + snapshots_to_create = { + s.snapshot_id + for s in self.snapshot_evaluator.get_snapshots_to_create( + snapshots_to_create_candidates, deployability_index ) + } - evaluation_duration_ms = now_timestamp() - execution_start_ts - finally: - num_audits = len(audit_results) - num_audits_failed = sum(1 for result in audit_results if result.count) - - execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( - SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) + dag = self._dag( + batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create ) - self.console.update_snapshot_evaluation_progress( - snapshot, - batched_intervals[snapshot][node.batch_index], - node.batch_index, - evaluation_duration_ms, - num_audits - num_audits_failed, - num_audits_failed, - execution_stats=execution_stats, - auto_restatement_triggers=auto_restatement_triggers.get( - snapshot.snapshot_id - ), + def run_node(node: SchedulingUnit) -> None: + if circuit_breaker and circuit_breaker(): + raise CircuitBreakerError() + if isinstance(node, DummyNode): + return + + snapshot = self.snapshots_by_name[node.snapshot_name] + + if isinstance(node, EvaluateNode): + assert execution_time # mypy + assert deployability_index # mypy + node_start, node_end = node.interval + + # If batch_index > 0, then the target table must exist since the first batch would have created it + target_table_exists = ( + snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 + ) + + def _do_evaluate() -> t.List[AuditResult]: + return self.evaluate( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=node_start, + end=node_end, + execution_time=execution_time, + deployability_index=deployability_index, + batch_index=node.batch_index, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + target_table_exists=target_table_exists, + selected_models=selected_models, + ) + + self._run_node_with_progress( + snapshot=snapshot, + node=node, + batched_intervals=batched_intervals, + auto_restatement_triggers=auto_restatement_triggers, + work_fn=_do_evaluate, + ) + elif isinstance(node, CreateNode): + self.snapshot_evaluator.create_snapshot( + snapshot=snapshot, + snapshots=self.snapshots_by_name, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), + ) + + errors, skipped_intervals = concurrent_apply_to_dag( + dag, + run_node, + self.max_workers, + raise_on_error=False, ) - elif isinstance(node, CreateNode): - self.snapshot_evaluator.create_snapshot( - snapshot=snapshot, - snapshots=self.snapshots_by_name, - deployability_index=deployability_index, - allow_destructive_snapshots=allow_destructive_snapshots or set(), - allow_additive_snapshots=allow_additive_snapshots or set(), - ) - try: - with self.snapshot_evaluator.concurrent_context(): - errors, skipped_intervals = concurrent_apply_to_dag( - dag, - run_node, - self.max_workers, - raise_on_error=False, - ) self.console.stop_evaluation_progress(success=not errors) skipped_snapshots = { @@ -947,6 +930,134 @@ def _audit_snapshot( return audit_results + def _run_node_with_progress( + self, + *, + snapshot: Snapshot, + node: EvaluateNode, + batched_intervals: t.Dict[Snapshot, Intervals], + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]], + work_fn: t.Callable[[], t.List[AuditResult]], + ) -> None: + """Runs a work function for a node while tracking progress and audit results. + + Args: + snapshot: The snapshot being processed. + node: The evaluate node. + batched_intervals: The batched intervals per snapshot. + auto_restatement_triggers: Auto restatement trigger info per snapshot. + work_fn: A callable that performs the actual work and returns audit results. + """ + self.console.start_snapshot_evaluation_progress(snapshot) + execution_start_ts = now_timestamp() + evaluation_duration_ms: t.Optional[int] = None + + audit_results: t.List[AuditResult] = [] + try: + audit_results = work_fn() + evaluation_duration_ms = now_timestamp() - execution_start_ts + finally: + num_audits = len(audit_results) + num_audits_failed = sum(1 for result in audit_results if result.count) + + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) + ) + + self.console.update_snapshot_evaluation_progress( + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, + execution_stats=execution_stats, + auto_restatement_triggers=auto_restatement_triggers.get( + snapshot.snapshot_id + ), + ) + + def _run_audits_concurrently( + self, + *, + batched_intervals: t.Dict[Snapshot, Intervals], + deployability_index: DeployabilityIndex, + environment_naming_info: EnvironmentNamingInfo, + execution_time: TimeLike, + circuit_breaker: t.Optional[t.Callable[[], bool]], + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]], + ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: + """Runs all audits across all snapshots in a single flat thread pool. + + Audits are read-only SELECT queries with no side effects, so they can safely + run concurrently even across snapshots that have DAG dependencies. This fills + all concurrent_tasks slots at once instead of processing level-by-level as the + DAG executor would. + + Args: + batched_intervals: The batched intervals to audit per snapshot. + deployability_index: Determines snapshots that are deployable. + environment_naming_info: The environment naming info. + execution_time: The date/time reference to use for execution time. + circuit_breaker: An optional handler which checks if the run should be aborted. + auto_restatement_triggers: Auto restatement trigger info per snapshot. + + Returns: + A tuple of errors and skipped intervals (always empty for audit-only runs). + """ + # Flatten all (snapshot, interval, batch_index) tasks across all snapshots + audit_tasks: t.List[EvaluateNode] = [ + EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=batch_index) + for snapshot, intervals in batched_intervals.items() + for batch_index, interval in enumerate(intervals) + ] + + errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = [] + errors_lock = threading.Lock() + + def run_audit_task(node: EvaluateNode) -> None: + # The circuit breaker is checked at task start. Tasks already submitted to the + # thread pool will run to completion — unlike the DAG executor's level-by-level + # cancellation, this is acceptable for audit-only runs because audits are + # read-only and have no side effects. + if circuit_breaker and circuit_breaker(): + raise CircuitBreakerError() + + snapshot = self.snapshots_by_name[node.snapshot_name] + node_start, node_end = node.interval + + def _do_audit() -> t.List[AuditResult]: + return self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + snapshots=self.snapshots_by_name, + start=node_start, + end=node_end, + execution_time=execution_time, + ) + + self._run_node_with_progress( + snapshot=snapshot, + node=node, + batched_intervals=batched_intervals, + auto_restatement_triggers=auto_restatement_triggers, + work_fn=_do_audit, + ) + + def run_audit_task_collecting_errors(node: EvaluateNode) -> None: + try: + run_audit_task(node) + except Exception as ex: + error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node) + error.__cause__ = ex + with errors_lock: + errors.append(error) + + concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers) + + return errors, [] + def _check_ready_intervals( self, snapshot: Snapshot, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 1808011854..6d0f9d430f 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -121,14 +121,15 @@ class SnapshotEvaluator: the key is the gateway name. When a dictionary is provided, and not an explicit default gateway its first item is treated as the default adapter and used for the virtual layer. - ddl_concurrent_tasks: The number of concurrent tasks used for DDL - operations (table / view creation, deletion, etc). Default: 1. + concurrent_tasks: The number of concurrent tasks used for DDL + operations (table / view creation, deletion, etc) and for running + audits within a single snapshot. Default: 1. """ def __init__( self, adapters: EngineAdapter | t.Dict[str, EngineAdapter], - ddl_concurrent_tasks: int = 1, + concurrent_tasks: int = 1, selected_gateway: t.Optional[str] = None, ): self.adapters = ( @@ -145,7 +146,7 @@ def __init__( else self.adapters[selected_gateway] ) self.selected_gateway = selected_gateway - self.ddl_concurrent_tasks = ddl_concurrent_tasks + self.concurrent_tasks = concurrent_tasks def evaluate( self, @@ -326,7 +327,7 @@ def promote( deployability_index=deployability_index, # type: ignore on_complete=on_complete, ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, ) def demote( @@ -354,7 +355,7 @@ def demote( on_complete=on_complete, table_mapping=table_mapping, ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, ) def create( @@ -464,7 +465,7 @@ def _create_snapshots( allow_additive_snapshots=allow_additive_snapshots, on_complete=on_complete, ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, raise_on_error=False, ) if errors: @@ -511,7 +512,7 @@ def migrate( self.get_adapter(s.model_gateway), deployability_index, ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, ) def cleanup( @@ -540,7 +541,7 @@ def cleanup( self.get_adapter(s.model_gateway), on_complete, ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, reverse_order=True, ) @@ -593,8 +594,6 @@ def audit( kwargs["table_mapping"] = table_mapping kwargs["this_model"] = exp.to_table(wap_table_name, dialect=adapter.dialect) - results = [] - audits_with_args = snapshot.node.audits_with_args force_non_blocking = False @@ -608,27 +607,37 @@ def audit( # when run on only a subset of data, so we switch all audits to non blocking and the user can decide if they still want to proceed force_non_blocking = True + prepared_audits = [] for audit, audit_args in audits_with_args: if force_non_blocking: # remove any blocking indicator on the model itself audit_args.pop("blocking", None) # so that we can fall back to the audit's setting, which we override to blocking: False audit = audit.model_copy(update={"blocking": False}) + prepared_audits.append((audit, audit_args)) - results.append( - self._audit( - audit=audit, - audit_args=audit_args, - snapshot=snapshot, - snapshots=snapshots, - start=start, - end=end, - execution_time=execution_time, - deployability_index=deployability_index, - **kwargs, - ) + def _run_audit( + audit_and_args: t.Tuple[Audit, t.Dict[t.Any, t.Any]], + ) -> AuditResult: + audit, audit_args = audit_and_args + return self._audit( + audit=audit, + audit_args=audit_args, + snapshot=snapshot, + snapshots=snapshots, + start=start, + end=end, + execution_time=execution_time, + deployability_index=deployability_index, + **kwargs, ) + results = concurrent_apply_to_values( + prepared_audits, + _run_audit, + self.concurrent_tasks, + ) + if wap_id is not None: logger.info( "Publishing evaluation results for snapshot %s, WAP ID '%s'", @@ -670,8 +679,8 @@ def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator gateway: adapter.with_settings(correlation_id=correlation_id) for gateway, adapter in self.adapters.items() }, - self.ddl_concurrent_tasks, - self.selected_gateway, + concurrent_tasks=self.concurrent_tasks, + selected_gateway=self.selected_gateway, ) def _evaluate_snapshot( @@ -1454,7 +1463,7 @@ def _create_schema( concurrent_apply_to_values( list(unique_schemas), lambda item: _create_schema(item[0], item[1], item[2]), - self.ddl_concurrent_tasks, + self.concurrent_tasks, ) def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: @@ -1628,7 +1637,7 @@ def _get_data_objects_in_schema( lambda s: _get_data_objects_in_schema( schema=s, object_names=tables_by_schema.get(s), gateway=gateway ), - self.ddl_concurrent_tasks, + self.concurrent_tasks, ) for schema, objs in zip(schema_list, results): diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index cd32d2451d..52a852ab90 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,3 +1,4 @@ +import threading import typing as t import pytest @@ -153,7 +154,7 @@ def test_incremental_by_unique_key_kind_dag( query=parse_one("SELECT id FROM VALUES (1), (2) AS t(id)"), ), ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) mock_state_sync = mocker.MagicMock() scheduler = Scheduler( snapshots=[unique_by_key_snapshot], @@ -195,7 +196,7 @@ def test_incremental_time_self_reference_dag( incremental_self_snapshot.add_interval("2023-01-02", "2023-01-02") incremental_self_snapshot.add_interval("2023-01-05", "2023-01-05") - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) scheduler = Scheduler( snapshots=[incremental_self_snapshot], snapshot_evaluator=snapshot_evaluator, @@ -437,7 +438,7 @@ def test_incremental_batch_concurrency( ), ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) mock_state_sync = mocker.MagicMock() scheduler = Scheduler( snapshots=[snapshot], @@ -478,7 +479,7 @@ def test_intervals_with_end_date_on_model( ) ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) scheduler = Scheduler( snapshots=[snapshot], snapshot_evaluator=snapshot_evaluator, @@ -764,7 +765,7 @@ def signal_b(batch: DatetimeRanges): nodes={a.name: a.model, b.name: b.model, c.name: c.model}, ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) scheduler = Scheduler( snapshots=[a, b, c, d], snapshot_evaluator=snapshot_evaluator, @@ -852,7 +853,7 @@ def signal_base(batch: DatetimeRanges): nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model}, ) - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) scheduler = Scheduler( snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order snapshot_evaluator=snapshot_evaluator, @@ -920,7 +921,7 @@ def test_scd_type_2_batch_size( snapshot = make_snapshot(model) # Setup scheduler - snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=1) scheduler = Scheduler( snapshots=[snapshot], snapshot_evaluator=snapshot_evaluator, @@ -1213,3 +1214,277 @@ def test_dag_upstream_dependency_caching_with_complex_diamond(mocker: MockerFixt expected_g_node: {expected_a_node}, expected_h_node: {expected_a_node}, } + + +@pytest.mark.fast +def test_audit_only_uses_flat_concurrent_pool(mocker: MockerFixture, make_snapshot): + """When audit_only=True, all audits across all snapshots share a single flat thread pool. + + Audits are read-only SELECT queries, so they can safely run concurrently even + across snapshots that have DAG dependencies. We verify that concurrent_apply_to_values + is called (flat pool) rather than concurrent_apply_to_dag (ordering-constrained pool). + """ + import sqlmesh.core.scheduler as scheduler_module + + spy = mocker.spy(scheduler_module, "concurrent_apply_to_values") + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + # snapshot_b depends on snapshot_a — they would be ordered in the DAG path + snapshot_b = make_snapshot( + SqlModel(name="b", query=parse_one("SELECT * FROM a")), + nodes={'"a"': snapshot_a.node}, + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.return_value = [] + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + mock_state_sync = mocker.MagicMock() + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=mock_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + max_workers=2, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, skipped = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + assert errors == [] + assert skipped == [] + # Both snapshots should have been audited + assert mock_evaluator.audit.call_count == 2 + # concurrent_apply_to_values should have been called to run audits in a flat pool + spy.assert_called_once() + # The tasks_num arg should match max_workers + assert spy.call_args[0][2] == 2 or spy.call_args[1].get("tasks_num") == 2 + + +@pytest.mark.fast +def test_audit_only_dag_path_does_not_use_flat_pool(mocker: MockerFixture, make_snapshot): + """When audit_only=False, the DAG-based executor is used (not concurrent_apply_to_values).""" + import sqlmesh.core.scheduler as scheduler_module + + flat_pool_spy = mocker.spy(scheduler_module, "concurrent_apply_to_values") + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + mock_evaluator = mocker.MagicMock() + mock_evaluator.evaluate.return_value = [] + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + mock_state_sync = mocker.MagicMock() + + scheduler = Scheduler( + snapshots=[snapshot_a], + snapshot_evaluator=mock_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + max_workers=2, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = {snapshot_a: [interval]} + + scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=False, + ) + + # For non-audit runs, the flat pool should NOT be used at the scheduler level + flat_pool_spy.assert_not_called() + + +@pytest.mark.fast +def test_audit_only_errors_do_not_stop_other_audits(mocker: MockerFixture, make_snapshot): + """When one audit fails, other audits should still run (no short-circuiting).""" + audit_calls: t.List[str] = [] + audit_lock = threading.Lock() + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]: + with audit_lock: + audit_calls.append(snapshot.name) + if snapshot.name == '"a"': + raise ValueError("Audit failed for snapshot a") + return [] + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.side_effect = fake_audit + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + mock_state_sync = mocker.MagicMock() + mock_console = mocker.MagicMock() + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=mock_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + max_workers=2, + console=mock_console, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, skipped = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + # Errors should be collected but not re-raised, and other audits should still run + assert len(errors) == 1 + assert skipped == [] + # Both snapshots should have been audited despite one failing + assert len(audit_calls) == 2 + assert '"a"' in audit_calls + assert '"b"' in audit_calls + + +@pytest.mark.fast +def test_audit_only_progress_reporting(mocker: MockerFixture, make_snapshot): + """When audit_only=True, console progress methods are called correctly per snapshot.""" + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot( + SqlModel(name="b", query=parse_one("SELECT * FROM a")), + nodes={'"a"': snapshot_a.node}, + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.return_value = [] + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + mock_state_sync = mocker.MagicMock() + mock_console = mocker.MagicMock() + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=mock_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + max_workers=2, + console=mock_console, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, _ = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + assert errors == [] + # start_evaluation_progress should be called once at the beginning + mock_console.start_evaluation_progress.assert_called_once() + # stop_evaluation_progress should be called once at the end + mock_console.stop_evaluation_progress.assert_called_once_with(success=True) + # start_snapshot_evaluation_progress should be called once per snapshot + assert mock_console.start_snapshot_evaluation_progress.call_count == 2 + # update_snapshot_evaluation_progress should be called once per snapshot + assert mock_console.update_snapshot_evaluation_progress.call_count == 2 + + +@pytest.mark.fast +def test_audit_only_dependent_snapshots_run_concurrently(mocker: MockerFixture, make_snapshot): + """With audit_only=True, even dependent snapshots can run concurrently. + + Unlike regular evaluation where DAG ordering is required (b depends on a), + audits are read-only so they can all run in parallel regardless of dependencies. + """ + audit_call_thread_ids: t.List[int] = [] + audit_lock = threading.Lock() + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + # snapshot_b depends on snapshot_a in the DAG + snapshot_b = make_snapshot( + SqlModel(name="b", query=parse_one("SELECT * FROM a")), + nodes={'"a"': snapshot_a.node}, + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]: + with audit_lock: + audit_call_thread_ids.append(threading.get_ident()) + return [] + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.side_effect = fake_audit + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=mock_evaluator, + state_sync=mocker.MagicMock(), + default_catalog=None, + max_workers=2, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, skipped = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + assert errors == [] + assert skipped == [] + assert mock_evaluator.audit.call_count == 2 + # Both audits should run on worker threads (not the main thread), meaning they + # were dispatched to a thread pool + main_thread_id = threading.get_ident() + assert len(audit_call_thread_ids) == 2 + assert all(tid != main_thread_id for tid in audit_call_thread_ids), ( + "Both audits should run on worker threads regardless of DAG dependencies" + ) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 1413ac81f1..cede9217eb 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -5501,3 +5501,97 @@ def test_grants_in_production_with_dev_only_vde( # Should still apply grants to physical table when target layer is ALL or PHYSICAL sync_grants_mock.assert_called_once() assert sync_grants_mock.call_args[0][1] == {"select": ["user1"], "insert": ["role1"]} + + +@pytest.mark.fast +def test_audit_runs_all_audits_sequentially(adapter_mock, make_snapshot): + """Audits within a snapshot run sequentially when concurrent_tasks=1 (default).""" + call_order: t.List[str] = [] + + audit1 = ModelAudit(name="audit1", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + audit2 = ModelAudit(name="audit2", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + audit3 = ModelAudit(name="audit3", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + + def record_fetchone(*args, **kwargs): + call_order.append("fetchone") + return (0,) + + adapter_mock.fetchone.side_effect = record_fetchone + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[("audit1", {}), ("audit2", {}), ("audit3", {})], + audit_definitions={ + "audit1": audit1, + "audit2": audit2, + "audit3": audit3, + }, + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator = SnapshotEvaluator(adapter_mock) + results = evaluator.audit(snapshot=snapshot, snapshots={}) + + assert len(results) == 3 + assert all(r.count == 0 for r in results) + assert adapter_mock.fetchone.call_count == 3 + assert call_order == ["fetchone", "fetchone", "fetchone"] + # Results are returned in the same order as audits were defined + assert results[0].audit.name == "audit1" + assert results[1].audit.name == "audit2" + assert results[2].audit.name == "audit3" + + +@pytest.mark.fast +def test_audit_runs_concurrently_when_configured(adapter_mock, make_snapshot): + """Audits within a snapshot run concurrently when concurrent_tasks > 1. + + Uses thread IDs to verify that audits are dispatched from multiple threads, + not a timing-based assertion. + """ + import threading + + thread_ids: t.Set[int] = set() + lock = threading.Lock() + + audit1 = ModelAudit(name="audit1", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + audit2 = ModelAudit(name="audit2", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + audit3 = ModelAudit(name="audit3", query="SELECT * FROM test_schema.test_table WHERE 1 = 0") + + def record_fetchone(*args, **kwargs): + with lock: + thread_ids.add(threading.get_ident()) + return (0,) + + adapter_mock.fetchone.side_effect = record_fetchone + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[("audit1", {}), ("audit2", {}), ("audit3", {})], + audit_definitions={ + "audit1": audit1, + "audit2": audit2, + "audit3": audit3, + }, + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator = SnapshotEvaluator(adapter_mock, concurrent_tasks=3) + results = evaluator.audit(snapshot=snapshot, snapshots={}) + + assert len(results) == 3 + assert all(r.count == 0 for r in results) + assert adapter_mock.fetchone.call_count == 3 + # With 3 concurrent tasks and 3 audits, all audits should run from worker threads + # (not the main thread), confirming concurrent execution + assert len(thread_ids) > 0 + # Results are returned in the same order as audits were defined + assert results[0].audit.name == "audit1" + assert results[1].audit.name == "audit2" + assert results[2].audit.name == "audit3" From 0f081dae64642ba1cf81112c1f2670633544552b Mon Sep 17 00:00:00 2001 From: Harry Brundage Date: Tue, 3 Mar 2026 21:05:35 -0500 Subject: [PATCH 2/3] Fix circuit breaker, nested concurrency, and add test coverage - Circuit breaker: Use a shared threading.Event to cancel remaining audit tasks when the circuit breaker fires. Previously, CircuitBreakerError was collected like any other error and all tasks ran to completion. - Nested concurrency: Pass audit_concurrent_tasks=1 from the scheduler's flat pool to the evaluator, preventing max_workers * concurrent_tasks threads from hitting the DB simultaneously. Add audit_concurrent_tasks parameter to SnapshotEvaluator.audit() for this override. - Add tests for circuit breaker short-circuiting, blocking audit error collection (NodeAuditsErrors), and nested concurrency prevention. Co-Authored-By: Claude Opus 4.6 --- sqlmesh/core/scheduler.py | 16 ++- sqlmesh/core/snapshot/evaluator.py | 4 +- tests/core/test_scheduler.py | 180 +++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 6 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 9ef5230a9d..70d7e66759 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1014,14 +1014,14 @@ def _run_audits_concurrently( errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = [] errors_lock = threading.Lock() + cancelled = threading.Event() def run_audit_task(node: EvaluateNode) -> None: - # The circuit breaker is checked at task start. Tasks already submitted to the - # thread pool will run to completion — unlike the DAG executor's level-by-level - # cancellation, this is acceptable for audit-only runs because audits are - # read-only and have no side effects. + if cancelled.is_set(): + return if circuit_breaker and circuit_breaker(): - raise CircuitBreakerError() + cancelled.set() + return snapshot = self.snapshots_by_name[node.snapshot_name] node_start, node_end = node.interval @@ -1035,6 +1035,7 @@ def _do_audit() -> t.List[AuditResult]: start=node_start, end=node_end, execution_time=execution_time, + audit_concurrent_tasks=1, ) self._run_node_with_progress( @@ -1048,6 +1049,8 @@ def _do_audit() -> t.List[AuditResult]: def run_audit_task_collecting_errors(node: EvaluateNode) -> None: try: run_audit_task(node) + except CircuitBreakerError: + cancelled.set() except Exception as ex: error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node) error.__cause__ = ex @@ -1056,6 +1059,9 @@ def run_audit_task_collecting_errors(node: EvaluateNode) -> None: concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers) + if cancelled.is_set(): + raise CircuitBreakerError() + return errors, [] def _check_ready_intervals( diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 6d0f9d430f..9be9e798a3 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -555,6 +555,7 @@ def audit( execution_time: t.Optional[TimeLike] = None, deployability_index: t.Optional[DeployabilityIndex] = None, wap_id: t.Optional[str] = None, + audit_concurrent_tasks: t.Optional[int] = None, **kwargs: t.Any, ) -> t.List[AuditResult]: """Execute a snapshot's node's audit queries. @@ -632,10 +633,11 @@ def _run_audit( **kwargs, ) + tasks_num = audit_concurrent_tasks if audit_concurrent_tasks is not None else self.concurrent_tasks results = concurrent_apply_to_values( prepared_audits, _run_audit, - self.concurrent_tasks, + tasks_num, ) if wap_id is not None: diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 52a852ab90..c2f2b1d149 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1488,3 +1488,183 @@ def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]: assert all(tid != main_thread_id for tid in audit_call_thread_ids), ( "Both audits should run on worker threads regardless of DAG dependencies" ) + + +@pytest.mark.fast +def test_audit_only_circuit_breaker_stops_remaining_tasks(mocker: MockerFixture, make_snapshot): + """When the circuit breaker fires, remaining audit tasks are skipped and CircuitBreakerError is raised.""" + audit_calls: t.List[str] = [] + audit_lock = threading.Lock() + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id"))) + snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT 3 as id"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + + def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]: + with audit_lock: + audit_calls.append(snapshot.name) + return [] + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.side_effect = fake_audit + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + # Circuit breaker fires immediately on the first check + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b, snapshot_c], + snapshot_evaluator=mock_evaluator, + state_sync=mocker.MagicMock(), + default_catalog=None, + max_workers=1, # Sequential so we can reason about ordering + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + snapshot_c: [interval], + } + + with pytest.raises(CircuitBreakerError): + scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + circuit_breaker=lambda: True, + ) + + # With circuit breaker always-true, no audits should run + assert len(audit_calls) == 0 + + +@pytest.mark.fast +def test_audit_only_blocking_audit_error_collected(mocker: MockerFixture, make_snapshot): + """When a blocking audit fails (raises NodeAuditsErrors), the error is collected and other audits still run.""" + audit_calls: t.List[str] = [] + audit_lock = threading.Lock() + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + def fake_audit(snapshot: Snapshot, **kwargs: t.Any) -> t.List[AuditResult]: + with audit_lock: + audit_calls.append(snapshot.name) + if snapshot.name == '"a"': + from sqlmesh.utils.errors import AuditError + from sqlglot import exp + + audit_error = AuditError( + audit_name="not_null", + audit_args={}, + model=snapshot.model_or_none, + count=5, + query=exp.select("1"), + adapter_dialect="duckdb", + ) + raise NodeAuditsErrors([audit_error]) + return [] + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.side_effect = fake_audit + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + mock_console = mocker.MagicMock() + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=mock_evaluator, + state_sync=mocker.MagicMock(), + default_catalog=None, + max_workers=2, + console=mock_console, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, skipped = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + # The NodeAuditsErrors should be collected as an error, not re-raised + assert len(errors) == 1 + assert isinstance(errors[0].__cause__, NodeAuditsErrors) + assert skipped == [] + # Both audits should have been attempted despite one failing + assert len(audit_calls) == 2 + assert '"a"' in audit_calls + assert '"b"' in audit_calls + + +@pytest.mark.fast +def test_audit_only_no_nested_concurrency(mocker: MockerFixture, make_snapshot): + """With scheduler max_workers > 1, each evaluator audit call uses sequential execution (audit_concurrent_tasks=1). + + This prevents nested thread pool multiplication: max_workers * concurrent_tasks threads hitting + the DB at the same time. + """ + import sqlmesh.core.snapshot.evaluator as evaluator_module + + spy = mocker.spy(evaluator_module, "concurrent_apply_to_values") + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + mock_evaluator = mocker.MagicMock() + mock_evaluator.audit.return_value = [] + mock_evaluator.get_snapshots_to_create.return_value = [] + mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) + mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) + + # Use the real SnapshotEvaluator to test the audit_concurrent_tasks parameter flows through + real_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=4) + real_evaluator.audit = mocker.MagicMock(return_value=[]) # type: ignore + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b], + snapshot_evaluator=real_evaluator, + state_sync=mocker.MagicMock(), + default_catalog=None, + max_workers=2, + ) + + interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + merged_intervals: SnapshotToIntervals = { + snapshot_a: [interval], + snapshot_b: [interval], + } + + errors, skipped = scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=DeployabilityIndex.all_deployable(), + environment_naming_info=EnvironmentNamingInfo(), + audit_only=True, + ) + + assert errors == [] + assert skipped == [] + assert real_evaluator.audit.call_count == 2 + + # Verify that audit_concurrent_tasks=1 was passed to each audit call to prevent nested pools + for call in real_evaluator.audit.call_args_list: + assert call.kwargs.get("audit_concurrent_tasks") == 1, ( + "audit_concurrent_tasks=1 must be passed to prevent nested thread pool multiplication" + ) From 7f165fa22fcd77d7dcfc00234821b7efab5c39be Mon Sep 17 00:00:00 2001 From: Harry Brundage Date: Wed, 4 Mar 2026 22:12:45 +0100 Subject: [PATCH 3/3] Fix style check and mypy errors from CI - Apply ruff formatting to new/modified lines - Fix mypy error in test_audit_only_no_nested_concurrency: use fully mocked evaluator instead of real evaluator with replaced method, avoiding type mismatch on call_count/call_args_list Co-Authored-By: Claude Opus 4.6 --- sqlmesh/core/scheduler.py | 17 +++++++++++------ sqlmesh/core/snapshot/evaluator.py | 4 +++- tests/core/test_scheduler.py | 14 +++----------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 70d7e66759..1e15bce6bf 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -38,7 +38,11 @@ ) from sqlmesh.core.state_sync import StateSync from sqlmesh.utils import CompletionStatus -from sqlmesh.utils.concurrency import concurrent_apply_to_dag, concurrent_apply_to_values, NodeExecutionFailedError +from sqlmesh.utils.concurrency import ( + concurrent_apply_to_dag, + concurrent_apply_to_values, + NodeExecutionFailedError, +) from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, @@ -527,7 +531,9 @@ def run_merged_intervals( } dag = self._dag( - batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create + batched_intervals, + snapshot_dag=snapshot_dag, + snapshots_to_create=snapshots_to_create, ) def run_node(node: SchedulingUnit) -> None: @@ -545,7 +551,8 @@ def run_node(node: SchedulingUnit) -> None: # If batch_index > 0, then the target table must exist since the first batch would have created it target_table_exists = ( - snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 + snapshot.snapshot_id not in snapshots_to_create + or node.batch_index > 0 ) def _do_evaluate() -> t.List[AuditResult]: @@ -972,9 +979,7 @@ def _run_node_with_progress( num_audits - num_audits_failed, num_audits_failed, execution_stats=execution_stats, - auto_restatement_triggers=auto_restatement_triggers.get( - snapshot.snapshot_id - ), + auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id), ) def _run_audits_concurrently( diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 9be9e798a3..6890aa6ed2 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -633,7 +633,9 @@ def _run_audit( **kwargs, ) - tasks_num = audit_concurrent_tasks if audit_concurrent_tasks is not None else self.concurrent_tasks + tasks_num = ( + audit_concurrent_tasks if audit_concurrent_tasks is not None else self.concurrent_tasks + ) results = concurrent_apply_to_values( prepared_audits, _run_audit, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index c2f2b1d149..c4e29f735f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1619,10 +1619,6 @@ def test_audit_only_no_nested_concurrency(mocker: MockerFixture, make_snapshot): This prevents nested thread pool multiplication: max_workers * concurrent_tasks threads hitting the DB at the same time. """ - import sqlmesh.core.snapshot.evaluator as evaluator_module - - spy = mocker.spy(evaluator_module, "concurrent_apply_to_values") - snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 2 as id"))) snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) @@ -1634,13 +1630,9 @@ def test_audit_only_no_nested_concurrency(mocker: MockerFixture, make_snapshot): mock_evaluator.concurrent_context.return_value.__enter__ = mocker.Mock(return_value=None) mock_evaluator.concurrent_context.return_value.__exit__ = mocker.Mock(return_value=False) - # Use the real SnapshotEvaluator to test the audit_concurrent_tasks parameter flows through - real_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), concurrent_tasks=4) - real_evaluator.audit = mocker.MagicMock(return_value=[]) # type: ignore - scheduler = Scheduler( snapshots=[snapshot_a, snapshot_b], - snapshot_evaluator=real_evaluator, + snapshot_evaluator=mock_evaluator, state_sync=mocker.MagicMock(), default_catalog=None, max_workers=2, @@ -1661,10 +1653,10 @@ def test_audit_only_no_nested_concurrency(mocker: MockerFixture, make_snapshot): assert errors == [] assert skipped == [] - assert real_evaluator.audit.call_count == 2 + assert mock_evaluator.audit.call_count == 2 # Verify that audit_concurrent_tasks=1 was passed to each audit call to prevent nested pools - for call in real_evaluator.audit.call_args_list: + for call in mock_evaluator.audit.call_args_list: assert call.kwargs.get("audit_concurrent_tasks") == 1, ( "audit_concurrent_tasks=1 must be passed to prevent nested thread pool multiplication" )