diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 0ed2425854..6073a9f98d 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -21,6 +21,7 @@ earliest_start_date, missing_intervals, merge_intervals, + snapshots_to_dag, Intervals, ) from sqlmesh.core.snapshot.definition import ( @@ -344,35 +345,26 @@ def run( return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS - def batch_intervals( - self, - merged_intervals: SnapshotToIntervals, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - ) -> t.Dict[Snapshot, Intervals]: - def expand_range_as_interval( - start_ts: int, end_ts: int, interval_unit: IntervalUnit - ) -> t.List[Interval]: - values = expand_range(start_ts, end_ts, interval_unit) - return [(values[i], values[i + 1]) for i in range(len(values) - 1)] - - dag = DAG[str]() - - for snapshot in merged_intervals: - dag.add(snapshot.name, [p.name for p in snapshot.parents]) - - snapshot_intervals = { - snapshot: [ - i - for interval in intervals - for i in expand_range_as_interval(*interval, snapshot.node.interval_unit) - ] + def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]: + dag = snapshots_to_dag(merged_intervals) + + snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = { + snapshot.snapshot_id: ( + snapshot, + [ + i + for interval in intervals + for i in _expand_range_as_interval(*interval, snapshot.node.interval_unit) + ], + ) for snapshot, intervals in merged_intervals.items() } snapshot_batches = {} all_unready_intervals: t.Dict[str, set[Interval]] = {} - for snapshot, intervals in snapshot_intervals.items(): + for snapshot_id in dag: + if snapshot_id not in snapshot_intervals: + continue + snapshot, intervals = snapshot_intervals[snapshot_id] unready = set(intervals) intervals = snapshot.check_ready_intervals(intervals) unready -= set(intervals) @@ -429,7 +421,7 @@ def run_merged_intervals( """ execution_time = execution_time or now_timestamp() - batched_intervals = self.batch_intervals(merged_intervals, start, end, execution_time) + batched_intervals = self.batch_intervals(merged_intervals) self.console.start_evaluation_progress( {snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()}, @@ -686,3 +678,10 @@ def _resolve_one_snapshot_per_version( snapshot_per_version[key] = snapshot return snapshot_per_version + + +def _expand_range_as_interval( + start_ts: int, end_ts: int, interval_unit: IntervalUnit +) -> t.List[Interval]: + values = expand_range(start_ts, end_ts, interval_unit) + return [(values[i], values[i + 1]) for i in range(len(values) - 1)] diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 7e3ed0f1ad..cfe3bf52bb 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -76,7 +76,7 @@ def _get_batched_missing_intervals( execution_time: t.Optional[TimeLike] = None, ) -> SnapshotToIntervals: merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time) - return scheduler.batch_intervals(merged_intervals, start, end, execution_time) + return scheduler.batch_intervals(merged_intervals) return _get_batched_missing_intervals @@ -722,3 +722,89 @@ def signal_b(batch: DatetimeRanges): c: [], d: [], } + + +def test_signals_snapshots_out_of_order( + mocker: MockerFixture, make_snapshot, get_batched_missing_intervals +): + @signal() + def signal_base(batch: DatetimeRanges): + return [batch[0]] + + signals = signal.get_registry() + + snapshot_a = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name a, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01', + signals SIGNAL_BASE(), + ); + SELECT @start_date AS dt; + """ + ), + signal_definitions=signals, + ), + ) + + snapshot_b = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name b, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01' + ); + SELECT @start_date AS dt; + """ + ), + signal_definitions=signals, + ) + ) + + snapshot_c = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name c, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01', + ); + SELECT * FROM a UNION SELECT * FROM b + """ + ), + signal_definitions=signals, + ), + nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model}, + ) + + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + scheduler = Scheduler( + snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order + snapshot_evaluator=snapshot_evaluator, + state_sync=mocker.MagicMock(), + max_workers=2, + default_catalog=None, + ) + + batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None) + + assert batches == { + snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + }