Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator:
gateway: adapter.with_settings(execute_log_level=logging.INFO)
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
)
return self._snapshot_evaluator
Expand Down
320 changes: 221 additions & 99 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
import abc
import logging
import threading
import typing as t
import time
from datetime import datetime
Expand Down Expand Up @@ -37,7 +38,11 @@
)
from sqlmesh.core.state_sync import StateSync
from sqlmesh.utils import CompletionStatus
from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError
from sqlmesh.utils.concurrency import (
concurrent_apply_to_dag,
concurrent_apply_to_values,
NodeExecutionFailedError,
)
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import (
TimeLike,
Expand Down Expand Up @@ -499,110 +504,95 @@ def run_merged_intervals(
selected_models=selected_models,
)

# We only need to create physical tables if the snapshot is not representative or if it
# needs backfill
snapshots_to_create_candidates = [
s
for s in selected_snapshots
if not deployability_index.is_representative(s) or s in batched_intervals
]
snapshots_to_create = {
s.snapshot_id
for s in self.snapshot_evaluator.get_snapshots_to_create(
snapshots_to_create_candidates, deployability_index
)
}

dag = self._dag(
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
)

def run_node(node: SchedulingUnit) -> None:
if circuit_breaker and circuit_breaker():
raise CircuitBreakerError()
if isinstance(node, DummyNode):
return

snapshot = self.snapshots_by_name[node.snapshot_name]

if isinstance(node, EvaluateNode):
self.console.start_snapshot_evaluation_progress(snapshot)
execution_start_ts = now_timestamp()
evaluation_duration_ms: t.Optional[int] = None
start, end = node.interval

audit_results: t.List[AuditResult] = []
try:
assert execution_time # mypy
assert deployability_index # mypy

if audit_only:
audit_results = self._audit_snapshot(
snapshot=snapshot,
environment_naming_info=environment_naming_info,
deployability_index=deployability_index,
snapshots=self.snapshots_by_name,
start=start,
end=end,
execution_time=execution_time,
)
else:
# If batch_index > 0, then the target table must exist since the first batch would have created it
target_table_exists = (
snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0
)
audit_results = self.evaluate(
snapshot=snapshot,
environment_naming_info=environment_naming_info,
start=start,
end=end,
execution_time=execution_time,
deployability_index=deployability_index,
batch_index=node.batch_index,
allow_destructive_snapshots=allow_destructive_snapshots,
allow_additive_snapshots=allow_additive_snapshots,
target_table_exists=target_table_exists,
selected_models=selected_models,
try:
with self.snapshot_evaluator.concurrent_context():
if audit_only:
errors, skipped_intervals = self._run_audits_concurrently(
batched_intervals=batched_intervals,
deployability_index=deployability_index,
environment_naming_info=environment_naming_info,
execution_time=execution_time,
circuit_breaker=circuit_breaker,
auto_restatement_triggers=auto_restatement_triggers,
)
else:
# We only need to create physical tables if the snapshot is not representative
# or if it needs backfill
snapshots_to_create_candidates = [
s
for s in selected_snapshots
if not deployability_index.is_representative(s) or s in batched_intervals
]
snapshots_to_create = {
s.snapshot_id
for s in self.snapshot_evaluator.get_snapshots_to_create(
snapshots_to_create_candidates, deployability_index
)
}

evaluation_duration_ms = now_timestamp() - execution_start_ts
finally:
num_audits = len(audit_results)
num_audits_failed = sum(1 for result in audit_results if result.count)

execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index)
dag = self._dag(
batched_intervals,
snapshot_dag=snapshot_dag,
snapshots_to_create=snapshots_to_create,
)

self.console.update_snapshot_evaluation_progress(
snapshot,
batched_intervals[snapshot][node.batch_index],
node.batch_index,
evaluation_duration_ms,
num_audits - num_audits_failed,
num_audits_failed,
execution_stats=execution_stats,
auto_restatement_triggers=auto_restatement_triggers.get(
snapshot.snapshot_id
),
def run_node(node: SchedulingUnit) -> None:
if circuit_breaker and circuit_breaker():
raise CircuitBreakerError()
if isinstance(node, DummyNode):
return

snapshot = self.snapshots_by_name[node.snapshot_name]

if isinstance(node, EvaluateNode):
assert execution_time # mypy
assert deployability_index # mypy
node_start, node_end = node.interval

# If batch_index > 0, then the target table must exist since the first batch would have created it
target_table_exists = (
snapshot.snapshot_id not in snapshots_to_create
or node.batch_index > 0
)

def _do_evaluate() -> t.List[AuditResult]:
return self.evaluate(
snapshot=snapshot,
environment_naming_info=environment_naming_info,
start=node_start,
end=node_end,
execution_time=execution_time,
deployability_index=deployability_index,
batch_index=node.batch_index,
allow_destructive_snapshots=allow_destructive_snapshots,
allow_additive_snapshots=allow_additive_snapshots,
target_table_exists=target_table_exists,
selected_models=selected_models,
)

self._run_node_with_progress(
snapshot=snapshot,
node=node,
batched_intervals=batched_intervals,
auto_restatement_triggers=auto_restatement_triggers,
work_fn=_do_evaluate,
)
elif isinstance(node, CreateNode):
self.snapshot_evaluator.create_snapshot(
snapshot=snapshot,
snapshots=self.snapshots_by_name,
deployability_index=deployability_index,
allow_destructive_snapshots=allow_destructive_snapshots or set(),
allow_additive_snapshots=allow_additive_snapshots or set(),
)

errors, skipped_intervals = concurrent_apply_to_dag(
dag,
run_node,
self.max_workers,
raise_on_error=False,
)
elif isinstance(node, CreateNode):
self.snapshot_evaluator.create_snapshot(
snapshot=snapshot,
snapshots=self.snapshots_by_name,
deployability_index=deployability_index,
allow_destructive_snapshots=allow_destructive_snapshots or set(),
allow_additive_snapshots=allow_additive_snapshots or set(),
)

try:
with self.snapshot_evaluator.concurrent_context():
errors, skipped_intervals = concurrent_apply_to_dag(
dag,
run_node,
self.max_workers,
raise_on_error=False,
)
self.console.stop_evaluation_progress(success=not errors)

skipped_snapshots = {
Expand Down Expand Up @@ -947,6 +937,138 @@ def _audit_snapshot(

return audit_results

def _run_node_with_progress(
self,
*,
snapshot: Snapshot,
node: EvaluateNode,
batched_intervals: t.Dict[Snapshot, Intervals],
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]],
work_fn: t.Callable[[], t.List[AuditResult]],
) -> None:
"""Runs a work function for a node while tracking progress and audit results.

Args:
snapshot: The snapshot being processed.
node: The evaluate node.
batched_intervals: The batched intervals per snapshot.
auto_restatement_triggers: Auto restatement trigger info per snapshot.
work_fn: A callable that performs the actual work and returns audit results.
"""
self.console.start_snapshot_evaluation_progress(snapshot)
execution_start_ts = now_timestamp()
evaluation_duration_ms: t.Optional[int] = None

audit_results: t.List[AuditResult] = []
try:
audit_results = work_fn()
evaluation_duration_ms = now_timestamp() - execution_start_ts
finally:
num_audits = len(audit_results)
num_audits_failed = sum(1 for result in audit_results if result.count)

execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index)
)

