From 7a775c84e13ce2a9dddb424dc41950ea28eec79c Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:44:42 -0700 Subject: [PATCH 1/2] feat: batch expired snapshots --- docs/reference/configuration.md | 7 +- sqlmesh/core/config/janitor.py | 12 + sqlmesh/core/context.py | 20 +- sqlmesh/core/state_sync/base.py | 72 +++- sqlmesh/core/state_sync/cache.py | 14 +- sqlmesh/core/state_sync/common.py | 91 ++++- sqlmesh/core/state_sync/db/facade.py | 37 +- sqlmesh/core/state_sync/db/snapshot.py | 199 ++++++----- tests/core/state_sync/test_state_sync.py | 423 ++++++++++++++++++++++- tests/core/test_context.py | 2 +- 10 files changed, 740 insertions(+), 137 deletions(-) diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 676f9d7389..b13438ee2d 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -125,9 +125,10 @@ Formatting settings for the `sqlmesh format` command and UI. Configuration for the `sqlmesh janitor` command. -| Option | Description | Type | Required | -|--------------------------|----------------------------------------------------------------------------------------------------------------------------|:-------:|:--------:| -| `warn_on_delete_failure` | Whether to warn instead of erroring if the janitor fails to delete the expired environment schema / views (Default: False) | boolean | N | +| Option | Description | Type | Required | +|---------------------------------|----------------------------------------------------------------------------------------------------------------------------|:-------:|:--------:| +| `warn_on_delete_failure` | Whether to warn instead of erroring if the janitor fails to delete the expired environment schema / views (Default: False) | boolean | N | +| `expired_snapshots_batch_size` | Maximum number of expired snapshots to clean in a single batch (Default: 200) | int | N | ## UI diff --git a/sqlmesh/core/config/janitor.py b/sqlmesh/core/config/janitor.py index d288c90b3e..0f1c953bc0 100644 --- a/sqlmesh/core/config/janitor.py +++ b/sqlmesh/core/config/janitor.py @@ -1,7 +1,9 @@ from __future__ import annotations +import typing as t from sqlmesh.core.config.base import BaseConfig +from sqlmesh.utils.pydantic import field_validator class JanitorConfig(BaseConfig): @@ -9,6 +11,16 @@ class JanitorConfig(BaseConfig): Args: warn_on_delete_failure: Whether to warn instead of erroring if the janitor fails to delete the expired environment schema / views. + expired_snapshots_batch_size: Maximum number of expired snapshots to clean in a single batch. """ warn_on_delete_failure: bool = False + expired_snapshots_batch_size: t.Optional[int] = None + + @field_validator("expired_snapshots_batch_size", mode="before") + @classmethod + def _validate_batch_size(cls, value: int) -> int: + batch_size = int(value) + if batch_size <= 0: + raise ValueError("expired_snapshots_batch_size must be greater than 0") + return batch_size diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index e31a04fe81..bd8647f811 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -109,6 +109,7 @@ StateSync, cleanup_expired_views, ) +from sqlmesh.core.state_sync.common import delete_expired_snapshots from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, @@ -2852,19 +2853,14 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None: # Clean up expired environments by removing their views and schemas self._cleanup_environments(current_ts=current_ts) - cleanup_targets = self.state_sync.get_expired_snapshots( - ignore_ttl=ignore_ttl, current_ts=current_ts - ) - - # Remove the expired snapshots tables - self.snapshot_evaluator.cleanup( - target_snapshots=cleanup_targets, - on_complete=self.console.update_cleanup_progress, + delete_expired_snapshots( + self.state_sync, + self.snapshot_evaluator, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + console=self.console, + batch_size=self.config.janitor.expired_snapshots_batch_size, ) - - # Delete the expired snapshot records from the state sync - self.state_sync.delete_expired_snapshots(ignore_ttl=ignore_ttl, current_ts=current_ts) - self.state_sync.compact_intervals() def _cleanup_environments(self, current_ts: t.Optional[int] = None) -> None: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 2f8a68dd4a..0cc2fe9d53 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -72,6 +72,50 @@ def _schema_version_validator(cls, v: t.Any) -> int: SCHEMA_VERSION: int = MIN_SCHEMA_VERSION + len(MIGRATIONS) - 1 +class BatchBoundary(PydanticModel): + updated_ts: int + name: str + identifier: str + + def to_upper_batch_boundary(self) -> UpperBatchBoundary: + return UpperBatchBoundary( + updated_ts=self.updated_ts, + name=self.name, + identifier=self.identifier, + ) + + def to_lower_batch_boundary(self, batch_size: int) -> LowerBatchBoundary: + return LowerBatchBoundary( + updated_ts=self.updated_ts, + name=self.name, + identifier=self.identifier, + batch_size=batch_size, + ) + + +class UpperBatchBoundary(BatchBoundary): + @classmethod + def include_all_boundary(cls) -> UpperBatchBoundary: + # 9999-12-31T23:59:59.999Z in epoch milliseconds + return UpperBatchBoundary(updated_ts=253_402_300_799_999, name="", identifier="") + + +class LowerBatchBoundary(BatchBoundary): + batch_size: int + + @classmethod + def init_batch_boundary(cls, batch_size: int) -> LowerBatchBoundary: + return LowerBatchBoundary(updated_ts=0, name="", identifier="", batch_size=batch_size) + + +class ExpiredSnapshotBatch(PydanticModel): + """A batch of expired snapshots to be cleaned up.""" + + expired_snapshot_ids: t.Set[SnapshotId] + cleanup_tasks: t.List[SnapshotTableCleanupTask] + batch_boundary: BatchBoundary + + class PromotionResult(PydanticModel): added: t.List[SnapshotTableInfo] removed: t.List[SnapshotTableInfo] @@ -315,15 +359,23 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre @abc.abstractmethod def get_expired_snapshots( - self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: - """Aggregates the id's of the expired snapshots and creates a list of table cleanup tasks. + self, + *, + batch_boundary: BatchBoundary, + current_ts: t.Optional[int] = None, + ignore_ttl: bool = False, + ) -> t.Optional[ExpiredSnapshotBatch]: + """Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier). - Expired snapshots are snapshots that have exceeded their time-to-live - and are no longer in use within an environment. + Args: + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_boundary: If provided, gets snapshot relative to the given boundary. + If lower boundary then snapshots later than that will be returned (exclusive). + If upper boundary then snapshots earlier than that will be returned (inclusive). Returns: - The list of table cleanup tasks. + A batch describing expired snapshots or None if no snapshots are pending cleanup. """ @abc.abstractmethod @@ -363,7 +415,10 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: @abc.abstractmethod def delete_expired_snapshots( - self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None + self, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: """Removes expired snapshots. @@ -373,6 +428,9 @@ def delete_expired_snapshots( Args: ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment + current_ts: Timestamp used to evaluate expiration. + upper_batch_boundary: The upper boundary to delete expired snapshots till (inclusive). If not provided, + deletes all expired snapshots. """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index 3de4e7bf51..59c7b8ab69 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -11,7 +11,7 @@ SnapshotInfoLike, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals -from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync +from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync, UpperBatchBoundary from sqlmesh.utils.date import TimeLike, now_timestamp @@ -108,11 +108,17 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: self.state_sync.delete_snapshots(snapshot_ids) def delete_expired_snapshots( - self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None + self, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: - current_ts = current_ts or now_timestamp() self.snapshot_cache.clear() - self.state_sync.delete_expired_snapshots(current_ts=current_ts, ignore_ttl=ignore_ttl) + self.state_sync.delete_expired_snapshots( + upper_batch_boundary=upper_batch_boundary, + ignore_ttl=ignore_ttl, + current_ts=current_ts, + ) def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: for snapshot_intervals in snapshots_intervals: diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index cd8c389e33..ad2f8f559e 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -14,14 +14,16 @@ from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.core.environment import Environment, EnvironmentStatements from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.core.snapshot import Snapshot +from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter.base import EngineAdapter - from sqlmesh.core.state_sync.base import Versions + from sqlmesh.core.state_sync.base import Versions, ExpiredSnapshotBatch, StateReader, StateSync logger = logging.getLogger(__name__) +EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE = 200 + def cleanup_expired_views( default_adapter: EngineAdapter, @@ -215,3 +217,88 @@ def __iter__(self) -> t.Iterator[StateStreamContents]: yield EnvironmentsChunk(environments) return _StateStream() + + +def iter_expired_snapshot_batches( + state_reader: StateReader, + *, + current_ts: int, + ignore_ttl: bool = False, + batch_size: t.Optional[int] = None, +) -> t.Iterator[ExpiredSnapshotBatch]: + """Yields expired snapshot batches. + + Args: + state_reader: StateReader instance to query expired snapshots from. + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_size: Maximum number of snapshots to fetch per batch. + """ + from sqlmesh.core.state_sync.base import LowerBatchBoundary + + batch_size = batch_size if batch_size is not None else EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE + batch_boundary = LowerBatchBoundary.init_batch_boundary(batch_size=batch_size) + + while True: + batch = state_reader.get_expired_snapshots( + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_boundary=batch_boundary, + ) + + if batch is None: + return + + yield batch + + batch_boundary = batch.batch_boundary.to_lower_batch_boundary(batch_size=batch_size) + + +def delete_expired_snapshots( + state_sync: StateSync, + snapshot_evaluator: SnapshotEvaluator, + *, + current_ts: int, + ignore_ttl: bool = False, + batch_size: t.Optional[int] = None, + console: t.Optional[Console] = None, +) -> None: + """Delete all expired snapshots in batches. + + This helper function encapsulates the logic for deleting expired snapshots in batches, + eliminating code duplication across different use cases. + + Args: + state_sync: StateSync instance to query and delete expired snapshots from. + snapshot_evaluator: SnapshotEvaluator instance to clean up tables associated with snapshots. + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_size: Maximum number of snapshots to fetch per batch. + console: Optional console for reporting progress. + + Returns: + The total number of deleted expired snapshots. + """ + num_expired_snapshots = 0 + for batch in iter_expired_snapshot_batches( + state_reader=state_sync, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_size=batch_size, + ): + logger.info( + "Processing batch of size %s and max_updated_ts of %s", + len(batch.expired_snapshot_ids), + batch.batch_boundary.updated_ts, + ) + snapshot_evaluator.cleanup( + target_snapshots=batch.cleanup_tasks, + on_complete=console.update_cleanup_progress if console else None, + ) + state_sync.delete_expired_snapshots( + upper_batch_boundary=batch.batch_boundary.to_upper_batch_boundary(), + ignore_ttl=ignore_ttl, + ) + logger.info("Cleaned up expired snapshots batch") + num_expired_snapshots += len(batch.expired_snapshot_ids) + logger.info("Cleaned up %s expired snapshots", num_expired_snapshots) diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 3c23ef339c..674399ebd1 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -35,7 +35,6 @@ SnapshotInfoLike, SnapshotIntervals, SnapshotNameVersion, - SnapshotTableCleanupTask, SnapshotTableInfo, start_date, ) @@ -43,9 +42,12 @@ Interval, ) from sqlmesh.core.state_sync.base import ( + ExpiredSnapshotBatch, PromotionResult, StateSync, Versions, + BatchBoundary, + UpperBatchBoundary, ) from sqlmesh.core.state_sync.common import ( EnvironmentsChunk, @@ -261,11 +263,18 @@ def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: self.environment_state.invalidate_environment(name, protect_prod) def get_expired_snapshots( - self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: + self, + *, + batch_boundary: BatchBoundary, + current_ts: t.Optional[int] = None, + ignore_ttl: bool = False, + ) -> t.Optional[ExpiredSnapshotBatch]: current_ts = current_ts or now_timestamp() return self.snapshot_state.get_expired_snapshots( - self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl + environments=self.environment_state.get_environments(), + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_boundary=batch_boundary, ) def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: @@ -273,14 +282,20 @@ def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary @transactional() def delete_expired_snapshots( - self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None + self, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: - current_ts = current_ts or now_timestamp() - for expired_snapshot_ids, cleanup_targets in self.snapshot_state._get_expired_snapshots( - self.environment_state.get_environments(), ignore_ttl=ignore_ttl, current_ts=current_ts - ): - self.snapshot_state.delete_snapshots(expired_snapshot_ids) - self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) + upper_batch_boundary = upper_batch_boundary or UpperBatchBoundary.include_all_boundary() + batch = self.get_expired_snapshots( + ignore_ttl=ignore_ttl, + current_ts=current_ts, + batch_boundary=upper_batch_boundary, + ) + if batch and batch.expired_snapshot_ids: + self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids) + self.interval_state.cleanup_intervals(batch.cleanup_tasks, batch.expired_snapshot_ids) @transactional() def delete_expired_environments( diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 4a8b2c44c5..a3b5a57340 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -14,7 +14,6 @@ snapshot_id_filter, fetchone, fetchall, - create_batches, ) from sqlmesh.core.environment import Environment from sqlmesh.core.model import SeedModel, ModelKindName @@ -30,6 +29,7 @@ SnapshotId, SnapshotFingerprint, ) +from sqlmesh.core.state_sync.base import ExpiredSnapshotBatch, BatchBoundary, LowerBatchBoundary from sqlmesh.utils.migration import index_text_type, blob_text_type from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp from sqlmesh.utils import unique @@ -43,9 +43,6 @@ class SnapshotState: SNAPSHOT_BATCH_SIZE = 1000 - # Use a smaller batch size for expired snapshots to account for fetching - # of all snapshots that share the same version. - EXPIRED_SNAPSHOT_BATCH_SIZE = 200 def __init__( self, @@ -166,47 +163,38 @@ def get_expired_snapshots( self, environments: t.Iterable[Environment], current_ts: int, - ignore_ttl: bool = False, - ) -> t.List[SnapshotTableCleanupTask]: - """Aggregates the id's of the expired snapshots and creates a list of table cleanup tasks. - - Expired snapshots are snapshots that have exceeded their time-to-live - and are no longer in use within an environment. - - Returns: - The set of expired snapshot ids. - The list of table cleanup tasks. - """ - all_cleanup_targets = [] - for _, cleanup_targets in self._get_expired_snapshots( - environments=environments, - current_ts=current_ts, - ignore_ttl=ignore_ttl, - ): - all_cleanup_targets.extend(cleanup_targets) - return all_cleanup_targets - - def _get_expired_snapshots( - self, - environments: t.Iterable[Environment], - current_ts: int, - ignore_ttl: bool = False, - ) -> t.Iterator[t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]]: - expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) + ignore_ttl: bool, + batch_boundary: BatchBoundary, + ) -> t.Optional[ExpiredSnapshotBatch]: + expired_query = exp.select("name", "identifier", "version", "updated_ts").from_( + self.snapshots_table + ) if not ignore_ttl: expired_query = expired_query.where( (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts ) - expired_candidates = { - SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( - name=name, version=version + # Use tuple comparison for proper cursor-based pagination + operation = exp.GT if isinstance(batch_boundary, LowerBatchBoundary) else exp.LTE + expired_query = expired_query.where( + operation( + this=exp.Tuple( + expressions=[ + exp.column("updated_ts"), + exp.column("name"), + exp.column("identifier"), + ] + ), + expression=exp.Tuple( + expressions=[ + exp.Literal.number(batch_boundary.updated_ts), + exp.Literal.string(batch_boundary.name), + exp.Literal.string(batch_boundary.identifier), + ] + ), ) - for name, identifier, version in fetchall(self.engine_adapter, expired_query) - } - if not expired_candidates: - return + ) promoted_snapshot_ids = { snapshot.snapshot_id @@ -214,63 +202,106 @@ def _get_expired_snapshots( for snapshot in environment.snapshots } + if promoted_snapshot_ids: + not_in_conditions = [ + exp.not_(condition) + for condition in snapshot_id_filter( + self.engine_adapter, + promoted_snapshot_ids, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + ] + expired_query = expired_query.where(exp.and_(*not_in_conditions)) + + expired_query = expired_query.order_by( + exp.column("updated_ts"), exp.column("name"), exp.column("identifier") + ) + + if isinstance(batch_boundary, LowerBatchBoundary): + expired_query = expired_query.limit(batch_boundary.batch_size) + + rows = fetchall(self.engine_adapter, expired_query) + + if not rows: + return None + + expired_candidates = { + SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( + name=name, version=version + ) + for name, identifier, version, _ in rows + } + if not expired_candidates: + return None + def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: return ( snapshot.snapshot_id in promoted_snapshot_ids or snapshot.snapshot_id not in expired_candidates ) - unique_expired_versions = unique(expired_candidates.values()) - version_batches = create_batches( - unique_expired_versions, batch_size=self.EXPIRED_SNAPSHOT_BATCH_SIZE + # Extract cursor values from last row for pagination + last_row = rows[-1] + batch_boundary = BatchBoundary( + updated_ts=last_row[3], + name=last_row[0], + identifier=last_row[1], ) - for versions_batch in version_batches: - snapshots = self._get_snapshots_with_same_version(versions_batch) - - snapshots_by_version = defaultdict(set) - snapshots_by_dev_version = defaultdict(set) - for s in snapshots: - snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) - snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) - - expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] - all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots} - - cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = [] - for snapshot in expired_snapshots: - shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] - shared_version_snapshots.discard(snapshot.snapshot_id) - - shared_dev_version_snapshots = snapshots_by_dev_version[ - (snapshot.name, snapshot.dev_version) - ] - shared_dev_version_snapshots.discard(snapshot.snapshot_id) - - if not shared_dev_version_snapshots: - dev_table_only = bool(shared_version_snapshots) - cleanup_targets.append((snapshot.snapshot_id, dev_table_only)) - - snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets] - for snapshot_id_batch in create_batches( - snapshot_ids_to_cleanup, batch_size=self.SNAPSHOT_BATCH_SIZE - ): - snapshot_id_batch_set = set(snapshot_id_batch) - full_snapshots = self._get_snapshots(snapshot_id_batch_set) - cleanup_tasks = [ + + unique_expired_versions = unique(expired_candidates.values()) + expired_snapshot_ids: t.Set[SnapshotId] = set() + cleanup_tasks: t.List[SnapshotTableCleanupTask] = [] + + snapshots = self._get_snapshots_with_same_version(unique_expired_versions) + + snapshots_by_version = defaultdict(set) + snapshots_by_dev_version = defaultdict(set) + for s in snapshots: + snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) + snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) + + expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] + all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots} + + cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = [] + for snapshot in expired_snapshots: + shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] + shared_version_snapshots.discard(snapshot.snapshot_id) + + shared_dev_version_snapshots = snapshots_by_dev_version[ + (snapshot.name, snapshot.dev_version) + ] + shared_dev_version_snapshots.discard(snapshot.snapshot_id) + + if not shared_dev_version_snapshots: + dev_table_only = bool(shared_version_snapshots) + cleanup_targets.append((snapshot.snapshot_id, dev_table_only)) + + snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets] + full_snapshots = self._get_snapshots(snapshot_ids_to_cleanup) + for snapshot_id, dev_table_only in cleanup_targets: + if snapshot_id in full_snapshots: + cleanup_tasks.append( SnapshotTableCleanupTask( snapshot=full_snapshots[snapshot_id].table_info, dev_table_only=dev_table_only, ) - for snapshot_id, dev_table_only in cleanup_targets - if snapshot_id in full_snapshots - ] - all_expired_snapshot_ids -= snapshot_id_batch_set - yield snapshot_id_batch_set, cleanup_tasks - - if all_expired_snapshot_ids: - # Remaining expired snapshots for which there are no tables - # to cleanup - yield all_expired_snapshot_ids, [] + ) + expired_snapshot_ids.add(snapshot_id) + all_expired_snapshot_ids.discard(snapshot_id) + + # Add any remaining expired snapshots that don't require cleanup + if all_expired_snapshot_ids: + expired_snapshot_ids.update(all_expired_snapshot_ids) + + if expired_snapshot_ids or cleanup_tasks: + return ExpiredSnapshotBatch( + expired_snapshot_ids=expired_snapshot_ids, + cleanup_tasks=cleanup_tasks, + batch_boundary=batch_boundary, + ) + + return None def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: """Deletes snapshots. diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 51a646ce5d..5f5827422e 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -52,6 +52,21 @@ pytestmark = pytest.mark.slow +def _get_cleanup_tasks( + state_sync: EngineAdapterStateSync, + *, + limit: int = 1000, + ignore_ttl: bool = False, +) -> t.List[SnapshotTableCleanupTask]: + from sqlmesh.core.state_sync.base import LowerBatchBoundary + + batch = state_sync.get_expired_snapshots( + ignore_ttl=ignore_ttl, + batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=limit), + ) + return [] if batch is None else batch.cleanup_tasks + + @pytest.fixture def state_sync(duck_conn, tmp_path): state_sync = EngineAdapterStateSync( @@ -1156,7 +1171,7 @@ def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snaps new_snapshot.snapshot_id, } - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] @@ -1165,6 +1180,388 @@ def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snaps assert not state_sync.get_snapshots(all_snapshots) +def test_get_expired_snapshot_batch(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + from sqlmesh.core.state_sync.base import LowerBatchBoundary + + now_ts = now_timestamp() + + snapshots = [] + for idx in range(3): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + batch = state_sync.get_expired_snapshots( + batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + ) + assert batch is not None + assert len(batch.expired_snapshot_ids) == 2 + assert len(batch.cleanup_tasks) == 2 + + # Delete first batch using new API + state_sync.delete_expired_snapshots( + upper_batch_boundary=batch.batch_boundary.to_upper_batch_boundary(), + ) + + next_batch = state_sync.get_expired_snapshots( + batch_boundary=batch.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + assert next_batch is not None + assert len(next_batch.expired_snapshot_ids) == 1 + + # Delete second batch using new API + state_sync.delete_expired_snapshots( + upper_batch_boundary=next_batch.batch_boundary.to_upper_batch_boundary(), + ) + + assert ( + state_sync.get_expired_snapshots( + batch_boundary=next_batch.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + is None + ) + + +def test_get_expired_snapshot_batch_same_timestamp( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test that pagination works correctly when multiple snapshots have the same updated_ts.""" + from sqlmesh.core.state_sync.base import LowerBatchBoundary + + now_ts = now_timestamp() + same_timestamp = now_ts - 20000 + + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx:02d}", # Zero-padded to ensure deterministic name ordering + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + # All snapshots have the same updated_ts + snapshot.updated_ts = same_timestamp + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Fetch first batch of 2 + batch1 = state_sync.get_expired_snapshots( + batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + ) + assert batch1 is not None + assert len(batch1.expired_snapshot_ids) == 2 + assert sorted([x.name for x in batch1.expired_snapshot_ids]) == [ + '"model_00"', + '"model_01"', + ] + + # Fetch second batch of 2 using cursor from batch1 + batch2 = state_sync.get_expired_snapshots( + batch_boundary=batch1.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + assert batch2 is not None + assert len(batch2.expired_snapshot_ids) == 2 + assert sorted([x.name for x in batch2.expired_snapshot_ids]) == [ + '"model_02"', + '"model_03"', + ] + + # Fetch third batch of 2 using cursor from batch2 + batch3 = state_sync.get_expired_snapshots( + batch_boundary=batch2.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + assert batch3 is not None + assert sorted([x.name for x in batch3.expired_snapshot_ids]) == [ + '"model_04"', + ] + + +def test_delete_expired_snapshots_batching_with_deletion( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test that delete_expired_snapshots properly deletes batches as it pages through them.""" + from sqlmesh.core.state_sync.base import LowerBatchBoundary + + now_ts = now_timestamp() + + # Create 5 expired snapshots with different timestamps + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Verify all 5 snapshots exist + assert len(state_sync.get_snapshots(snapshots)) == 5 + + # Get first batch of 2 + batch1 = state_sync.get_expired_snapshots( + batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + ) + assert batch1 is not None + assert len(batch1.expired_snapshot_ids) == 2 + + # Delete the first batch using upper_batch_boundary + state_sync.delete_expired_snapshots( + upper_batch_boundary=batch1.batch_boundary.to_upper_batch_boundary(), + ) + + # Verify first 2 snapshots (model_0 and model_1, the oldest) are deleted and last 3 remain + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 3 + assert snapshots[0].snapshot_id in remaining # model_0 (newest) + assert snapshots[1].snapshot_id in remaining # model_1 + assert snapshots[2].snapshot_id in remaining # model_2 + assert snapshots[3].snapshot_id not in remaining # model_3 + assert snapshots[4].snapshot_id not in remaining # model_4 (oldest) + + # Get next batch of 2 (should start after batch1's boundary) + batch2 = state_sync.get_expired_snapshots( + batch_boundary=batch1.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + assert batch2 is not None + assert len(batch2.expired_snapshot_ids) == 2 + + # Delete the second batch + state_sync.delete_expired_snapshots( + upper_batch_boundary=batch2.batch_boundary.to_upper_batch_boundary(), + ) + + # Verify only the last snapshot remains + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 1 + assert snapshots[0].snapshot_id in remaining # model_0 (newest) + assert snapshots[1].snapshot_id not in remaining # model_1 + assert snapshots[2].snapshot_id not in remaining # model_2 + assert snapshots[3].snapshot_id not in remaining # model_3 + assert snapshots[4].snapshot_id not in remaining # model_4 (oldest) + + # Get final batch + batch3 = state_sync.get_expired_snapshots( + batch_boundary=batch2.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + assert batch3 is not None + assert len(batch3.expired_snapshot_ids) == 1 + + # Delete the final batch + state_sync.delete_expired_snapshots( + upper_batch_boundary=batch3.batch_boundary.to_upper_batch_boundary(), + ) + + # Verify all snapshots are deleted + assert len(state_sync.get_snapshots(snapshots)) == 0 + + # Verify no more expired snapshots exist + assert ( + state_sync.get_expired_snapshots( + batch_boundary=batch3.batch_boundary.to_lower_batch_boundary(batch_size=2), + ) + is None + ) + + +def test_iterator_expired_snapshot_batch( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test the for_each_expired_snapshot_batch helper function.""" + from sqlmesh.core.state_sync.common import iter_expired_snapshot_batches + + now_ts = now_timestamp() + + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Track all batches processed + batches_processed = [] + + # Process with batch size of 2 + for batch in iter_expired_snapshot_batches( + state_sync, + current_ts=now_ts, + ignore_ttl=False, + batch_size=2, + ): + batches_processed.append(batch) + + # Should have processed 3 batches (2 + 2 + 1) + assert len(batches_processed) == 3 + assert len(batches_processed[0].expired_snapshot_ids) == 2 + assert len(batches_processed[1].expired_snapshot_ids) == 2 + assert len(batches_processed[2].expired_snapshot_ids) == 1 + + # Verify all snapshots were processed + all_processed_ids = set() + for batch in batches_processed: + all_processed_ids.update(batch.expired_snapshot_ids) + + expected_ids = {s.snapshot_id for s in snapshots} + assert all_processed_ids == expected_ids + + +def test_delete_expired_snapshots_common_function_batching( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture +): + """Test that the common delete_expired_snapshots function properly pages through batches and deletes them.""" + from sqlmesh.core.state_sync.common import delete_expired_snapshots + from sqlmesh.core.state_sync.base import LowerBatchBoundary, UpperBatchBoundary + from unittest.mock import MagicMock + + now_ts = now_timestamp() + + # Create 5 expired snapshots with different timestamps + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Spy on get_expired_snapshots and delete_expired_snapshots methods + get_expired_spy = mocker.spy(state_sync, "get_expired_snapshots") + delete_expired_spy = mocker.spy(state_sync, "delete_expired_snapshots") + + # Mock snapshot evaluator + mock_evaluator = MagicMock() + mock_evaluator.cleanup = MagicMock() + + # Run delete_expired_snapshots with batch_size=2 + delete_expired_snapshots( + state_sync, + mock_evaluator, + current_ts=now_ts, + batch_size=2, + ) + + # Verify get_expired_snapshots was called the correct number of times: + # - 3 batches (2+2+1): each batch triggers 2 calls (one from for_each loop, one from delete_expired_snapshots) + # - Plus 1 final call that returns empty to exit the loop + # Total: 3 * 2 + 1 = 7 calls + assert get_expired_spy.call_count == 7 + + # Verify the progression of batch_boundary calls from the for_each loop + # (calls at indices 0, 2, 4, 6 are from for_each_expired_snapshot_batch) + # (calls at indices 1, 3, 5 are from delete_expired_snapshots in facade.py) + calls = get_expired_spy.call_args_list + + # First call from for_each should have a LowerBatchBoundary starting from the beginning + first_call_kwargs = calls[0][1] + assert "batch_boundary" in first_call_kwargs + first_boundary = first_call_kwargs["batch_boundary"] + assert isinstance(first_boundary, LowerBatchBoundary) + assert first_boundary.batch_size == 2 + assert first_boundary.updated_ts == 0 + assert first_boundary.name == "" + assert first_boundary.identifier == "" + + # Third call (second batch from for_each) should have a LowerBatchBoundary from the first batch's boundary + third_call_kwargs = calls[2][1] + assert "batch_boundary" in third_call_kwargs + second_boundary = third_call_kwargs["batch_boundary"] + assert isinstance(second_boundary, LowerBatchBoundary) + assert second_boundary.batch_size == 2 + # Should have progressed from the first batch + assert second_boundary.updated_ts > 0 + assert second_boundary.name == '"model_3"' + + # Fifth call (third batch from for_each) should have a LowerBatchBoundary from the second batch's boundary + fifth_call_kwargs = calls[4][1] + assert "batch_boundary" in fifth_call_kwargs + third_boundary = fifth_call_kwargs["batch_boundary"] + assert isinstance(third_boundary, LowerBatchBoundary) + assert third_boundary.batch_size == 2 + # Should have progressed from the second batch + assert third_boundary.updated_ts >= second_boundary.updated_ts + assert third_boundary.name == '"model_1"' + + # Seventh call (final call from for_each) should have a LowerBatchBoundary from the third batch's boundary + seventh_call_kwargs = calls[6][1] + assert "batch_boundary" in seventh_call_kwargs + fourth_boundary = seventh_call_kwargs["batch_boundary"] + assert isinstance(fourth_boundary, LowerBatchBoundary) + assert fourth_boundary.batch_size == 2 + # Should have progressed from the third batch + assert fourth_boundary.updated_ts >= third_boundary.updated_ts + assert fourth_boundary.name == '"model_0"' + + # Verify delete_expired_snapshots was called 3 times (once per batch) + assert delete_expired_spy.call_count == 3 + + # Verify each delete call used an UpperBatchBoundary + delete_calls = delete_expired_spy.call_args_list + + # First call should have an UpperBatchBoundary matching the first batch + first_delete_kwargs = delete_calls[0][1] + assert "upper_batch_boundary" in first_delete_kwargs + first_delete_boundary = first_delete_kwargs["upper_batch_boundary"] + assert isinstance(first_delete_boundary, UpperBatchBoundary) + assert first_delete_boundary.updated_ts == second_boundary.updated_ts + assert first_delete_boundary.name == second_boundary.name + assert first_delete_boundary.identifier == second_boundary.identifier + + second_delete_kwargs = delete_calls[1][1] + assert "upper_batch_boundary" in second_delete_kwargs + second_delete_boundary = second_delete_kwargs["upper_batch_boundary"] + assert isinstance(second_delete_boundary, UpperBatchBoundary) + assert second_delete_boundary.updated_ts == third_boundary.updated_ts + assert second_delete_boundary.name == third_boundary.name + assert second_delete_boundary.identifier == third_boundary.identifier + + third_delete_kwargs = delete_calls[2][1] + assert "upper_batch_boundary" in third_delete_kwargs + third_delete_boundary = third_delete_kwargs["upper_batch_boundary"] + assert isinstance(third_delete_boundary, UpperBatchBoundary) + assert third_delete_boundary.updated_ts == fourth_boundary.updated_ts + assert third_delete_boundary.name == fourth_boundary.name + assert third_delete_boundary.identifier == fourth_boundary.identifier + # Verify the cleanup method was called for each batch that had cleanup tasks + assert mock_evaluator.cleanup.call_count >= 1 + + # Verify all snapshots were deleted in the end + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 0 + + def test_delete_expired_snapshots_seed( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): @@ -1187,7 +1584,7 @@ def test_delete_expired_snapshots_seed( state_sync.push_snapshots(all_snapshots) assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False), ] state_sync.delete_expired_snapshots() @@ -1228,7 +1625,7 @@ def test_delete_expired_snapshots_batching( snapshot_b.snapshot_id, } - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False), SnapshotTableCleanupTask(snapshot=snapshot_b.table_info, dev_table_only=False), ] @@ -1265,7 +1662,7 @@ def test_delete_expired_snapshots_promoted( state_sync.promote(env) all_snapshots = [snapshot] - assert not state_sync.get_expired_snapshots() + assert not _get_cleanup_tasks(state_sync) state_sync.delete_expired_snapshots() assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} @@ -1275,7 +1672,7 @@ def test_delete_expired_snapshots_promoted( now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.db.facade.now_timestamp") now_timestamp_mock.return_value = now_timestamp() + 11000 - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) ] state_sync.delete_expired_snapshots() @@ -1315,7 +1712,7 @@ def test_delete_expired_snapshots_dev_table_cleanup_only( new_snapshot.snapshot_id, } - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True) ] state_sync.delete_expired_snapshots() @@ -1357,7 +1754,7 @@ def test_delete_expired_snapshots_shared_dev_table( new_snapshot.snapshot_id, } - assert not state_sync.get_expired_snapshots() # No dev table cleanup + assert not _get_cleanup_tasks(state_sync) # No dev table cleanup state_sync.delete_expired_snapshots() assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} @@ -1403,13 +1800,13 @@ def test_delete_expired_snapshots_ignore_ttl( state_sync.promote(env) # default TTL = 1 week, nothing to clean up yet if we take TTL into account - assert not state_sync.get_expired_snapshots() + assert not _get_cleanup_tasks(state_sync) state_sync.delete_expired_snapshots() assert state_sync.snapshots_exist([snapshot_c.snapshot_id]) == {snapshot_c.snapshot_id} # If we ignore TTL, only snapshot_c should get cleaned up because snapshot_a and snapshot_b are part of an environment assert snapshot_a.table_info != snapshot_b.table_info != snapshot_c.table_info - assert state_sync.get_expired_snapshots(ignore_ttl=True) == [ + assert _get_cleanup_tasks(state_sync, ignore_ttl=True) == [ SnapshotTableCleanupTask(snapshot=snapshot_c.table_info, dev_table_only=False) ] state_sync.delete_expired_snapshots(ignore_ttl=True) @@ -1476,7 +1873,7 @@ def test_delete_expired_snapshots_cleanup_intervals( ] assert not stored_new_snapshot.dev_intervals - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] @@ -1564,7 +1961,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( ) # Delete the expired snapshot - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), ] state_sync.delete_expired_snapshots() @@ -1684,7 +2081,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( ) # Delete the expired snapshot - assert state_sync.get_expired_snapshots() == [] + assert not _get_cleanup_tasks(state_sync) state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots([snapshot]) @@ -1778,7 +2175,7 @@ def test_compact_intervals_after_cleanup( state_sync.add_interval(snapshot_c, "2023-01-07", "2023-01-09", is_dev=True) # Only the dev table of the original snapshot should be deleted - assert state_sync.get_expired_snapshots() == [ + assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=True), ] state_sync.delete_expired_snapshots() diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 6270cec56a..60ea3fd451 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1030,7 +1030,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: sushi_context._engine_adapter = adapter_mock sushi_context.engine_adapters = {sushi_context.config.default_gateway: adapter_mock} sushi_context._state_sync = state_sync_mock - state_sync_mock.get_expired_snapshots.return_value = [] + state_sync_mock.get_expired_snapshots.return_value = None sushi_context._run_janitor() # Assert that the schemas are dropped just twice for the schema based environment From c95209a887430a2f009a225eddccaa70070123ee Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Mon, 6 Oct 2025 17:02:48 -0700 Subject: [PATCH 2/2] use range instead of single boundary --- sqlmesh/core/state_sync/base.py | 81 +----- sqlmesh/core/state_sync/cache.py | 7 +- sqlmesh/core/state_sync/common.py | 207 +++++++++++++++- sqlmesh/core/state_sync/db/facade.py | 16 +- sqlmesh/core/state_sync/db/snapshot.py | 43 ++-- tests/core/state_sync/test_state_sync.py | 302 ++++++++++++++++------- 6 files changed, 442 insertions(+), 214 deletions(-) diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 0cc2fe9d53..3c8c72845d 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -11,7 +11,6 @@ from sqlmesh import migrations from sqlmesh.core.environment import ( Environment, - EnvironmentNamingInfo, EnvironmentStatements, EnvironmentSummary, ) @@ -21,8 +20,6 @@ SnapshotIdLike, SnapshotIdAndVersionLike, SnapshotInfoLike, - SnapshotTableCleanupTask, - SnapshotTableInfo, SnapshotNameVersion, SnapshotIdAndVersion, ) @@ -30,8 +27,13 @@ from sqlmesh.utils import major_minor from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator -from sqlmesh.core.state_sync.common import StateStream +from sqlmesh.utils.pydantic import PydanticModel, field_validator +from sqlmesh.core.state_sync.common import ( + StateStream, + ExpiredSnapshotBatch, + PromotionResult, + ExpiredBatchRange, +) logger = logging.getLogger(__name__) @@ -72,64 +74,6 @@ def _schema_version_validator(cls, v: t.Any) -> int: SCHEMA_VERSION: int = MIN_SCHEMA_VERSION + len(MIGRATIONS) - 1 -class BatchBoundary(PydanticModel): - updated_ts: int - name: str - identifier: str - - def to_upper_batch_boundary(self) -> UpperBatchBoundary: - return UpperBatchBoundary( - updated_ts=self.updated_ts, - name=self.name, - identifier=self.identifier, - ) - - def to_lower_batch_boundary(self, batch_size: int) -> LowerBatchBoundary: - return LowerBatchBoundary( - updated_ts=self.updated_ts, - name=self.name, - identifier=self.identifier, - batch_size=batch_size, - ) - - -class UpperBatchBoundary(BatchBoundary): - @classmethod - def include_all_boundary(cls) -> UpperBatchBoundary: - # 9999-12-31T23:59:59.999Z in epoch milliseconds - return UpperBatchBoundary(updated_ts=253_402_300_799_999, name="", identifier="") - - -class LowerBatchBoundary(BatchBoundary): - batch_size: int - - @classmethod - def init_batch_boundary(cls, batch_size: int) -> LowerBatchBoundary: - return LowerBatchBoundary(updated_ts=0, name="", identifier="", batch_size=batch_size) - - -class ExpiredSnapshotBatch(PydanticModel): - """A batch of expired snapshots to be cleaned up.""" - - expired_snapshot_ids: t.Set[SnapshotId] - cleanup_tasks: t.List[SnapshotTableCleanupTask] - batch_boundary: BatchBoundary - - -class PromotionResult(PydanticModel): - added: t.List[SnapshotTableInfo] - removed: t.List[SnapshotTableInfo] - removed_environment_naming_info: t.Optional[EnvironmentNamingInfo] - - @field_validator("removed_environment_naming_info") - def _validate_removed_environment_naming_info( - cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo - ) -> t.Optional[EnvironmentNamingInfo]: - if v and not info.data.get("removed"): - raise ValueError("removed_environment_naming_info must be None if removed is empty") - return v - - class StateReader(abc.ABC): """Abstract base class for read-only operations on snapshot and environment state.""" @@ -361,7 +305,7 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre def get_expired_snapshots( self, *, - batch_boundary: BatchBoundary, + batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, ) -> t.Optional[ExpiredSnapshotBatch]: @@ -370,9 +314,7 @@ def get_expired_snapshots( Args: current_ts: Timestamp used to evaluate expiration. ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). - batch_boundary: If provided, gets snapshot relative to the given boundary. - If lower boundary then snapshots later than that will be returned (exclusive). - If upper boundary then snapshots earlier than that will be returned (inclusive). + batch_range: The range of the batch to fetch. Returns: A batch describing expired snapshots or None if no snapshots are pending cleanup. @@ -416,9 +358,9 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: @abc.abstractmethod def delete_expired_snapshots( self, + batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, - upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: """Removes expired snapshots. @@ -426,11 +368,10 @@ def delete_expired_snapshots( and are no longer in use within an environment. Args: + batch_range: The range of snapshots to delete in this batch. ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment current_ts: Timestamp used to evaluate expiration. - upper_batch_boundary: The upper boundary to delete expired snapshots till (inclusive). If not provided, - deletes all expired snapshots. """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index 59c7b8ab69..77f3fc6ba5 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -11,7 +11,8 @@ SnapshotInfoLike, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals -from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync, UpperBatchBoundary +from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync +from sqlmesh.core.state_sync.common import ExpiredBatchRange from sqlmesh.utils.date import TimeLike, now_timestamp @@ -109,13 +110,13 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: def delete_expired_snapshots( self, + batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, - upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: self.snapshot_cache.clear() self.state_sync.delete_expired_snapshots( - upper_batch_boundary=upper_batch_boundary, + batch_range=batch_range, ignore_ttl=ignore_ttl, current_ts=current_ts, ) diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index ad2f8f559e..3fdd0bc015 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -7,18 +7,26 @@ import abc from dataclasses import dataclass + +from pydantic_core.core_schema import ValidationInfo from sqlglot import exp from sqlmesh.core.console import Console from sqlmesh.core.dialect import schema_ -from sqlmesh.utils.pydantic import PydanticModel -from sqlmesh.core.environment import Environment, EnvironmentStatements +from sqlmesh.utils.pydantic import PydanticModel, field_validator +from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotEvaluator, + SnapshotId, + SnapshotTableCleanupTask, + SnapshotTableInfo, +) if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter.base import EngineAdapter - from sqlmesh.core.state_sync.base import Versions, ExpiredSnapshotBatch, StateReader, StateSync + from sqlmesh.core.state_sync.base import Versions, StateReader, StateSync logger = logging.getLogger(__name__) @@ -219,6 +227,170 @@ def __iter__(self) -> t.Iterator[StateStreamContents]: return _StateStream() +class ExpiredBatchRange(PydanticModel): + start: RowBoundary + end: t.Union[RowBoundary, LimitBoundary] + + @classmethod + def init_batch_range(cls, batch_size: int) -> ExpiredBatchRange: + return ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=LimitBoundary(batch_size=batch_size), + ) + + @classmethod + def all_batch_range(cls) -> ExpiredBatchRange: + return ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=RowBoundary.highest_boundary(), + ) + + @classmethod + def _expanded_tuple_comparison( + cls, + columns: t.List[exp.Column], + values: t.List[exp.Literal], + operator: t.Type[exp.Expression], + ) -> exp.Expression: + """Generate expanded tuple comparison that works across all SQL engines. + + Converts tuple comparisons like (a, b, c) OP (x, y, z) into an expanded form + that's compatible with all SQL engines, since native tuple comparisons have + inconsistent support across engines (especially DuckDB, MySQL, SQLite). + + Repro of problem with DuckDB: + "SELECT * FROM VALUES(1,'2') as test(a,b) WHERE ((a, b) > (1, 'foo')) AND ((a, b) <= (10, 'baz'))" + + Args: + columns: List of column expressions to compare + values: List of value expressions to compare against + operator: The comparison operator class (exp.GT, exp.GTE, exp.LT, exp.LTE) + + Examples: + (a, b, c) > (x, y, z) expands to: + a > x OR (a = x AND b > y) OR (a = x AND b = y AND c > z) + + (a, b, c) <= (x, y, z) expands to: + a < x OR (a = x AND b < y) OR (a = x AND b = y AND c <= z) + + (a, b, c) >= (x, y, z) expands to: + a > x OR (a = x AND b > y) OR (a = x AND b = y AND c >= z) + + Returns: + An expanded OR expression representing the tuple comparison + """ + if operator not in (exp.GT, exp.GTE, exp.LT, exp.LTE): + raise ValueError(f"Unsupported operator: {operator}. Use GT, GTE, LT, or LTE.") + + # For <= and >=, we use the strict operator for all but the last column + # e.g., (a, b) <= (x, y) becomes: a < x OR (a = x AND b <= y) + # For < and >, we use the strict operator throughout + # e.g., (a, b) > (x, y) becomes: a > x OR (a = x AND b > x) + strict_operator: t.Type[exp.Expression] + final_operator: t.Type[exp.Expression] + + if operator in (exp.LTE, exp.GTE): + # For inclusive operators (<=, >=), use strict form for intermediate columns + # but keep inclusive form for the last column + strict_operator = exp.LT if operator == exp.LTE else exp.GT + final_operator = operator # Keep LTE/GTE for last column + else: + # For strict operators (<, >), use them throughout + strict_operator = operator + final_operator = operator + + conditions: t.List[exp.Expression] = [] + for i in range(len(columns)): + # Build equality conditions for all columns before current + equality_conditions = [exp.EQ(this=columns[j], expression=values[j]) for j in range(i)] + + # Use the final operator for the last column, strict for others + comparison_op = final_operator if i == len(columns) - 1 else strict_operator + comparison_condition = comparison_op(this=columns[i], expression=values[i]) + + if equality_conditions: + conditions.append(exp.and_(*equality_conditions, comparison_condition)) + else: + conditions.append(comparison_condition) + + return exp.or_(*conditions) if len(conditions) > 1 else conditions[0] + + @property + def where_filter(self) -> exp.Expression: + # Use expanded tuple comparisons for cross-engine compatibility + # Native tuple comparisons like (a, b) > (x, y) don't work reliably across all SQL engines + columns = [ + exp.column("updated_ts"), + exp.column("name"), + exp.column("identifier"), + ] + start_values = [ + exp.Literal.number(self.start.updated_ts), + exp.Literal.string(self.start.name), + exp.Literal.string(self.start.identifier), + ] + + start_condition = self._expanded_tuple_comparison(columns, start_values, exp.GT) + + range_filter: exp.Expression + if isinstance(self.end, RowBoundary): + end_values = [ + exp.Literal.number(self.end.updated_ts), + exp.Literal.string(self.end.name), + exp.Literal.string(self.end.identifier), + ] + end_condition = self._expanded_tuple_comparison(columns, end_values, exp.LTE) + range_filter = exp.and_(start_condition, end_condition) + else: + range_filter = start_condition + return range_filter + + +class RowBoundary(PydanticModel): + updated_ts: int + name: str + identifier: str + + @classmethod + def lowest_boundary(cls) -> RowBoundary: + return RowBoundary(updated_ts=0, name="", identifier="") + + @classmethod + def highest_boundary(cls) -> RowBoundary: + # 9999-12-31T23:59:59.999Z in epoch milliseconds + return RowBoundary(updated_ts=253_402_300_799_999, name="", identifier="") + + +class LimitBoundary(PydanticModel): + batch_size: int + + @classmethod + def init_batch_boundary(cls, batch_size: int) -> LimitBoundary: + return LimitBoundary(batch_size=batch_size) + + +class PromotionResult(PydanticModel): + added: t.List[SnapshotTableInfo] + removed: t.List[SnapshotTableInfo] + removed_environment_naming_info: t.Optional[EnvironmentNamingInfo] + + @field_validator("removed_environment_naming_info") + def _validate_removed_environment_naming_info( + cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo + ) -> t.Optional[EnvironmentNamingInfo]: + if v and not info.data.get("removed"): + raise ValueError("removed_environment_naming_info must be None if removed is empty") + return v + + +class ExpiredSnapshotBatch(PydanticModel): + """A batch of expired snapshots to be cleaned up.""" + + expired_snapshot_ids: t.Set[SnapshotId] + cleanup_tasks: t.List[SnapshotTableCleanupTask] + batch_range: ExpiredBatchRange + + def iter_expired_snapshot_batches( state_reader: StateReader, *, @@ -234,16 +406,15 @@ def iter_expired_snapshot_batches( ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). batch_size: Maximum number of snapshots to fetch per batch. """ - from sqlmesh.core.state_sync.base import LowerBatchBoundary batch_size = batch_size if batch_size is not None else EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE - batch_boundary = LowerBatchBoundary.init_batch_boundary(batch_size=batch_size) + batch_range = ExpiredBatchRange.init_batch_range(batch_size=batch_size) while True: batch = state_reader.get_expired_snapshots( current_ts=current_ts, ignore_ttl=ignore_ttl, - batch_boundary=batch_boundary, + batch_range=batch_range, ) if batch is None: @@ -251,7 +422,13 @@ def iter_expired_snapshot_batches( yield batch - batch_boundary = batch.batch_boundary.to_lower_batch_boundary(batch_size=batch_size) + assert isinstance(batch.batch_range.end, RowBoundary), ( + "Only RowBoundary is supported for pagination currently" + ) + batch_range = ExpiredBatchRange( + start=batch.batch_range.end, + end=LimitBoundary(batch_size=batch_size), + ) def delete_expired_snapshots( @@ -286,17 +463,25 @@ def delete_expired_snapshots( ignore_ttl=ignore_ttl, batch_size=batch_size, ): + end_info = ( + f"updated_ts={batch.batch_range.end.updated_ts}" + if isinstance(batch.batch_range.end, RowBoundary) + else f"limit={batch.batch_range.end.batch_size}" + ) logger.info( - "Processing batch of size %s and max_updated_ts of %s", + "Processing batch of size %s with end %s", len(batch.expired_snapshot_ids), - batch.batch_boundary.updated_ts, + end_info, ) snapshot_evaluator.cleanup( target_snapshots=batch.cleanup_tasks, on_complete=console.update_cleanup_progress if console else None, ) state_sync.delete_expired_snapshots( - upper_batch_boundary=batch.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=batch.batch_range.end, + ), ignore_ttl=ignore_ttl, ) logger.info("Cleaned up expired snapshots batch") diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 674399ebd1..49f7b5b92f 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -42,12 +42,8 @@ Interval, ) from sqlmesh.core.state_sync.base import ( - ExpiredSnapshotBatch, - PromotionResult, StateSync, Versions, - BatchBoundary, - UpperBatchBoundary, ) from sqlmesh.core.state_sync.common import ( EnvironmentsChunk, @@ -57,6 +53,9 @@ StateStream, chunk_iterable, EnvironmentWithStatements, + ExpiredSnapshotBatch, + PromotionResult, + ExpiredBatchRange, ) from sqlmesh.core.state_sync.db.interval import IntervalState from sqlmesh.core.state_sync.db.environment import EnvironmentState @@ -265,7 +264,7 @@ def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: def get_expired_snapshots( self, *, - batch_boundary: BatchBoundary, + batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, ) -> t.Optional[ExpiredSnapshotBatch]: @@ -274,7 +273,7 @@ def get_expired_snapshots( environments=self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl, - batch_boundary=batch_boundary, + batch_range=batch_range, ) def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: @@ -283,15 +282,14 @@ def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary @transactional() def delete_expired_snapshots( self, + batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, - upper_batch_boundary: t.Optional[UpperBatchBoundary] = None, ) -> None: - upper_batch_boundary = upper_batch_boundary or UpperBatchBoundary.include_all_boundary() batch = self.get_expired_snapshots( ignore_ttl=ignore_ttl, current_ts=current_ts, - batch_boundary=upper_batch_boundary, + batch_range=batch_range, ) if batch and batch.expired_snapshot_ids: self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index a3b5a57340..4565990d65 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -29,7 +29,12 @@ SnapshotId, SnapshotFingerprint, ) -from sqlmesh.core.state_sync.base import ExpiredSnapshotBatch, BatchBoundary, LowerBatchBoundary +from sqlmesh.core.state_sync.common import ( + RowBoundary, + ExpiredSnapshotBatch, + ExpiredBatchRange, + LimitBoundary, +) from sqlmesh.utils.migration import index_text_type, blob_text_type from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp from sqlmesh.utils import unique @@ -164,7 +169,7 @@ def get_expired_snapshots( environments: t.Iterable[Environment], current_ts: int, ignore_ttl: bool, - batch_boundary: BatchBoundary, + batch_range: ExpiredBatchRange, ) -> t.Optional[ExpiredSnapshotBatch]: expired_query = exp.select("name", "identifier", "version", "updated_ts").from_( self.snapshots_table @@ -175,26 +180,7 @@ def get_expired_snapshots( (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts ) - # Use tuple comparison for proper cursor-based pagination - operation = exp.GT if isinstance(batch_boundary, LowerBatchBoundary) else exp.LTE - expired_query = expired_query.where( - operation( - this=exp.Tuple( - expressions=[ - exp.column("updated_ts"), - exp.column("name"), - exp.column("identifier"), - ] - ), - expression=exp.Tuple( - expressions=[ - exp.Literal.number(batch_boundary.updated_ts), - exp.Literal.string(batch_boundary.name), - exp.Literal.string(batch_boundary.identifier), - ] - ), - ) - ) + expired_query = expired_query.where(batch_range.where_filter) promoted_snapshot_ids = { snapshot.snapshot_id @@ -217,8 +203,8 @@ def get_expired_snapshots( exp.column("updated_ts"), exp.column("name"), exp.column("identifier") ) - if isinstance(batch_boundary, LowerBatchBoundary): - expired_query = expired_query.limit(batch_boundary.batch_size) + if isinstance(batch_range.end, LimitBoundary): + expired_query = expired_query.limit(batch_range.end.batch_size) rows = fetchall(self.engine_adapter, expired_query) @@ -242,11 +228,16 @@ def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: # Extract cursor values from last row for pagination last_row = rows[-1] - batch_boundary = BatchBoundary( + last_row_boundary = RowBoundary( updated_ts=last_row[3], name=last_row[0], identifier=last_row[1], ) + # The returned batch_range represents the actual range of rows in this batch + result_batch_range = ExpiredBatchRange( + start=batch_range.start, + end=last_row_boundary, + ) unique_expired_versions = unique(expired_candidates.values()) expired_snapshot_ids: t.Set[SnapshotId] = set() @@ -298,7 +289,7 @@ def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: return ExpiredSnapshotBatch( expired_snapshot_ids=expired_snapshot_ids, cleanup_tasks=cleanup_tasks, - batch_boundary=batch_boundary, + batch_range=result_batch_range, ) return None diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 5f5827422e..199ca43ee9 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -43,9 +43,14 @@ from sqlmesh.core.state_sync.base import ( SCHEMA_VERSION, SQLGLOT_VERSION, - PromotionResult, Versions, ) +from sqlmesh.core.state_sync.common import ( + ExpiredBatchRange, + LimitBoundary, + PromotionResult, + RowBoundary, +) from sqlmesh.utils.date import now_timestamp, to_datetime, to_timestamp from sqlmesh.utils.errors import SQLMeshError, StateMigrationError @@ -58,11 +63,9 @@ def _get_cleanup_tasks( limit: int = 1000, ignore_ttl: bool = False, ) -> t.List[SnapshotTableCleanupTask]: - from sqlmesh.core.state_sync.base import LowerBatchBoundary - batch = state_sync.get_expired_snapshots( ignore_ttl=ignore_ttl, - batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=limit), + batch_range=ExpiredBatchRange.init_batch_range(batch_size=limit), ) return [] if batch is None else batch.cleanup_tasks @@ -1175,14 +1178,12 @@ def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snaps SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots(all_snapshots) def test_get_expired_snapshot_batch(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - from sqlmesh.core.state_sync.base import LowerBatchBoundary - now_ts = now_timestamp() snapshots = [] @@ -1201,31 +1202,41 @@ def test_get_expired_snapshot_batch(state_sync: EngineAdapterStateSync, make_sna state_sync.push_snapshots(snapshots) batch = state_sync.get_expired_snapshots( - batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), ) assert batch is not None assert len(batch.expired_snapshot_ids) == 2 assert len(batch.cleanup_tasks) == 2 - # Delete first batch using new API state_sync.delete_expired_snapshots( - upper_batch_boundary=batch.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=batch.batch_range.end, + ), ) next_batch = state_sync.get_expired_snapshots( - batch_boundary=batch.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) assert next_batch is not None assert len(next_batch.expired_snapshot_ids) == 1 - # Delete second batch using new API state_sync.delete_expired_snapshots( - upper_batch_boundary=next_batch.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=next_batch.batch_range.start, + end=next_batch.batch_range.end, + ), ) assert ( state_sync.get_expired_snapshots( - batch_boundary=next_batch.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=next_batch.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) is None ) @@ -1235,8 +1246,6 @@ def test_get_expired_snapshot_batch_same_timestamp( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): """Test that pagination works correctly when multiple snapshots have the same updated_ts.""" - from sqlmesh.core.state_sync.base import LowerBatchBoundary - now_ts = now_timestamp() same_timestamp = now_ts - 20000 @@ -1258,7 +1267,7 @@ def test_get_expired_snapshot_batch_same_timestamp( # Fetch first batch of 2 batch1 = state_sync.get_expired_snapshots( - batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), ) assert batch1 is not None assert len(batch1.expired_snapshot_ids) == 2 @@ -1269,7 +1278,10 @@ def test_get_expired_snapshot_batch_same_timestamp( # Fetch second batch of 2 using cursor from batch1 batch2 = state_sync.get_expired_snapshots( - batch_boundary=batch1.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch1.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) assert batch2 is not None assert len(batch2.expired_snapshot_ids) == 2 @@ -1280,7 +1292,10 @@ def test_get_expired_snapshot_batch_same_timestamp( # Fetch third batch of 2 using cursor from batch2 batch3 = state_sync.get_expired_snapshots( - batch_boundary=batch2.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch2.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) assert batch3 is not None assert sorted([x.name for x in batch3.expired_snapshot_ids]) == [ @@ -1292,8 +1307,6 @@ def test_delete_expired_snapshots_batching_with_deletion( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): """Test that delete_expired_snapshots properly deletes batches as it pages through them.""" - from sqlmesh.core.state_sync.base import LowerBatchBoundary - now_ts = now_timestamp() # Create 5 expired snapshots with different timestamps @@ -1317,14 +1330,17 @@ def test_delete_expired_snapshots_batching_with_deletion( # Get first batch of 2 batch1 = state_sync.get_expired_snapshots( - batch_boundary=LowerBatchBoundary.init_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), ) assert batch1 is not None assert len(batch1.expired_snapshot_ids) == 2 - # Delete the first batch using upper_batch_boundary + # Delete the first batch using batch_range state_sync.delete_expired_snapshots( - upper_batch_boundary=batch1.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=batch1.batch_range.start, + end=batch1.batch_range.end, + ), ) # Verify first 2 snapshots (model_0 and model_1, the oldest) are deleted and last 3 remain @@ -1338,14 +1354,20 @@ def test_delete_expired_snapshots_batching_with_deletion( # Get next batch of 2 (should start after batch1's boundary) batch2 = state_sync.get_expired_snapshots( - batch_boundary=batch1.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch1.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) assert batch2 is not None assert len(batch2.expired_snapshot_ids) == 2 # Delete the second batch state_sync.delete_expired_snapshots( - upper_batch_boundary=batch2.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=batch2.batch_range.start, + end=batch2.batch_range.end, + ), ) # Verify only the last snapshot remains @@ -1359,14 +1381,20 @@ def test_delete_expired_snapshots_batching_with_deletion( # Get final batch batch3 = state_sync.get_expired_snapshots( - batch_boundary=batch2.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch2.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) assert batch3 is not None assert len(batch3.expired_snapshot_ids) == 1 # Delete the final batch state_sync.delete_expired_snapshots( - upper_batch_boundary=batch3.batch_boundary.to_upper_batch_boundary(), + batch_range=ExpiredBatchRange( + start=batch3.batch_range.start, + end=batch3.batch_range.end, + ), ) # Verify all snapshots are deleted @@ -1375,7 +1403,10 @@ def test_delete_expired_snapshots_batching_with_deletion( # Verify no more expired snapshots exist assert ( state_sync.get_expired_snapshots( - batch_boundary=batch3.batch_boundary.to_lower_batch_boundary(batch_size=2), + batch_range=ExpiredBatchRange( + start=batch3.batch_range.end, + end=LimitBoundary(batch_size=2), + ), ) is None ) @@ -1431,12 +1462,74 @@ def test_iterator_expired_snapshot_batch( assert all_processed_ids == expected_ids +@pytest.mark.parametrize( + "start_boundary,end_boundary,expected_sql", + [ + # Test with GT only (when end is LimitBoundary) + ( + RowBoundary(updated_ts=0, name="", identifier=""), + LimitBoundary(batch_size=100), + "updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')", + ), + # Test with GT and LTE (when both are RowBoundary) + ( + RowBoundary(updated_ts=1000, name="model_a", identifier="abc"), + RowBoundary(updated_ts=2000, name="model_z", identifier="xyz"), + "(updated_ts > 1000 OR (updated_ts = 1000 AND name > 'model_a') OR (updated_ts = 1000 AND name = 'model_a' AND identifier > 'abc')) AND (updated_ts < 2000 OR (updated_ts = 2000 AND name < 'model_z') OR (updated_ts = 2000 AND name = 'model_z' AND identifier <= 'xyz'))", + ), + # Test with zero timestamp + ( + RowBoundary(updated_ts=0, name="", identifier=""), + RowBoundary(updated_ts=1234567890, name="model_x", identifier="id_123"), + "(updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')) AND (updated_ts < 1234567890 OR (updated_ts = 1234567890 AND name < 'model_x') OR (updated_ts = 1234567890 AND name = 'model_x' AND identifier <= 'id_123'))", + ), + # Test with same timestamp, different names + ( + RowBoundary(updated_ts=5000, name="model_a", identifier="id_1"), + RowBoundary(updated_ts=5000, name="model_b", identifier="id_2"), + "(updated_ts > 5000 OR (updated_ts = 5000 AND name > 'model_a') OR (updated_ts = 5000 AND name = 'model_a' AND identifier > 'id_1')) AND (updated_ts < 5000 OR (updated_ts = 5000 AND name < 'model_b') OR (updated_ts = 5000 AND name = 'model_b' AND identifier <= 'id_2'))", + ), + # Test with same timestamp and name, different identifiers + ( + RowBoundary(updated_ts=7000, name="model_x", identifier="id_a"), + RowBoundary(updated_ts=7000, name="model_x", identifier="id_b"), + "(updated_ts > 7000 OR (updated_ts = 7000 AND name > 'model_x') OR (updated_ts = 7000 AND name = 'model_x' AND identifier > 'id_a')) AND (updated_ts < 7000 OR (updated_ts = 7000 AND name < 'model_x') OR (updated_ts = 7000 AND name = 'model_x' AND identifier <= 'id_b'))", + ), + # Test all_batch_range use case + ( + RowBoundary(updated_ts=0, name="", identifier=""), + RowBoundary(updated_ts=253_402_300_799_999, name="", identifier=""), + "(updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')) AND (updated_ts < 253402300799999 OR (updated_ts = 253402300799999 AND name < '') OR (updated_ts = 253402300799999 AND name = '' AND identifier <= ''))", + ), + ], +) +def test_expired_batch_range_where_filter(start_boundary, end_boundary, expected_sql): + """Test ExpiredBatchRange.where_filter generates correct SQL for various boundary combinations.""" + batch_range = ExpiredBatchRange(start=start_boundary, end=end_boundary) + result = batch_range.where_filter + assert result.sql() == expected_sql + + +def test_expired_batch_range_where_filter_with_limit(): + """Test that where_filter correctly handles LimitBoundary (only start condition, no end condition).""" + batch_range = ExpiredBatchRange( + start=RowBoundary(updated_ts=1000, name="model_a", identifier="abc"), + end=LimitBoundary(batch_size=50), + ) + result = batch_range.where_filter + # When end is LimitBoundary, should only have the start (GT) condition + assert ( + result.sql() + == "updated_ts > 1000 OR (updated_ts = 1000 AND name > 'model_a') OR (updated_ts = 1000 AND name = 'model_a' AND identifier > 'abc')" + ) + + def test_delete_expired_snapshots_common_function_batching( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture ): """Test that the common delete_expired_snapshots function properly pages through batches and deletes them.""" from sqlmesh.core.state_sync.common import delete_expired_snapshots - from sqlmesh.core.state_sync.base import LowerBatchBoundary, UpperBatchBoundary + from sqlmesh.core.state_sync.common import ExpiredBatchRange, RowBoundary, LimitBoundary from unittest.mock import MagicMock now_ts = now_timestamp() @@ -1474,86 +1567,103 @@ def test_delete_expired_snapshots_common_function_batching( ) # Verify get_expired_snapshots was called the correct number of times: - # - 3 batches (2+2+1): each batch triggers 2 calls (one from for_each loop, one from delete_expired_snapshots) + # - 3 batches (2+2+1): each batch triggers 2 calls (one from iter_expired_snapshot_batches, one from delete_expired_snapshots) # - Plus 1 final call that returns empty to exit the loop # Total: 3 * 2 + 1 = 7 calls assert get_expired_spy.call_count == 7 - # Verify the progression of batch_boundary calls from the for_each loop - # (calls at indices 0, 2, 4, 6 are from for_each_expired_snapshot_batch) + # Verify the progression of batch_range calls from the iter_expired_snapshot_batches loop + # (calls at indices 0, 2, 4, 6 are from iter_expired_snapshot_batches) # (calls at indices 1, 3, 5 are from delete_expired_snapshots in facade.py) calls = get_expired_spy.call_args_list - # First call from for_each should have a LowerBatchBoundary starting from the beginning + # First call from iterator should have a batch_range starting from the beginning first_call_kwargs = calls[0][1] - assert "batch_boundary" in first_call_kwargs - first_boundary = first_call_kwargs["batch_boundary"] - assert isinstance(first_boundary, LowerBatchBoundary) - assert first_boundary.batch_size == 2 - assert first_boundary.updated_ts == 0 - assert first_boundary.name == "" - assert first_boundary.identifier == "" - - # Third call (second batch from for_each) should have a LowerBatchBoundary from the first batch's boundary + assert "batch_range" in first_call_kwargs + first_range = first_call_kwargs["batch_range"] + assert isinstance(first_range, ExpiredBatchRange) + assert isinstance(first_range.start, RowBoundary) + assert isinstance(first_range.end, LimitBoundary) + assert first_range.end.batch_size == 2 + assert first_range.start.updated_ts == 0 + assert first_range.start.name == "" + assert first_range.start.identifier == "" + + # Third call (second batch from iterator) should have a batch_range from the first batch's range third_call_kwargs = calls[2][1] - assert "batch_boundary" in third_call_kwargs - second_boundary = third_call_kwargs["batch_boundary"] - assert isinstance(second_boundary, LowerBatchBoundary) - assert second_boundary.batch_size == 2 + assert "batch_range" in third_call_kwargs + second_range = third_call_kwargs["batch_range"] + assert isinstance(second_range, ExpiredBatchRange) + assert isinstance(second_range.start, RowBoundary) + assert isinstance(second_range.end, LimitBoundary) + assert second_range.end.batch_size == 2 # Should have progressed from the first batch - assert second_boundary.updated_ts > 0 - assert second_boundary.name == '"model_3"' + assert second_range.start.updated_ts > 0 + assert second_range.start.name == '"model_3"' - # Fifth call (third batch from for_each) should have a LowerBatchBoundary from the second batch's boundary + # Fifth call (third batch from iterator) should have a batch_range from the second batch's range fifth_call_kwargs = calls[4][1] - assert "batch_boundary" in fifth_call_kwargs - third_boundary = fifth_call_kwargs["batch_boundary"] - assert isinstance(third_boundary, LowerBatchBoundary) - assert third_boundary.batch_size == 2 + assert "batch_range" in fifth_call_kwargs + third_range = fifth_call_kwargs["batch_range"] + assert isinstance(third_range, ExpiredBatchRange) + assert isinstance(third_range.start, RowBoundary) + assert isinstance(third_range.end, LimitBoundary) + assert third_range.end.batch_size == 2 # Should have progressed from the second batch - assert third_boundary.updated_ts >= second_boundary.updated_ts - assert third_boundary.name == '"model_1"' + assert third_range.start.updated_ts >= second_range.start.updated_ts + assert third_range.start.name == '"model_1"' - # Seventh call (final call from for_each) should have a LowerBatchBoundary from the third batch's boundary + # Seventh call (final call from iterator) should have a batch_range from the third batch's range seventh_call_kwargs = calls[6][1] - assert "batch_boundary" in seventh_call_kwargs - fourth_boundary = seventh_call_kwargs["batch_boundary"] - assert isinstance(fourth_boundary, LowerBatchBoundary) - assert fourth_boundary.batch_size == 2 + assert "batch_range" in seventh_call_kwargs + fourth_range = seventh_call_kwargs["batch_range"] + assert isinstance(fourth_range, ExpiredBatchRange) + assert isinstance(fourth_range.start, RowBoundary) + assert isinstance(fourth_range.end, LimitBoundary) + assert fourth_range.end.batch_size == 2 # Should have progressed from the third batch - assert fourth_boundary.updated_ts >= third_boundary.updated_ts - assert fourth_boundary.name == '"model_0"' + assert fourth_range.start.updated_ts >= third_range.start.updated_ts + assert fourth_range.start.name == '"model_0"' # Verify delete_expired_snapshots was called 3 times (once per batch) assert delete_expired_spy.call_count == 3 - # Verify each delete call used an UpperBatchBoundary + # Verify each delete call used a batch_range delete_calls = delete_expired_spy.call_args_list - # First call should have an UpperBatchBoundary matching the first batch + # First call should have a batch_range matching the first batch first_delete_kwargs = delete_calls[0][1] - assert "upper_batch_boundary" in first_delete_kwargs - first_delete_boundary = first_delete_kwargs["upper_batch_boundary"] - assert isinstance(first_delete_boundary, UpperBatchBoundary) - assert first_delete_boundary.updated_ts == second_boundary.updated_ts - assert first_delete_boundary.name == second_boundary.name - assert first_delete_boundary.identifier == second_boundary.identifier + assert "batch_range" in first_delete_kwargs + first_delete_range = first_delete_kwargs["batch_range"] + assert isinstance(first_delete_range, ExpiredBatchRange) + assert isinstance(first_delete_range.start, RowBoundary) + assert first_delete_range.start.updated_ts == 0 + assert isinstance(first_delete_range.end, RowBoundary) + assert first_delete_range.end.updated_ts == second_range.start.updated_ts + assert first_delete_range.end.name == second_range.start.name + assert first_delete_range.end.identifier == second_range.start.identifier second_delete_kwargs = delete_calls[1][1] - assert "upper_batch_boundary" in second_delete_kwargs - second_delete_boundary = second_delete_kwargs["upper_batch_boundary"] - assert isinstance(second_delete_boundary, UpperBatchBoundary) - assert second_delete_boundary.updated_ts == third_boundary.updated_ts - assert second_delete_boundary.name == third_boundary.name - assert second_delete_boundary.identifier == third_boundary.identifier + assert "batch_range" in second_delete_kwargs + second_delete_range = second_delete_kwargs["batch_range"] + assert isinstance(second_delete_range, ExpiredBatchRange) + assert isinstance(second_delete_range.start, RowBoundary) + assert second_delete_range.start.updated_ts == 0 + assert isinstance(second_delete_range.end, RowBoundary) + assert second_delete_range.end.updated_ts == third_range.start.updated_ts + assert second_delete_range.end.name == third_range.start.name + assert second_delete_range.end.identifier == third_range.start.identifier third_delete_kwargs = delete_calls[2][1] - assert "upper_batch_boundary" in third_delete_kwargs - third_delete_boundary = third_delete_kwargs["upper_batch_boundary"] - assert isinstance(third_delete_boundary, UpperBatchBoundary) - assert third_delete_boundary.updated_ts == fourth_boundary.updated_ts - assert third_delete_boundary.name == fourth_boundary.name - assert third_delete_boundary.identifier == fourth_boundary.identifier + assert "batch_range" in third_delete_kwargs + third_delete_range = third_delete_kwargs["batch_range"] + assert isinstance(third_delete_range, ExpiredBatchRange) + assert isinstance(third_delete_range.start, RowBoundary) + assert third_delete_range.start.updated_ts == 0 + assert isinstance(third_delete_range.end, RowBoundary) + assert third_delete_range.end.updated_ts == fourth_range.start.updated_ts + assert third_delete_range.end.name == fourth_range.start.name + assert third_delete_range.end.identifier == fourth_range.start.identifier # Verify the cleanup method was called for each batch that had cleanup tasks assert mock_evaluator.cleanup.call_count >= 1 @@ -1587,7 +1697,7 @@ def test_delete_expired_snapshots_seed( assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots(all_snapshots) @@ -1629,7 +1739,7 @@ def test_delete_expired_snapshots_batching( SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False), SnapshotTableCleanupTask(snapshot=snapshot_b.table_info, dev_table_only=False), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots(all_snapshots) @@ -1663,7 +1773,7 @@ def test_delete_expired_snapshots_promoted( all_snapshots = [snapshot] assert not _get_cleanup_tasks(state_sync) - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} env.snapshots_ = [] @@ -1675,7 +1785,7 @@ def test_delete_expired_snapshots_promoted( assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots(all_snapshots) @@ -1715,7 +1825,7 @@ def test_delete_expired_snapshots_dev_table_cleanup_only( assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True) ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} @@ -1755,7 +1865,7 @@ def test_delete_expired_snapshots_shared_dev_table( } assert not _get_cleanup_tasks(state_sync) # No dev table cleanup - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} @@ -1801,7 +1911,7 @@ def test_delete_expired_snapshots_ignore_ttl( # default TTL = 1 week, nothing to clean up yet if we take TTL into account assert not _get_cleanup_tasks(state_sync) - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert state_sync.snapshots_exist([snapshot_c.snapshot_id]) == {snapshot_c.snapshot_id} # If we ignore TTL, only snapshot_c should get cleaned up because snapshot_a and snapshot_b are part of an environment @@ -1809,7 +1919,9 @@ def test_delete_expired_snapshots_ignore_ttl( assert _get_cleanup_tasks(state_sync, ignore_ttl=True) == [ SnapshotTableCleanupTask(snapshot=snapshot_c.table_info, dev_table_only=False) ] - state_sync.delete_expired_snapshots(ignore_ttl=True) + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange.all_batch_range(), ignore_ttl=True + ) assert not state_sync.snapshots_exist([snapshot_c.snapshot_id]) @@ -1877,7 +1989,7 @@ def test_delete_expired_snapshots_cleanup_intervals( SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not get_snapshot_intervals(snapshot) @@ -1964,7 +2076,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots([snapshot]) # Check new snapshot's intervals @@ -2082,7 +2194,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Delete the expired snapshot assert not _get_cleanup_tasks(state_sync) - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert not state_sync.get_snapshots([snapshot]) # Check new snapshot's intervals @@ -2178,7 +2290,7 @@ def test_compact_intervals_after_cleanup( assert _get_cleanup_tasks(state_sync) == [ SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=True), ] - state_sync.delete_expired_snapshots() + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 5 # type: ignore