Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
earliest_start_date,
missing_intervals,
merge_intervals,
snapshots_to_dag,
Intervals,
)
from sqlmesh.core.snapshot.definition import (
Expand Down Expand Up @@ -344,35 +345,26 @@ def run(

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

def batch_intervals(
self,
merged_intervals: SnapshotToIntervals,
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
execution_time: t.Optional[TimeLike] = None,
) -> t.Dict[Snapshot, Intervals]:
def expand_range_as_interval(
start_ts: int, end_ts: int, interval_unit: IntervalUnit
) -> t.List[Interval]:
values = expand_range(start_ts, end_ts, interval_unit)
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]

dag = DAG[str]()

for snapshot in merged_intervals:
dag.add(snapshot.name, [p.name for p in snapshot.parents])

snapshot_intervals = {
snapshot: [
i
for interval in intervals
for i in expand_range_as_interval(*interval, snapshot.node.interval_unit)
]
def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
dag = snapshots_to_dag(merged_intervals)

snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
snapshot.snapshot_id: (
snapshot,
[
i
for interval in intervals
for i in _expand_range_as_interval(*interval, snapshot.node.interval_unit)
],
)
for snapshot, intervals in merged_intervals.items()
}
snapshot_batches = {}
all_unready_intervals: t.Dict[str, set[Interval]] = {}
for snapshot, intervals in snapshot_intervals.items():
for snapshot_id in dag:
if snapshot_id not in snapshot_intervals:
continue
snapshot, intervals = snapshot_intervals[snapshot_id]
unready = set(intervals)
intervals = snapshot.check_ready_intervals(intervals)
unready -= set(intervals)
Expand Down Expand Up @@ -429,7 +421,7 @@ def run_merged_intervals(
"""
execution_time = execution_time or now_timestamp()

batched_intervals = self.batch_intervals(merged_intervals, start, end, execution_time)
batched_intervals = self.batch_intervals(merged_intervals)

self.console.start_evaluation_progress(
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
Expand Down Expand Up @@ -686,3 +678,10 @@ def _resolve_one_snapshot_per_version(
snapshot_per_version[key] = snapshot

return snapshot_per_version


def _expand_range_as_interval(
start_ts: int, end_ts: int, interval_unit: IntervalUnit
) -> t.List[Interval]:
values = expand_range(start_ts, end_ts, interval_unit)
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]
88 changes: 87 additions & 1 deletion tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_batched_missing_intervals(
execution_time: t.Optional[TimeLike] = None,
) -> SnapshotToIntervals:
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
return scheduler.batch_intervals(merged_intervals, start, end, execution_time)
return scheduler.batch_intervals(merged_intervals)

return _get_batched_missing_intervals

Expand Down Expand Up @@ -722,3 +722,89 @@ def signal_b(batch: DatetimeRanges):
c: [],
d: [],
}


def test_signals_snapshots_out_of_order(
mocker: MockerFixture, make_snapshot, get_batched_missing_intervals
):
@signal()
def signal_base(batch: DatetimeRanges):
return [batch[0]]

signals = signal.get_registry()

snapshot_a = make_snapshot(
load_sql_based_model(
parse( # type: ignore
"""
MODEL (
name a,
kind INCREMENTAL_BY_TIME_RANGE(
lookback 1,
time_column dt,
),
start '2023-01-01',
signals SIGNAL_BASE(),
);
SELECT @start_date AS dt;
"""
),
signal_definitions=signals,
),
)

snapshot_b = make_snapshot(
load_sql_based_model(
parse( # type: ignore
"""
MODEL (
name b,
kind INCREMENTAL_BY_TIME_RANGE(
lookback 1,
time_column dt,
),
start '2023-01-01'
);
SELECT @start_date AS dt;
"""
),
signal_definitions=signals,
)
)

snapshot_c = make_snapshot(
load_sql_based_model(
parse( # type: ignore
"""
MODEL (
name c,
kind INCREMENTAL_BY_TIME_RANGE(
lookback 1,
time_column dt,
),
start '2023-01-01',
);
SELECT * FROM a UNION SELECT * FROM b
"""
),
signal_definitions=signals,
),
nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model},
)

snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1)
scheduler = Scheduler(
snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order
snapshot_evaluator=snapshot_evaluator,
state_sync=mocker.MagicMock(),
max_workers=2,
default_catalog=None,
)

batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None)

assert batches == {
snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))],
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
}