self.console.update_snapshot_evaluation_progress(
snapshot,
batched_intervals[snapshot][node.batch_index],
node.batch_index,
evaluation_duration_ms,
num_audits - num_audits_failed,
num_audits_failed,
execution_stats=execution_stats,
auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id),
)

def _run_audits_concurrently(
self,
*,
batched_intervals: t.Dict[Snapshot, Intervals],
deployability_index: DeployabilityIndex,
environment_naming_info: EnvironmentNamingInfo,
execution_time: TimeLike,
circuit_breaker: t.Optional[t.Callable[[], bool]],
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]],
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs all audits across all snapshots in a single flat thread pool.

Audits are read-only SELECT queries with no side effects, so they can safely
run concurrently even across snapshots that have DAG dependencies. This fills
all concurrent_tasks slots at once instead of processing level-by-level as the
DAG executor would.

Args:
batched_intervals: The batched intervals to audit per snapshot.
deployability_index: Determines snapshots that are deployable.
environment_naming_info: The environment naming info.
execution_time: The date/time reference to use for execution time.
circuit_breaker: An optional handler which checks if the run should be aborted.
auto_restatement_triggers: Auto restatement trigger info per snapshot.

Returns:
A tuple of errors and skipped intervals (always empty for audit-only runs).
"""
# Flatten all (snapshot, interval, batch_index) tasks across all snapshots
audit_tasks: t.List[EvaluateNode] = [
EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=batch_index)
for snapshot, intervals in batched_intervals.items()
for batch_index, interval in enumerate(intervals)
]

errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = []
errors_lock = threading.Lock()
cancelled = threading.Event()

def run_audit_task(node: EvaluateNode) -> None:
if cancelled.is_set():
return
if circuit_breaker and circuit_breaker():
cancelled.set()
return

snapshot = self.snapshots_by_name[node.snapshot_name]
node_start, node_end = node.interval

def _do_audit() -> t.List[AuditResult]:
return self._audit_snapshot(
snapshot=snapshot,
environment_naming_info=environment_naming_info,
deployability_index=deployability_index,
snapshots=self.snapshots_by_name,
start=node_start,
end=node_end,
execution_time=execution_time,
audit_concurrent_tasks=1,
)

self._run_node_with_progress(
snapshot=snapshot,
node=node,
batched_intervals=batched_intervals,
auto_restatement_triggers=auto_restatement_triggers,
work_fn=_do_audit,
)

def run_audit_task_collecting_errors(node: EvaluateNode) -> None:
try:
run_audit_task(node)
except CircuitBreakerError:
cancelled.set()
except Exception as ex:
error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node)
error.__cause__ = ex
with errors_lock:
errors.append(error)

concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers)

if cancelled.is_set():
raise CircuitBreakerError()

return errors, []

def _check_ready_intervals(
self,
snapshot: Snapshot,
Expand Down
Loading