From 010aa3db0e25a60eb9365606351c294d2ca55d19 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 14:06:08 -0800 Subject: [PATCH 1/6] Chore: Refactor the interval state from the state sync --- .../state_sync/engine_adapter/__init__.py | 3 + .../facade.py} | 443 +--------------- .../state_sync/engine_adapter/interval.py | 495 ++++++++++++++++++ .../core/state_sync/engine_adapter/utils.py | 109 ++++ tests/core/test_environment.py | 2 +- tests/core/test_state_sync.py | 32 +- tests/dbt/test_transformation.py | 2 +- 7 files changed, 643 insertions(+), 443 deletions(-) create mode 100644 sqlmesh/core/state_sync/engine_adapter/__init__.py rename sqlmesh/core/state_sync/{engine_adapter.py => engine_adapter/facade.py} (78%) create mode 100644 sqlmesh/core/state_sync/engine_adapter/interval.py create mode 100644 sqlmesh/core/state_sync/engine_adapter/utils.py diff --git a/sqlmesh/core/state_sync/engine_adapter/__init__.py b/sqlmesh/core/state_sync/engine_adapter/__init__.py new file mode 100644 index 0000000000..86839f1797 --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/__init__.py @@ -0,0 +1,3 @@ +from sqlmesh.core.state_sync.engine_adapter.facade import EngineAdapterStateSync + +__all__ = ["EngineAdapterStateSync"] diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter/facade.py similarity index 78% rename from sqlmesh/core/state_sync/engine_adapter.py rename to sqlmesh/core/state_sync/engine_adapter/facade.py index 22a4201c9a..1a4134f7f3 100644 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ b/sqlmesh/core/state_sync/engine_adapter/facade.py @@ -68,7 +68,8 @@ Versions, ) from sqlmesh.core.state_sync.common import transactional -from sqlmesh.utils import major_minor, random_id, unique +from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState +from sqlmesh.utils import major_minor, unique from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike, now, now_timestamp, time_like_to_str, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError @@ -106,7 +107,6 @@ class EngineAdapterStateSync(StateSync): context_path: The context path, used for caching snapshot models. """ - INTERVAL_BATCH_SIZE = 1000 SNAPSHOT_BATCH_SIZE = 1000 SNAPSHOT_MIGRATION_BATCH_SIZE = 500 @@ -117,13 +117,13 @@ def __init__( console: t.Optional[Console] = None, context_path: Path = Path(), ): + self.interval_state = IntervalState(engine_adapter, schema=schema) # Make sure that if an empty string is provided that we treat it as None self.schema = schema or None self.engine_adapter = engine_adapter self.console = console or get_console() self.snapshots_table = exp.table_("_snapshots", db=self.schema) self.environments_table = exp.table_("_environments", db=self.schema) - self.intervals_table = exp.table_("_intervals", db=self.schema) self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) self.auto_restatements_table = exp.table_("_auto_restatements", db=self.schema) self.versions_table = exp.table_("_versions", db=self.schema) @@ -159,21 +159,6 @@ def __init__( "requirements": exp.DataType.build(blob_type), } - self._interval_columns_to_types = { - "id": exp.DataType.build(index_type), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build(index_type), - "dev_version": exp.DataType.build(index_type), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - "is_pending_restatement": exp.DataType.build("boolean"), - } - self._auto_restatement_columns_to_types = { "snapshot_name": exp.DataType.build(index_type), "snapshot_version": exp.DataType.build(index_type), @@ -609,7 +594,7 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: if expired_snapshot_ids: self.delete_snapshots(expired_snapshot_ids) - self._cleanup_intervals(cleanup_targets, expired_snapshot_ids) + self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) return cleanup_targets @@ -665,7 +650,7 @@ def reset(self, default_catalog: t.Optional[str]) -> None: for table in ( self.snapshots_table, self.environments_table, - self.intervals_table, + self.interval_state.intervals_table, self.plan_dags_table, self.versions_table, ): @@ -878,7 +863,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: snapshots.pop(missing_cached_snapshot_id, None) if snapshots and hydrate_intervals: - _, intervals = self._get_snapshot_intervals(snapshots.values()) + intervals = self.interval_state.get_snapshot_intervals(snapshots.values()) Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) if duplicates: @@ -1038,52 +1023,7 @@ def add_interval( @transactional() def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: - def remove_partial_intervals( - intervals: t.List[Interval], snapshot_id: t.Optional[SnapshotId], *, is_dev: bool - ) -> t.List[Interval]: - results = [] - for start_ts, end_ts in intervals: - if start_ts < end_ts: - logger.info( - "Adding %s (%s, %s) for snapshot %s", - "dev interval" if is_dev else "interval", - time_like_to_str(start_ts), - time_like_to_str(end_ts), - snapshot_id, - ) - results.append((start_ts, end_ts)) - else: - logger.info( - "Skipping partial interval (%s, %s) for snapshot %s", - start_ts, - end_ts, - snapshot_id, - ) - return results - - intervals_to_insert = [] - for snapshot_intervals in snapshots_intervals: - snapshot_intervals = snapshot_intervals.copy( - update={ - "intervals": remove_partial_intervals( - snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False - ), - "dev_intervals": remove_partial_intervals( - snapshot_intervals.dev_intervals, - snapshot_intervals.snapshot_id, - is_dev=True, - ), - } - ) - if ( - snapshot_intervals.intervals - or snapshot_intervals.dev_intervals - or snapshot_intervals.pending_restatement_intervals - ): - intervals_to_insert.append(snapshot_intervals) - - if intervals_to_insert: - self._push_snapshot_intervals(intervals_to_insert) + self.interval_state.add_snapshots_intervals(snapshots_intervals) @transactional() def remove_intervals( @@ -1091,74 +1031,14 @@ def remove_intervals( snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], remove_shared_versions: bool = False, ) -> None: - intervals_to_remove: t.Sequence[ - t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] - ] = snapshot_intervals - if remove_shared_versions: - name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} - all_snapshots = [] - for where in self._snapshot_name_version_filter(name_version_mapping, alias=None): - all_snapshots.extend( - [ - SnapshotIntervals( - name=r[0], - identifier=r[1], - version=r[2], - dev_version=r[3], - intervals=[], - dev_intervals=[], - ) - for r in self._fetchall( - exp.select("name", "identifier", "version", "dev_version") - .from_(self.intervals_table) - .where(where) - .distinct() - ) - ] - ) - intervals_to_remove = [ - (snapshot, name_version_mapping[snapshot.name_version]) - for snapshot in all_snapshots - ] - - if logger.isEnabledFor(logging.INFO): - snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) - logger.info("Removing interval for snapshots: %s", snapshot_ids) - - for is_dev in (True, False): - self.engine_adapter.insert_append( - self.intervals_table, - _intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True), - columns_to_types=self._interval_columns_to_types, - ) + self.interval_state.remove_intervals(snapshot_intervals, remove_shared_versions) @transactional() def compact_intervals(self) -> None: - interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) - - logger.info( - "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) - ) - - self._push_snapshot_intervals(snapshot_intervals, is_compacted=True) - - if interval_ids: - for interval_id_batch in self._batches( - list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE - ): - self.engine_adapter.delete_from( - self.intervals_table, exp.column("id").isin(*interval_id_batch) - ) + self.interval_state.compact_intervals() def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: - if not snapshots: - return [] - - _, intervals = self._get_snapshot_intervals(snapshots) - for s in snapshots: - s.intervals = [] - s.dev_intervals = [] - return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) + return self.interval_state.refresh_snapshot_intervals(snapshots) def max_interval_end_per_model( self, @@ -1179,35 +1059,7 @@ def max_interval_end_per_model( if not snapshots: return {} - table_alias = "intervals" - name_col = exp.column("name", table=table_alias) - version_col = exp.column("version", table=table_alias) - - result: t.Dict[str, int] = {} - - for where in self._snapshot_name_version_filter(snapshots, alias=table_alias): - query = ( - exp.select( - name_col, - exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), - ) - .from_(exp.to_table(self.intervals_table).as_(table_alias)) - .where(where, copy=False) - .where( - exp.and_( - exp.to_column("is_dev").not_(), - exp.to_column("is_removed").not_(), - exp.to_column("is_pending_restatement").not_(), - ), - copy=False, - ) - .group_by(name_col, version_col, copy=False) - ) - - for name, max_end in self._fetchall(query): - result[name] = max_end - - return result + return self.interval_state.max_interval_end_per_model(snapshots) def recycle(self) -> None: self.engine_adapter.recycle() @@ -1215,228 +1067,6 @@ def recycle(self) -> None: def close(self) -> None: self.engine_adapter.close() - def _cleanup_intervals( - self, - cleanup_targets: t.List[SnapshotTableCleanupTask], - expired_snapshot_ids: t.Collection[SnapshotIdLike], - ) -> None: - # Cleanup can only happen for compacted intervals - self.compact_intervals() - # Delete intervals for non-dev tables that are no longer used - self._delete_intervals_by_version(cleanup_targets) - # Delete dev intervals for dev tables that are no longer used - self._delete_intervals_by_dev_version(cleanup_targets) - # Nullify the snapshot identifiers of interval records for snapshots that have been deleted - self._update_intervals_for_deleted_snapshots(expired_snapshot_ids) - - def _update_intervals_for_deleted_snapshots( - self, snapshot_ids: t.Collection[SnapshotIdLike] - ) -> None: - """Nullifies the snapshot identifiers of dev interval records and snapshot identifiers and dev versions of - non-dev interval records for snapshots that have been deleted so that they can be compacted efficiently. - """ - if not snapshot_ids: - return - - for where in self._snapshot_id_filter(snapshot_ids, alias=None): - # Nullify the identifier for dev intervals - # Set is_compacted to False so that it's compacted during the next compaction - self.engine_adapter.update_table( - self.intervals_table, - {"identifier": None, "is_compacted": False}, - where=where.and_(exp.column("is_dev")), - ) - # Nullify both identifier and dev version for non-dev intervals - # Set is_compacted to False so that it's compacted during the next compaction - self.engine_adapter.update_table( - self.intervals_table, - {"identifier": None, "dev_version": None, "is_compacted": False}, - where=where.and_(exp.column("is_dev").not_()), - ) - - def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: - """Deletes dev intervals for snapshot dev versions that are no longer used.""" - dev_keys_to_delete = [ - SnapshotNameVersion(name=t.snapshot.name, version=t.snapshot.dev_version) - for t in targets - if t.dev_table_only - ] - if not dev_keys_to_delete: - return - - for where in self._snapshot_name_version_filter( - dev_keys_to_delete, version_column_name="dev_version", alias=None - ): - self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) - - def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: - """Deletes intervals for snapshot versions that are no longer used.""" - non_dev_keys_to_delete = [t.snapshot for t in targets if not t.dev_table_only] - if not non_dev_keys_to_delete: - return - - for where in self._snapshot_name_version_filter(non_dev_keys_to_delete, alias=None): - self.engine_adapter.delete_from(self.intervals_table, where) - - def _get_snapshot_intervals( - self, - snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, - uncompacted_only: bool = False, - ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: - if not snapshots and snapshots is not None: - return (set(), []) - - query = self._get_snapshot_intervals_query(uncompacted_only) - - interval_ids: t.Set[str] = set() - intervals: t.Dict[ - t.Tuple[str, str, t.Optional[str], t.Optional[str]], SnapshotIntervals - ] = {} - - for where in ( - self._snapshot_name_version_filter(snapshots, alias="intervals") - if snapshots - else [None] - ): - rows = self._fetchall(query.where(where)) - for ( - interval_id, - name, - identifier, - version, - dev_version, - start, - end, - is_dev, - is_removed, - is_pending_restatement, - ) in rows: - interval_ids.add(interval_id) - merge_key = (name, version, dev_version, identifier) - # Pending restatement intervals are merged by name and version - pending_restatement_interval_merge_key = (name, version, None, None) - - if merge_key not in intervals: - intervals[merge_key] = SnapshotIntervals( - name=name, - identifier=identifier, - version=version, - dev_version=dev_version, - ) - - if pending_restatement_interval_merge_key not in intervals: - intervals[pending_restatement_interval_merge_key] = SnapshotIntervals( - name=name, - identifier=None, - version=version, - dev_version=None, - ) - - if is_removed: - if is_dev: - intervals[merge_key].remove_dev_interval(start, end) - else: - intervals[merge_key].remove_interval(start, end) - elif is_pending_restatement: - intervals[ - pending_restatement_interval_merge_key - ].add_pending_restatement_interval(start, end) - else: - if is_dev: - intervals[merge_key].add_dev_interval(start, end) - else: - intervals[merge_key].add_interval(start, end) - # Remove all pending restatement intervals recorded before the current interval has been added - intervals[ - pending_restatement_interval_merge_key - ].remove_pending_restatement_interval(start, end) - - return interval_ids, [i for i in intervals.values() if not i.is_empty()] - - def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: - query = ( - exp.select( - "id", - exp.column("name", table="intervals"), - exp.column("identifier", table="intervals"), - exp.column("version", table="intervals"), - exp.column("dev_version", table="intervals"), - "start_ts", - "end_ts", - "is_dev", - "is_removed", - "is_pending_restatement", - ) - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .order_by( - exp.column("name", table="intervals"), - exp.column("version", table="intervals"), - "created_ts", - "is_removed", - "is_pending_restatement", - ) - ) - if uncompacted_only: - query.join( - exp.select("name", "version") - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .where(exp.column("is_compacted").not_()) - .distinct() - .subquery(alias="uncompacted"), - on=exp.and_( - exp.column("name", table="intervals").eq( - exp.column("name", table="uncompacted") - ), - exp.column("version", table="intervals").eq( - exp.column("version", table="uncompacted") - ), - ), - copy=False, - ) - return query - - def _push_snapshot_intervals( - self, - snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]], - is_compacted: bool = False, - ) -> None: - new_intervals = [] - for snapshot in snapshots: - logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) - for start_ts, end_ts in snapshot.intervals: - new_intervals.append( - _interval_to_df( - snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted - ) - ) - for start_ts, end_ts in snapshot.dev_intervals: - new_intervals.append( - _interval_to_df( - snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted - ) - ) - - # Make sure that all pending restatement intervals are recorded last - for snapshot in snapshots: - for start_ts, end_ts in snapshot.pending_restatement_intervals: - new_intervals.append( - _interval_to_df( - snapshot, - start_ts, - end_ts, - is_dev=False, - is_compacted=is_compacted, - is_pending_restatement=True, - ) - ) - - if new_intervals: - self.engine_adapter.insert_append( - self.intervals_table, - pd.DataFrame(new_intervals), - columns_to_types=self._interval_columns_to_types, - ) - def _restore_table( self, table_name: TableName, @@ -1500,7 +1130,11 @@ def rollback(self) -> None: """Rollback to the previous migration.""" logger.info("Starting migration rollback.") tables = (self.snapshots_table, self.environments_table, self.versions_table) - optional_tables = (self.intervals_table, self.plan_dags_table, self.auto_restatements_table) + optional_tables = ( + self.interval_state.intervals_table, + self.plan_dags_table, + self.auto_restatements_table, + ) versions = self.get_versions(validate=False) if versions.schema_version == 0: # Clean up state tables @@ -1529,7 +1163,7 @@ def _backup_state(self) -> None: self.snapshots_table, self.environments_table, self.versions_table, - self.intervals_table, + self.interval_state.intervals_table, self.plan_dags_table, self.auto_restatements_table, ): @@ -1883,49 +1517,6 @@ def _transaction(self) -> t.Iterator[None]: yield -def _intervals_to_df( - snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], - is_dev: bool, - is_removed: bool, -) -> pd.DataFrame: - return pd.DataFrame( - [ - _interval_to_df( - s, - *interval, - is_dev=is_dev, - is_removed=is_removed, - ) - for s, interval in snapshot_intervals - ] - ) - - -def _interval_to_df( - snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], - start_ts: int, - end_ts: int, - is_dev: bool = False, - is_removed: bool = False, - is_compacted: bool = False, - is_pending_restatement: bool = False, -) -> t.Dict[str, t.Any]: - return { - "id": random_id(), - "created_ts": now_timestamp(), - "name": snapshot.name, - "identifier": snapshot.identifier if not is_pending_restatement else None, - "version": snapshot.version, - "dev_version": snapshot.dev_version if not is_pending_restatement else None, - "start_ts": start_ts, - "end_ts": end_ts, - "is_dev": is_dev, - "is_removed": is_removed, - "is_compacted": is_compacted, - "is_pending_restatement": is_pending_restatement, - } - - def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: return pd.DataFrame( [ diff --git a/sqlmesh/core/state_sync/engine_adapter/interval.py b/sqlmesh/core/state_sync/engine_adapter/interval.py new file mode 100644 index 0000000000..786a0cc698 --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/interval.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import typing as t +import logging +import pandas as pd + +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.engine_adapter.utils import ( + snapshot_name_version_filter, + snapshot_id_filter, + create_batches, + fetchall, +) +from sqlmesh.core.snapshot import ( + SnapshotIntervals, + SnapshotIdLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + SnapshotInfoLike, + Snapshot, + SnapshotId, +) +from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils import random_id +from sqlmesh.utils.date import now_timestamp, time_like_to_str + + +logger = logging.getLogger(__name__) + + +class IntervalState: + INTERVAL_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + table_name: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.intervals_table = exp.table_(table_name or "_intervals", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._interval_columns_to_types = { + "id": exp.DataType.build(index_type), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + "is_pending_restatement": exp.DataType.build("boolean"), + } + + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + def remove_partial_intervals( + intervals: t.List[Interval], snapshot_id: t.Optional[SnapshotId], *, is_dev: bool + ) -> t.List[Interval]: + results = [] + for start_ts, end_ts in intervals: + if start_ts < end_ts: + logger.info( + "Adding %s (%s, %s) for snapshot %s", + "dev interval" if is_dev else "interval", + time_like_to_str(start_ts), + time_like_to_str(end_ts), + snapshot_id, + ) + results.append((start_ts, end_ts)) + else: + logger.info( + "Skipping partial interval (%s, %s) for snapshot %s", + start_ts, + end_ts, + snapshot_id, + ) + return results + + intervals_to_insert = [] + for snapshot_intervals in snapshots_intervals: + snapshot_intervals = snapshot_intervals.copy( + update={ + "intervals": remove_partial_intervals( + snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False + ), + "dev_intervals": remove_partial_intervals( + snapshot_intervals.dev_intervals, + snapshot_intervals.snapshot_id, + is_dev=True, + ), + } + ) + if ( + snapshot_intervals.intervals + or snapshot_intervals.dev_intervals + or snapshot_intervals.pending_restatement_intervals + ): + intervals_to_insert.append(snapshot_intervals) + + if intervals_to_insert: + self._push_snapshot_intervals(intervals_to_insert) + + def remove_intervals( + self, + snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + remove_shared_versions: bool = False, + ) -> None: + intervals_to_remove: t.Sequence[ + t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] + ] = snapshot_intervals + if remove_shared_versions: + name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} + all_snapshots = [] + for where in snapshot_name_version_filter( + self.engine_adapter, name_version_mapping, alias=None + ): + all_snapshots.extend( + [ + SnapshotIntervals( + name=r[0], + identifier=r[1], + version=r[2], + dev_version=r[3], + intervals=[], + dev_intervals=[], + ) + for r in fetchall( + self.engine_adapter, + exp.select("name", "identifier", "version", "dev_version") + .from_(self.intervals_table) + .where(where) + .distinct(), + ) + ] + ) + intervals_to_remove = [ + (snapshot, name_version_mapping[snapshot.name_version]) + for snapshot in all_snapshots + ] + + if logger.isEnabledFor(logging.INFO): + snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) + logger.info("Removing interval for snapshots: %s", snapshot_ids) + + for is_dev in (True, False): + self.engine_adapter.insert_append( + self.intervals_table, + _intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True), + columns_to_types=self._interval_columns_to_types, + ) + + def get_snapshot_intervals( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.List[SnapshotIntervals]: + return self._get_snapshot_intervals(snapshots)[1] + + def compact_intervals(self) -> None: + interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) + + logger.info( + "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) + ) + + self._push_snapshot_intervals(snapshot_intervals, is_compacted=True) + + if interval_ids: + for interval_id_batch in create_batches( + list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE + ): + self.engine_adapter.delete_from( + self.intervals_table, exp.column("id").isin(*interval_id_batch) + ) + + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + if not snapshots: + return [] + + _, intervals = self._get_snapshot_intervals(snapshots) + for s in snapshots: + s.intervals = [] + s.dev_intervals = [] + return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) + + def max_interval_end_per_model( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.Dict[str, int]: + if not snapshots: + return {} + + table_alias = "intervals" + name_col = exp.column("name", table=table_alias) + version_col = exp.column("version", table=table_alias) + + result: t.Dict[str, int] = {} + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, alias=table_alias + ): + query = ( + exp.select( + name_col, + exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), + ) + .from_(exp.to_table(self.intervals_table).as_(table_alias)) + .where(where, copy=False) + .where( + exp.and_( + exp.to_column("is_dev").not_(), + exp.to_column("is_removed").not_(), + exp.to_column("is_pending_restatement").not_(), + ), + copy=False, + ) + .group_by(name_col, version_col, copy=False) + ) + + for name, max_end in fetchall(self.engine_adapter, query): + result[name] = max_end + + return result + + def cleanup_intervals( + self, + cleanup_targets: t.List[SnapshotTableCleanupTask], + expired_snapshot_ids: t.Collection[SnapshotIdLike], + ) -> None: + # Cleanup can only happen for compacted intervals + self.compact_intervals() + # Delete intervals for non-dev tables that are no longer used + self._delete_intervals_by_version(cleanup_targets) + # Delete dev intervals for dev tables that are no longer used + self._delete_intervals_by_dev_version(cleanup_targets) + # Nullify the snapshot identifiers of interval records for snapshots that have been deleted + self._update_intervals_for_deleted_snapshots(expired_snapshot_ids) + + def _push_snapshot_intervals( + self, + snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]], + is_compacted: bool = False, + ) -> None: + new_intervals = [] + for snapshot in snapshots: + logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) + for start_ts, end_ts in snapshot.intervals: + new_intervals.append( + _interval_to_df( + snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted + ) + ) + for start_ts, end_ts in snapshot.dev_intervals: + new_intervals.append( + _interval_to_df( + snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted + ) + ) + + # Make sure that all pending restatement intervals are recorded last + for snapshot in snapshots: + for start_ts, end_ts in snapshot.pending_restatement_intervals: + new_intervals.append( + _interval_to_df( + snapshot, + start_ts, + end_ts, + is_dev=False, + is_compacted=is_compacted, + is_pending_restatement=True, + ) + ) + + if new_intervals: + self.engine_adapter.insert_append( + self.intervals_table, + pd.DataFrame(new_intervals), + columns_to_types=self._interval_columns_to_types, + ) + + def _get_snapshot_intervals( + self, + snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, + uncompacted_only: bool = False, + ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: + if not snapshots and snapshots is not None: + return (set(), []) + + query = self._get_snapshot_intervals_query(uncompacted_only) + + interval_ids: t.Set[str] = set() + intervals: t.Dict[ + t.Tuple[str, str, t.Optional[str], t.Optional[str]], SnapshotIntervals + ] = {} + + for where in ( + snapshot_name_version_filter(self.engine_adapter, snapshots, alias="intervals") + if snapshots + else [None] + ): + rows = fetchall(self.engine_adapter, query.where(where)) + for ( + interval_id, + name, + identifier, + version, + dev_version, + start, + end, + is_dev, + is_removed, + is_pending_restatement, + ) in rows: + interval_ids.add(interval_id) + merge_key = (name, version, dev_version, identifier) + # Pending restatement intervals are merged by name and version + pending_restatement_interval_merge_key = (name, version, None, None) + + if merge_key not in intervals: + intervals[merge_key] = SnapshotIntervals( + name=name, + identifier=identifier, + version=version, + dev_version=dev_version, + ) + + if pending_restatement_interval_merge_key not in intervals: + intervals[pending_restatement_interval_merge_key] = SnapshotIntervals( + name=name, + identifier=None, + version=version, + dev_version=None, + ) + + if is_removed: + if is_dev: + intervals[merge_key].remove_dev_interval(start, end) + else: + intervals[merge_key].remove_interval(start, end) + elif is_pending_restatement: + intervals[ + pending_restatement_interval_merge_key + ].add_pending_restatement_interval(start, end) + else: + if is_dev: + intervals[merge_key].add_dev_interval(start, end) + else: + intervals[merge_key].add_interval(start, end) + # Remove all pending restatement intervals recorded before the current interval has been added + intervals[ + pending_restatement_interval_merge_key + ].remove_pending_restatement_interval(start, end) + + return interval_ids, [i for i in intervals.values() if not i.is_empty()] + + def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: + query = ( + exp.select( + "id", + exp.column("name", table="intervals"), + exp.column("identifier", table="intervals"), + exp.column("version", table="intervals"), + exp.column("dev_version", table="intervals"), + "start_ts", + "end_ts", + "is_dev", + "is_removed", + "is_pending_restatement", + ) + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .order_by( + exp.column("name", table="intervals"), + exp.column("version", table="intervals"), + "created_ts", + "is_removed", + "is_pending_restatement", + ) + ) + if uncompacted_only: + query.join( + exp.select("name", "version") + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .where(exp.column("is_compacted").not_()) + .distinct() + .subquery(alias="uncompacted"), + on=exp.and_( + exp.column("name", table="intervals").eq( + exp.column("name", table="uncompacted") + ), + exp.column("version", table="intervals").eq( + exp.column("version", table="uncompacted") + ), + ), + copy=False, + ) + return query + + def _update_intervals_for_deleted_snapshots( + self, snapshot_ids: t.Collection[SnapshotIdLike] + ) -> None: + """Nullifies the snapshot identifiers of dev interval records and snapshot identifiers and dev versions of + non-dev interval records for snapshots that have been deleted so that they can be compacted efficiently. + """ + if not snapshot_ids: + return + + for where in snapshot_id_filter(self.engine_adapter, snapshot_ids, alias=None): + # Nullify the identifier for dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev")), + ) + # Nullify both identifier and dev version for non-dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "dev_version": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev").not_()), + ) + + def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes dev intervals for snapshot dev versions that are no longer used.""" + dev_keys_to_delete = [ + SnapshotNameVersion(name=t.snapshot.name, version=t.snapshot.dev_version) + for t in targets + if t.dev_table_only + ] + if not dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, dev_keys_to_delete, version_column_name="dev_version", alias=None + ): + self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) + + def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes intervals for snapshot versions that are no longer used.""" + non_dev_keys_to_delete = [t.snapshot for t in targets if not t.dev_table_only] + if not non_dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, non_dev_keys_to_delete, alias=None + ): + self.engine_adapter.delete_from(self.intervals_table, where) + + +def _intervals_to_df( + snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], + is_dev: bool, + is_removed: bool, +) -> pd.DataFrame: + return pd.DataFrame( + [ + _interval_to_df( + s, + *interval, + is_dev=is_dev, + is_removed=is_removed, + ) + for s, interval in snapshot_intervals + ] + ) + + +def _interval_to_df( + snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], + start_ts: int, + end_ts: int, + is_dev: bool = False, + is_removed: bool = False, + is_compacted: bool = False, + is_pending_restatement: bool = False, +) -> t.Dict[str, t.Any]: + return { + "id": random_id(), + "created_ts": now_timestamp(), + "name": snapshot.name, + "identifier": snapshot.identifier if not is_pending_restatement else None, + "version": snapshot.version, + "dev_version": snapshot.dev_version if not is_pending_restatement else None, + "start_ts": start_ts, + "end_ts": end_ts, + "is_dev": is_dev, + "is_removed": is_removed, + "is_compacted": is_compacted, + "is_pending_restatement": is_pending_restatement, + } diff --git a/sqlmesh/core/state_sync/engine_adapter/utils.py b/sqlmesh/core/state_sync/engine_adapter/utils.py new file mode 100644 index 0000000000..5520f94b87 --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.snapshot import SnapshotIdLike, SnapshotNameVersionLike + + +T = t.TypeVar("T") + + +DEFAULT_BATCH_SIZE = 1000 + + +def snapshot_id_filter( + engine_adapter: EngineAdapter, + snapshot_ids: t.Iterable[SnapshotIdLike], + alias: t.Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, +) -> t.Iterator[exp.Condition]: + name_identifiers = sorted( + {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} + ) + batches = create_batches(name_identifiers, batch_size=batch_size) + + if not name_identifiers: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for identifiers in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + exp.column("name", table=alias), + exp.column("identifier", table=alias), + ) + ), + ).isin(*identifiers) + else: + for identifiers in batches: + yield exp.or_( + *[ + exp.and_( + exp.column("name", table=alias).eq(name), + exp.column("identifier", table=alias).eq(identifier), + ) + for name, identifier in identifiers + ] + ) + + +def snapshot_name_version_filter( + engine_adapter: EngineAdapter, + snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], + version_column_name: str = "version", + alias: t.Optional[str] = "snapshots", + column_prefix: t.Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, +) -> t.Iterator[exp.Condition]: + name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) + batches = create_batches(name_versions, batch_size=batch_size) + + name_column_name = "name" + if column_prefix: + name_column_name = f"{column_prefix}_{name_column_name}" + version_column_name = f"{column_prefix}_{version_column_name}" + + name_column = exp.column(name_column_name, table=alias) + version_column = exp.column(version_column_name, table=alias) + + if not name_versions: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for versions in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + name_column, + version_column, + ) + ), + ).isin(*versions) + else: + for versions in batches: + yield exp.or_( + *[ + exp.and_( + name_column.eq(name), + version_column.eq(version), + ) + for name, version in versions + ] + ) + + +def create_batches(l: t.List[T], batch_size: int = 1000) -> t.List[t.List[T]]: + return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] + + +def fetchone( + engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str] +) -> t.Optional[t.Tuple]: + return engine_adapter.fetchone(query, ignore_unsupported_errors=True, quote_identifiers=True) + + +def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: + return engine_adapter.fetchall(query, ignore_unsupported_errors=True, quote_identifiers=True) diff --git a/tests/core/test_environment.py b/tests/core/test_environment.py index 163bba47d8..228269cc6c 100644 --- a/tests/core/test_environment.py +++ b/tests/core/test_environment.py @@ -2,7 +2,7 @@ from sqlmesh.core.environment import Environment, EnvironmentNamingInfo from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo -from sqlmesh.core.state_sync.engine_adapter import _environment_to_df +from sqlmesh.core.state_sync.engine_adapter.facade import _environment_to_df def test_sanitize_name(): diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index a3b5f3ee29..9b3205f487 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -143,7 +143,7 @@ def test_push_snapshots( snapshot_b.snapshot_id: snapshot_b, } - logger = logging.getLogger("sqlmesh.core.state_sync.engine_adapter") + logger = logging.getLogger("sqlmesh.core.state_sync.engine_adapter.facade") with patch.object(logger, "error") as mock_logger: state_sync.push_snapshots([snapshot_a]) assert str({snapshot_a.snapshot_id}) == mock_logger.call_args[0][1] @@ -214,7 +214,7 @@ def test_snapshots_exists(state_sync: EngineAdapterStateSync, snapshots: t.List[ @pytest.fixture def get_snapshot_intervals(state_sync) -> t.Callable[[Snapshot], t.Optional[SnapshotIntervals]]: def _get_snapshot_intervals(snapshot: Snapshot) -> t.Optional[SnapshotIntervals]: - intervals = state_sync._get_snapshot_intervals([snapshot])[-1] + intervals = state_sync.interval_state.get_snapshot_intervals([snapshot]) return intervals[0] if intervals else None return _get_snapshot_intervals @@ -500,7 +500,7 @@ def test_compact_intervals_delete_batches( ) delete_from_mock = mocker.patch.object(state_sync.engine_adapter, "delete_from") - state_sync.INTERVAL_BATCH_SIZE = 2 + state_sync.interval_state.INTERVAL_BATCH_SIZE = 2 state_sync.push_snapshots([snapshot]) @@ -512,7 +512,9 @@ def test_compact_intervals_delete_batches( state_sync.compact_intervals() - delete_from_mock.assert_has_calls([call(state_sync.intervals_table, mocker.ANY)] * 3) + delete_from_mock.assert_has_calls( + [call(state_sync.interval_state.intervals_table, mocker.ANY)] * 3 + ) def test_promote_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): @@ -1234,7 +1236,7 @@ def test_delete_expired_snapshots_promoted( env.snapshots_ = [] state_sync.promote(env) - now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.engine_adapter.now_timestamp") + now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.engine_adapter.facade.now_timestamp") now_timestamp_mock.return_value = now_timestamp() + 11000 assert state_sync.delete_expired_snapshots() == [ @@ -1494,7 +1496,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1529,7 +1531,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ # The intervals of the old snapshot is preserved with the null identifier @@ -1607,7 +1609,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1642,7 +1644,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1754,7 +1756,7 @@ def test_compact_intervals_after_cleanup( assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), key=lambda x: (x.identifier or "", x.dev_version or ""), ) == expected_intervals @@ -1765,7 +1767,7 @@ def test_compact_intervals_after_cleanup( assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 4 # type: ignore assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), key=lambda x: (x.identifier or "", x.dev_version or ""), ) == expected_intervals @@ -2216,7 +2218,7 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> assert not state_sync.engine_adapter.table_exists(state_sync.snapshots_table) assert not state_sync.engine_adapter.table_exists(state_sync.environments_table) assert not state_sync.engine_adapter.table_exists(state_sync.versions_table) - assert not state_sync.engine_adapter.table_exists(state_sync.intervals_table) + assert not state_sync.engine_adapter.table_exists(state_sync.interval_state.intervals_table) def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: @@ -3125,7 +3127,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_b, "2020-01-03", "2020-01-03") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals @@ -3210,7 +3212,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_a, "2020-01-04", "2020-01-04") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals @@ -3269,7 +3271,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_b, "2020-01-05", "2020-01-05") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 8c0b0631c8..2abb73765b 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -27,7 +27,7 @@ ViewKind, ) from sqlmesh.core.model.kind import SCDType2ByColumnKind, SCDType2ByTimeKind -from sqlmesh.core.state_sync.engine_adapter import _snapshot_to_json +from sqlmesh.core.state_sync.engine_adapter.facade import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig, From b2e281d6f85fb7a8d62f9f476091daba5ba9003f Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 14:24:08 -0800 Subject: [PATCH 2/6] Chore: Refactor the environment state from the state sync --- .../state_sync/engine_adapter/environment.py | 257 ++++++++++++++++++ .../core/state_sync/engine_adapter/facade.py | 210 ++------------ tests/core/test_environment.py | 2 +- tests/core/test_state_sync.py | 4 +- 4 files changed, 284 insertions(+), 189 deletions(-) create mode 100644 sqlmesh/core/state_sync/engine_adapter/environment.py diff --git a/sqlmesh/core/state_sync/engine_adapter/environment.py b/sqlmesh/core/state_sync/engine_adapter/environment.py new file mode 100644 index 0000000000..ed7ebd0382 --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/environment.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import typing as t +import pandas as pd +import json +import logging +from sqlglot import exp + +from sqlmesh.core import constants as c +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.engine_adapter.utils import ( + fetchall, + fetchone, +) +from sqlmesh.core.environment import Environment +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, time_like_to_str +from sqlmesh.utils.errors import SQLMeshError + + +logger = logging.getLogger(__name__) + + +class EnvironmentState: + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + table_name: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.environments_table = exp.table_(table_name or "_environments", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + self._environment_columns_to_types = { + "name": exp.DataType.build(index_type), + "snapshots": exp.DataType.build(blob_type), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build(blob_type), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build(blob_type), + "normalize_name": exp.DataType.build("boolean"), + "requirements": exp.DataType.build(blob_type), + } + + def update_environment(self, environment: Environment) -> None: + """Updates the environment. + + Args: + environment: The environment + """ + self.engine_adapter.delete_from( + self.environments_table, + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment.name), + ), + ) + + self.engine_adapter.insert_append( + self.environments_table, + _environment_to_df(environment), + columns_to_types=self._environment_columns_to_types, + ) + + def invalidate_environment(self, name: str) -> None: + """Invalidates the environment. + + Args: + name: The name of the environment + """ + name = name.lower() + if name == c.PROD: + raise SQLMeshError("Cannot invalidate the production environment.") + + filter_expr = exp.column("name").eq(name) + + self.engine_adapter.update_table( + self.environments_table, + {"expiration_ts": now_timestamp()}, + where=filter_expr, + ) + + def finalize(self, environment: Environment) -> None: + """Finalize the target environment, indicating that this environment has been + fully promoted and is ready for use. + + Args: + environment: The target environment to finalize. + """ + logger.info("Finalizing environment '%s'", environment.name) + + environment_filter = exp.column("name").eq(exp.Literal.string(environment.name)) + + stored_plan_id_query = ( + exp.select("plan_id") + .from_(self.environments_table) + .where(environment_filter, copy=False) + .lock(copy=False) + ) + stored_plan_id_row = fetchone(self.engine_adapter, stored_plan_id_query) + + if not stored_plan_id_row: + raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized") + + stored_plan_id = stored_plan_id_row[0] + if stored_plan_id != environment.plan_id: + raise SQLMeshError( + f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " + f"Stored plan ID: '{stored_plan_id}'. Please recreate the plan and try again" + ) + + environment.finalized_ts = now_timestamp() + self.engine_adapter.update_table( + self.environments_table, + {"finalized_ts": environment.finalized_ts}, + where=environment_filter, + ) + + def delete_expired_environments(self) -> t.List[Environment]: + """Deletes expired environments. + + Returns: + A list of deleted environments. + """ + now_ts = now_timestamp() + filter_expr = exp.LTE( + this=exp.column("expiration_ts"), + expression=exp.Literal.number(now_ts), + ) + + rows = fetchall( + self.engine_adapter, + self._environments_query( + where=filter_expr, + lock_for_update=True, + ), + ) + environments = [self._environment_from_row(r) for r in rows] + + self.engine_adapter.delete_from( + self.environments_table, + where=filter_expr, + ) + + return environments + + def get_environments(self) -> t.List[Environment]: + """Fetches all environments. + + Returns: + A list of all environments. + """ + return [ + self._environment_from_row(row) + for row in fetchall(self.engine_adapter, self._environments_query()) + ] + + def get_environments_summary(self) -> t.Dict[str, int]: + """Fetches all environment names along with expiry datetime. + + Returns: + A dict of all environment names along with expiry datetime. + """ + return dict( + fetchall( + self.engine_adapter, + self._environments_query(required_fields=["name", "expiration_ts"]), + ), + ) + + def get_environment( + self, environment: str, lock_for_update: bool = False + ) -> t.Optional[Environment]: + """Fetches the environment if it exists. + + Args: + environment: The environment + lock_for_update: Lock the snapshot rows for future update + + Returns: + The environment object. + """ + row = fetchone( + self.engine_adapter, + self._environments_query( + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment), + ), + lock_for_update=lock_for_update, + ), + ) + + if not row: + return None + + env = self._environment_from_row(row) + return env + + def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: + return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) + + def _environments_query( + self, + where: t.Optional[str | exp.Expression] = None, + lock_for_update: bool = False, + required_fields: t.Optional[t.List[str]] = None, + ) -> exp.Select: + query_fields = required_fields if required_fields else Environment.all_fields() + query = ( + exp.select(*(exp.to_identifier(field) for field in query_fields)) + .from_(self.environments_table) + .where(where) + ) + if lock_for_update: + return query.lock(copy=False) + return query + + +def _environment_to_df(environment: Environment) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "name": environment.name, + "snapshots": json.dumps(environment.snapshot_dicts()), + "start_at": time_like_to_str(environment.start_at), + "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, + "plan_id": environment.plan_id, + "previous_plan_id": environment.previous_plan_id, + "expiration_ts": environment.expiration_ts, + "finalized_ts": environment.finalized_ts, + "promoted_snapshot_ids": ( + json.dumps(environment.promoted_snapshot_id_dicts()) + if environment.promoted_snapshot_ids is not None + else None + ), + "suffix_target": environment.suffix_target.value, + "catalog_name_override": environment.catalog_name_override, + "previous_finalized_snapshots": ( + json.dumps(environment.previous_finalized_snapshot_dicts()) + if environment.previous_finalized_snapshots is not None + else None + ), + "normalize_name": environment.normalize_name, + "requirements": json.dumps(environment.requirements), + } + ] + ) diff --git a/sqlmesh/core/state_sync/engine_adapter/facade.py b/sqlmesh/core/state_sync/engine_adapter/facade.py index 1a4134f7f3..7ced4522d8 100644 --- a/sqlmesh/core/state_sync/engine_adapter/facade.py +++ b/sqlmesh/core/state_sync/engine_adapter/facade.py @@ -69,11 +69,12 @@ ) from sqlmesh.core.state_sync.common import transactional from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState +from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState from sqlmesh.utils import major_minor, unique from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now, now_timestamp, time_like_to_str, to_timestamp +from sqlmesh.utils.date import TimeLike, now, now_timestamp, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError -from sqlmesh.utils.migration import blob_text_type, index_text_type +from sqlmesh.utils.migration import index_text_type from sqlmesh.utils.pydantic import PydanticModel logger = logging.getLogger(__name__) @@ -118,18 +119,17 @@ def __init__( context_path: Path = Path(), ): self.interval_state = IntervalState(engine_adapter, schema=schema) + self.environment_state = EnvironmentState(engine_adapter, schema=schema) # Make sure that if an empty string is provided that we treat it as None self.schema = schema or None self.engine_adapter = engine_adapter self.console = console or get_console() self.snapshots_table = exp.table_("_snapshots", db=self.schema) - self.environments_table = exp.table_("_environments", db=self.schema) self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) self.auto_restatements_table = exp.table_("_auto_restatements", db=self.schema) self.versions_table = exp.table_("_versions", db=self.schema) index_type = index_text_type(engine_adapter.dialect) - blob_type = blob_text_type(engine_adapter.dialect) self._snapshot_columns_to_types = { "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), @@ -142,23 +142,6 @@ def __init__( "unrestorable": exp.DataType.build("boolean"), } - self._environment_columns_to_types = { - "name": exp.DataType.build(index_type), - "snapshots": exp.DataType.build(blob_type), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - "finalized_ts": exp.DataType.build("bigint"), - "promoted_snapshot_ids": exp.DataType.build(blob_type), - "suffix_target": exp.DataType.build("text"), - "catalog_name_override": exp.DataType.build("text"), - "previous_finalized_snapshots": exp.DataType.build(blob_type), - "normalize_name": exp.DataType.build("boolean"), - "requirements": exp.DataType.build(blob_type), - } - self._auto_restatement_columns_to_types = { "snapshot_name": exp.DataType.build(index_type), "snapshot_version": exp.DataType.build(index_type), @@ -270,7 +253,9 @@ def promote( f"Missing snapshots {missing}. Make sure to push and backfill your snapshots." ) - existing_environment = self._get_environment(environment.name, lock_for_update=True) + existing_environment = self.environment_state.get_environment( + environment.name, lock_for_update=True + ) existing_table_infos = ( {table_info.name: table_info for table_info in existing_environment.promoted_snapshots} @@ -320,7 +305,7 @@ def promote( # Only promote new snapshots. added_table_infos -= set(existing_environment.promoted_snapshots) - self._update_environment(environment) + self.environment_state.update_environment(environment) removed = {existing_table_infos[name] for name in missing_models}.union( views_that_changed_location @@ -385,34 +370,7 @@ def finalize(self, environment: Environment) -> None: Args: environment: The target environment to finalize. """ - logger.info("Finalizing environment '%s'", environment.name) - - environment_filter = exp.column("name").eq(exp.Literal.string(environment.name)) - - stored_plan_id_query = ( - exp.select("plan_id") - .from_(self.environments_table) - .where(environment_filter, copy=False) - .lock(copy=False) - ) - stored_plan_id_row = self._fetchone(stored_plan_id_query) - - if not stored_plan_id_row: - raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized") - - stored_plan_id = stored_plan_id_row[0] - if stored_plan_id != environment.plan_id: - raise SQLMeshError( - f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " - f"Stored plan ID: '{stored_plan_id}'. Please recreate the plan and try again" - ) - - environment.finalized_ts = now_timestamp() - self.engine_adapter.update_table( - self.environments_table, - {"finalized_ts": environment.finalized_ts}, - where=environment_filter, - ) + self.environment_state.finalize(environment) @transactional() def unpause_snapshots( @@ -512,17 +470,7 @@ def _update_versions( ) def invalidate_environment(self, name: str) -> None: - name = name.lower() - if name == c.PROD: - raise SQLMeshError("Cannot invalidate the production environment.") - - filter_expr = exp.column("name").eq(name) - - self.engine_adapter.update_table( - self.environments_table, - {"expiration_ts": now_timestamp()}, - where=filter_expr, - ) + self.environment_state.invalidate_environment(name) @transactional() def delete_expired_snapshots( @@ -599,26 +547,7 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: return cleanup_targets def delete_expired_environments(self) -> t.List[Environment]: - now_ts = now_timestamp() - filter_expr = exp.LTE( - this=exp.column("expiration_ts"), - expression=exp.Literal.number(now_ts), - ) - - rows = self._fetchall( - self._environments_query( - where=filter_expr, - lock_for_update=True, - ) - ) - environments = [self._environment_from_row(r) for r in rows] - - self.engine_adapter.delete_from( - self.environments_table, - where=filter_expr, - ) - - return environments + return self.environment_state.delete_expired_environments() def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: if not snapshot_ids: @@ -649,7 +578,7 @@ def reset(self, default_catalog: t.Optional[str]) -> None: """Resets the state store to the state when it was first initialized.""" for table in ( self.snapshots_table, - self.environments_table, + self.environment_state.environments_table, self.interval_state.intervals_table, self.plan_dags_table, self.versions_table, @@ -679,21 +608,6 @@ def update_auto_restatements( columns_to_types=self._auto_restatement_columns_to_types, ) - def _update_environment(self, environment: Environment) -> None: - self.engine_adapter.delete_from( - self.environments_table, - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment.name), - ), - ) - - self.engine_adapter.insert_append( - self.environments_table, - _environment_to_df(environment), - columns_to_types=self._environment_columns_to_types, - ) - def _update_snapshots( self, snapshots: t.Iterable[SnapshotIdLike], @@ -710,7 +624,7 @@ def _update_snapshots( ) def get_environment(self, environment: str) -> t.Optional[Environment]: - return self._get_environment(environment) + return self.environment_state.get_environment(environment) def get_environments(self) -> t.List[Environment]: """Fetches all environments. @@ -718,9 +632,7 @@ def get_environments(self) -> t.List[Environment]: Returns: A list of all environments. """ - return [ - self._environment_from_row(row) for row in self._fetchall(self._environments_query()) - ] + return self.environment_state.get_environments() def get_environments_summary(self) -> t.Dict[str, int]: """Fetches all environment names along with expiry datetime. @@ -728,28 +640,7 @@ def get_environments_summary(self) -> t.Dict[str, int]: Returns: A dict of all environment names along with expiry datetime. """ - return dict( - self._fetchall(self._environments_query(required_fields=["name", "expiration_ts"])), - ) - - def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: - return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) - - def _environments_query( - self, - where: t.Optional[str | exp.Expression] = None, - lock_for_update: bool = False, - required_fields: t.Optional[t.List[str]] = None, - ) -> exp.Select: - query_fields = required_fields if required_fields else Environment.all_fields() - query = ( - exp.select(*(exp.to_identifier(field) for field in query_fields)) - .from_(self.environments_table) - .where(where) - ) - if lock_for_update: - return query.lock(copy=False) - return query + return self.environment_state.get_environments_summary() def get_snapshots( self, @@ -983,34 +874,6 @@ def _get_versions(self, lock_for_update: bool = False) -> Versions: schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) ) - def _get_environment( - self, environment: str, lock_for_update: bool = False - ) -> t.Optional[Environment]: - """Fetches the environment if it exists. - - Args: - environment: The environment - lock_for_update: Lock the snapshot rows for future update - - Returns: - The environment object. - """ - row = self._fetchone( - self._environments_query( - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment), - ), - lock_for_update=lock_for_update, - ) - ) - - if not row: - return None - - env = self._environment_from_row(row) - return env - @transactional() def add_interval( self, @@ -1046,7 +909,7 @@ def max_interval_end_per_model( models: t.Optional[t.Set[str]] = None, ensure_finalized_snapshots: bool = False, ) -> t.Dict[str, int]: - env = self._get_environment(environment) + env = self.get_environment(environment) if not env: return {} @@ -1129,7 +992,11 @@ def migrate( def rollback(self) -> None: """Rollback to the previous migration.""" logger.info("Starting migration rollback.") - tables = (self.snapshots_table, self.environments_table, self.versions_table) + tables = ( + self.snapshots_table, + self.environment_state.environments_table, + self.versions_table, + ) optional_tables = ( self.interval_state.intervals_table, self.plan_dags_table, @@ -1161,7 +1028,7 @@ def state_type(self) -> str: def _backup_state(self) -> None: for table in ( self.snapshots_table, - self.environments_table, + self.environment_state.environments_table, self.versions_table, self.interval_state.intervals_table, self.plan_dags_table, @@ -1406,7 +1273,7 @@ def _migrate_environment_rows( self.console.start_env_migration_progress(len(updated_environments)) for environment in updated_environments: - self._update_environment(environment) + self.environment_state.update_environment(environment) self.console.update_env_migration_progress(1) if updated_prod_environment: @@ -1536,37 +1403,6 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: ) -def _environment_to_df(environment: Environment) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": environment.name, - "snapshots": json.dumps(environment.snapshot_dicts()), - "start_at": time_like_to_str(environment.start_at), - "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, - "plan_id": environment.plan_id, - "previous_plan_id": environment.previous_plan_id, - "expiration_ts": environment.expiration_ts, - "finalized_ts": environment.finalized_ts, - "promoted_snapshot_ids": ( - json.dumps(environment.promoted_snapshot_id_dicts()) - if environment.promoted_snapshot_ids is not None - else None - ), - "suffix_target": environment.suffix_target.value, - "catalog_name_override": environment.catalog_name_override, - "previous_finalized_snapshots": ( - json.dumps(environment.previous_finalized_snapshot_dicts()) - if environment.previous_finalized_snapshots is not None - else None - ), - "normalize_name": environment.normalize_name, - "requirements": json.dumps(environment.requirements), - } - ] - ) - - def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: return pd.DataFrame( [ diff --git a/tests/core/test_environment.py b/tests/core/test_environment.py index 228269cc6c..8de10318e6 100644 --- a/tests/core/test_environment.py +++ b/tests/core/test_environment.py @@ -2,7 +2,7 @@ from sqlmesh.core.environment import Environment, EnvironmentNamingInfo from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo -from sqlmesh.core.state_sync.engine_adapter.facade import _environment_to_df +from sqlmesh.core.state_sync.engine_adapter.environment import _environment_to_df def test_sanitize_name(): diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 9b3205f487..2ae8e9cc98 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -2216,7 +2216,9 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> ): state_sync.migrate(default_catalog=None) assert not state_sync.engine_adapter.table_exists(state_sync.snapshots_table) - assert not state_sync.engine_adapter.table_exists(state_sync.environments_table) + assert not state_sync.engine_adapter.table_exists( + state_sync.environment_state.environments_table + ) assert not state_sync.engine_adapter.table_exists(state_sync.versions_table) assert not state_sync.engine_adapter.table_exists(state_sync.interval_state.intervals_table) From 6fa55d57f41be1c77c1e5e7707f4d05afd7bbe64 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 16:03:20 -0800 Subject: [PATCH 3/6] Chore: Refactor the snapshot state from the state sync --- .../state_sync/engine_adapter/environment.py | 3 +- .../core/state_sync/engine_adapter/facade.py | 708 +--------------- .../state_sync/engine_adapter/interval.py | 30 +- .../state_sync/engine_adapter/snapshot.py | 790 ++++++++++++++++++ .../core/state_sync/engine_adapter/utils.py | 9 +- tests/core/test_state_sync.py | 34 +- tests/dbt/test_transformation.py | 2 +- 7 files changed, 873 insertions(+), 703 deletions(-) create mode 100644 sqlmesh/core/state_sync/engine_adapter/snapshot.py diff --git a/sqlmesh/core/state_sync/engine_adapter/environment.py b/sqlmesh/core/state_sync/engine_adapter/environment.py index ed7ebd0382..1150d3f9e2 100644 --- a/sqlmesh/core/state_sync/engine_adapter/environment.py +++ b/sqlmesh/core/state_sync/engine_adapter/environment.py @@ -26,10 +26,9 @@ def __init__( self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None, - table_name: t.Optional[str] = None, ): self.engine_adapter = engine_adapter - self.environments_table = exp.table_(table_name or "_environments", db=schema) + self.environments_table = exp.table_("_environments", db=schema) index_type = index_text_type(engine_adapter.dialect) blob_type = blob_text_type(engine_adapter.dialect) diff --git a/sqlmesh/core/state_sync/engine_adapter/facade.py b/sqlmesh/core/state_sync/engine_adapter/facade.py index 7ced4522d8..7658bfe6b7 100644 --- a/sqlmesh/core/state_sync/engine_adapter/facade.py +++ b/sqlmesh/core/state_sync/engine_adapter/facade.py @@ -21,13 +21,11 @@ import logging import time import typing as t -from collections import defaultdict from copy import deepcopy from pathlib import Path from datetime import datetime import pandas as pd -from pydantic import Field from sqlglot import __version__ as SQLGLOT_VERSION from sqlglot import exp from sqlglot.helper import seq_get @@ -37,25 +35,20 @@ from sqlmesh.core.console import Console, get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment -from sqlmesh.core.model import ModelKindName, SeedModel -from sqlmesh.core.node import IntervalUnit from sqlmesh.core.snapshot import ( Node, Snapshot, - SnapshotChangeCategory, SnapshotFingerprint, SnapshotId, SnapshotIdLike, SnapshotInfoLike, SnapshotIntervals, SnapshotNameVersion, - SnapshotNameVersionLike, SnapshotTableCleanupTask, SnapshotTableInfo, fingerprint_from_node, start_date, ) -from sqlmesh.core.snapshot.cache import SnapshotCache from sqlmesh.core.snapshot.definition import ( Interval, _parents_from_node, @@ -70,12 +63,12 @@ from sqlmesh.core.state_sync.common import transactional from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState -from sqlmesh.utils import major_minor, unique +from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState +from sqlmesh.utils import major_minor from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now, now_timestamp, to_timestamp +from sqlmesh.utils.date import TimeLike, now_timestamp, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError from sqlmesh.utils.migration import index_text_type -from sqlmesh.utils.pydantic import PydanticModel logger = logging.getLogger(__name__) @@ -120,42 +113,23 @@ def __init__( ): self.interval_state = IntervalState(engine_adapter, schema=schema) self.environment_state = EnvironmentState(engine_adapter, schema=schema) + self.snapshot_state = SnapshotState( + engine_adapter, schema=schema, context_path=context_path + ) # Make sure that if an empty string is provided that we treat it as None self.schema = schema or None self.engine_adapter = engine_adapter self.console = console or get_console() - self.snapshots_table = exp.table_("_snapshots", db=self.schema) self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) - self.auto_restatements_table = exp.table_("_auto_restatements", db=self.schema) self.versions_table = exp.table_("_versions", db=self.schema) index_type = index_text_type(engine_adapter.dialect) - self._snapshot_columns_to_types = { - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build("text"), - "updated_ts": exp.DataType.build("bigint"), - "unpaused_ts": exp.DataType.build("bigint"), - "ttl_ms": exp.DataType.build("bigint"), - "unrestorable": exp.DataType.build("boolean"), - } - - self._auto_restatement_columns_to_types = { - "snapshot_name": exp.DataType.build(index_type), - "snapshot_version": exp.DataType.build(index_type), - "next_auto_restatement_ts": exp.DataType.build("bigint"), - } - self._version_columns_to_types = { "schema_version": exp.DataType.build("int"), "sqlglot_version": exp.DataType.build(index_type), "sqlmesh_version": exp.DataType.build(index_type), } - self._snapshot_cache = SnapshotCache(context_path / c.CACHE) - def _fetchone(self, query: t.Union[exp.Expression, str]) -> t.Optional[t.Tuple]: return self.engine_adapter.fetchone( query, ignore_unsupported_errors=True, quote_identifiers=True @@ -175,7 +149,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: insert all the local snapshots. This can be made safer with locks or merge/upsert. Args: - snapshot_ids: Iterable of snapshot ids to bulk push. + snapshots: The snapshots to push. """ snapshots_by_id = {} for snapshot in snapshots: @@ -198,30 +172,8 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: snapshots_by_id.pop(sid) snapshots = snapshots_by_id.values() - if snapshots: - self._push_snapshots(snapshots) - for snapshot in snapshots: - self._snapshot_cache.put(snapshot) - - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: - if overwrite: - snapshots = tuple(snapshots) - self.delete_snapshots(snapshots) - - snapshots_to_store = [] - - for snapshot in snapshots: - if isinstance(snapshot.node, SeedModel): - seed_model = t.cast(SeedModel, snapshot.node) - snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) - snapshots_to_store.append(snapshot) - - self.engine_adapter.insert_append( - self.snapshots_table, - _snapshots_to_df(snapshots_to_store), - columns_to_types=self._snapshot_columns_to_types, - ) + self.snapshot_state.push_snapshots(snapshots) @transactional() def promote( @@ -282,7 +234,7 @@ def promote( "Please recreate the plan and try again" ) if no_gaps_snapshot_names != set(): - snapshots = self._get_snapshots(environment.snapshots).values() + snapshots = self.get_snapshots(environment.snapshots).values() self._ensure_no_gaps( snapshots, existing_environment, @@ -290,7 +242,7 @@ def promote( ) demoted_snapshots = set(existing_environment.snapshots) - set(environment.snapshots) # Update the updated_at attribute. - self._update_snapshots(demoted_snapshots) + self.snapshot_state.touch_snapshots(demoted_snapshots) missing_models = set(existing_table_infos) - { snapshot.name for snapshot in environment.promoted_snapshots @@ -334,7 +286,7 @@ def _ensure_no_gaps( and target_snapshots_by_name[s.name].version != s.version } - prev_snapshots = self._get_snapshots( + prev_snapshots = self.get_snapshots( changed_version_prev_snapshots_by_name.values() ).values() cache: t.Dict[str, datetime] = {} @@ -376,76 +328,7 @@ def finalize(self, environment: Environment) -> None: def unpause_snapshots( self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike ) -> None: - current_ts = now() - - target_snapshot_ids = {s.snapshot_id for s in snapshots} - same_version_snapshots = self._get_snapshots_with_same_version( - snapshots, lock_for_update=True - ) - target_snapshots_by_version = { - (s.name, s.version): s - for s in same_version_snapshots - if s.snapshot_id in target_snapshot_ids - } - - unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list) - paused_snapshots: t.List[SnapshotId] = [] - unrestorable_snapshots: t.List[SnapshotId] = [] - - for snapshot in same_version_snapshots: - is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids - if is_target_snapshot and not snapshot.unpaused_ts: - logger.info("Unpausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(unpaused_dt) - assert snapshot.unpaused_ts is not None - unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id) - elif not is_target_snapshot: - target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] - if ( - target_snapshot.normalized_effective_from_ts - and not target_snapshot.disable_restatement - ): - # Making sure that there are no overlapping intervals. - effective_from_ts = target_snapshot.normalized_effective_from_ts - logger.info( - "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", - target_snapshot.effective_from, - snapshot.snapshot_id, - target_snapshot.snapshot_id, - ) - full_snapshot = snapshot.full_snapshot - self.remove_intervals( - [ - ( - full_snapshot, - full_snapshot.get_removal_interval(effective_from_ts, current_ts), - ) - ] - ) - - if snapshot.unpaused_ts: - logger.info("Pausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(None) - paused_snapshots.append(snapshot.snapshot_id) - - if ( - not snapshot.is_forward_only - and target_snapshot.is_forward_only - and not snapshot.unrestorable - ): - logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) - snapshot.unrestorable = True - unrestorable_snapshots.append(snapshot.snapshot_id) - - if unpaused_snapshots: - for unpaused_ts, snapshot_ids in unpaused_snapshots.items(): - self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts) - - if paused_snapshots: - self._update_snapshots(paused_snapshots, unpaused_ts=None) - - if unrestorable_snapshots: - self._update_snapshots(unrestorable_snapshots, unrestorable=True) + self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state) def _update_versions( self, @@ -476,152 +359,43 @@ def invalidate_environment(self, name: str) -> None: def delete_expired_snapshots( self, ignore_ttl: bool = False ) -> t.List[SnapshotTableCleanupTask]: - current_ts = now_timestamp(minute_floor=False) - - expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) - - if not ignore_ttl: - expired_query = expired_query.where( - (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts - ) - - expired_candidates = { - SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( - name=name, version=version - ) - for name, identifier, version in self._fetchall(expired_query) - } - if not expired_candidates: - return [] - - promoted_snapshot_ids = { - snapshot.snapshot_id - for environment in self.get_environments() - for snapshot in environment.snapshots - } - - def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: - return ( - snapshot.snapshot_id in promoted_snapshot_ids - or snapshot.snapshot_id not in expired_candidates - ) - - unique_expired_versions = unique(expired_candidates.values()) - version_batches = self._batches(unique_expired_versions) - cleanup_targets = [] - expired_snapshot_ids = set() - for versions_batch in version_batches: - snapshots = self._get_snapshots_with_same_version(versions_batch) - - snapshots_by_version = defaultdict(set) - snapshots_by_dev_version = defaultdict(set) - for s in snapshots: - snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) - snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) - - expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] - expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) - - for snapshot in expired_snapshots: - shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] - shared_version_snapshots.discard(snapshot.snapshot_id) - - shared_dev_version_snapshots = snapshots_by_dev_version[ - (snapshot.name, snapshot.dev_version) - ] - shared_dev_version_snapshots.discard(snapshot.snapshot_id) - - if not shared_dev_version_snapshots: - cleanup_targets.append( - SnapshotTableCleanupTask( - snapshot=snapshot.full_snapshot.table_info, - dev_table_only=bool(shared_version_snapshots), - ) - ) - - if expired_snapshot_ids: - self.delete_snapshots(expired_snapshot_ids) - + expired_snapshot_ids, cleanup_targets = self.snapshot_state.delete_expired_snapshots( + self.environment_state.get_environments(), ignore_ttl=ignore_ttl + ) self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) - return cleanup_targets def delete_expired_environments(self) -> t.List[Environment]: return self.environment_state.delete_expired_environments() def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: - if not snapshot_ids: - return - for where in self._snapshot_id_filter(snapshot_ids): - self.engine_adapter.delete_from(self.snapshots_table, where=where) + self.snapshot_state.delete_snapshots(snapshot_ids) def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: - return self._snapshot_ids_exist(snapshot_ids, self.snapshots_table) + return self.snapshot_state.snapshots_exist(snapshot_ids) def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: - names = set(names) - - if not names: - return names - - query = ( - exp.select("name") - .from_(self.snapshots_table) - .where(exp.column("name").isin(*names)) - .distinct() - ) - if exclude_external: - query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) - return {name for (name,) in self._fetchall(query)} + return self.snapshot_state.nodes_exist(names, exclude_external) def reset(self, default_catalog: t.Optional[str]) -> None: """Resets the state store to the state when it was first initialized.""" for table in ( - self.snapshots_table, + self.snapshot_state.snapshots_table, + self.snapshot_state.auto_restatements_table, self.environment_state.environments_table, self.interval_state.intervals_table, self.plan_dags_table, self.versions_table, ): self.engine_adapter.drop_table(table) - self._snapshot_cache.clear() + self.snapshot_state.clear_cache() self.migrate(default_catalog) @transactional() def update_auto_restatements( self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] ) -> None: - for where in self._snapshot_name_version_filter( - next_auto_restatement_ts, column_prefix="snapshot", alias=None - ): - self.engine_adapter.delete_from(self.auto_restatements_table, where=where) - - next_auto_restatement_ts_filtered = { - k: v for k, v in next_auto_restatement_ts.items() if v is not None - } - if not next_auto_restatement_ts_filtered: - return - - self.engine_adapter.insert_append( - self.auto_restatements_table, - _auto_restatements_to_df(next_auto_restatement_ts_filtered), - columns_to_types=self._auto_restatement_columns_to_types, - ) - - def _update_snapshots( - self, - snapshots: t.Iterable[SnapshotIdLike], - **kwargs: t.Any, - ) -> None: - properties = kwargs - properties["updated_ts"] = now_timestamp() - - for where in self._snapshot_id_filter(snapshots): - self.engine_adapter.update_table( - self.snapshots_table, - properties, - where=where, - ) + self.snapshot_state.update_auto_restatements(next_auto_restatement_ts) def get_environment(self, environment: str) -> t.Optional[Environment]: return self.environment_state.get_environment(environment) @@ -646,216 +420,19 @@ def get_snapshots( self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], ) -> t.Dict[SnapshotId, Snapshot]: - if snapshot_ids is None: - raise SQLMeshError("Must provide snapshot IDs to fetch snapshots.") - return self._get_snapshots(snapshot_ids) - - def _get_snapshots( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - lock_for_update: bool = False, - hydrate_intervals: bool = True, - ) -> t.Dict[SnapshotId, Snapshot]: - """Fetches specified snapshots or all snapshots. + """Fetches snapshots from the state. Args: - snapshot_ids: The collection of snapshot like objects to fetch. - lock_for_update: Lock the snapshot rows for future update - hydrate_intervals: Whether to hydrate result snapshots with intervals. + snapshot_ids: The snapshot IDs to fetch. Returns: - A dictionary of snapshot ids to snapshots for ones that could be found. + A dict of snapshots. """ - duplicates: t.Dict[SnapshotId, Snapshot] = {} - - def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: - fetched_snapshots: t.Dict[SnapshotId, Snapshot] = {} - for query in self._get_snapshots_expressions(snapshot_ids_to_load, lock_for_update): - for ( - serialized_snapshot, - _, - _, - _, - updated_ts, - unpaused_ts, - unrestorable, - next_auto_restatement_ts, - ) in self._fetchall(query): - snapshot = parse_snapshot( - serialized_snapshot=serialized_snapshot, - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - next_auto_restatement_ts=next_auto_restatement_ts, - ) - snapshot_id = snapshot.snapshot_id - if snapshot_id in fetched_snapshots: - other = duplicates.get(snapshot_id, fetched_snapshots[snapshot_id]) - duplicates[snapshot_id] = ( - snapshot if snapshot.updated_ts > other.updated_ts else other - ) - fetched_snapshots[snapshot_id] = duplicates[snapshot_id] - else: - fetched_snapshots[snapshot_id] = snapshot - return fetched_snapshots.values() - - snapshots, cached_snapshots = self._snapshot_cache.get_or_load( - {s.snapshot_id for s in snapshot_ids}, _loader - ) - - if cached_snapshots: - cached_snapshots_in_state: t.Set[SnapshotId] = set() - for where in self._snapshot_id_filter(cached_snapshots): - query = ( - exp.select( - "name", - "identifier", - "updated_ts", - "unpaused_ts", - "unrestorable", - "next_auto_restatement_ts", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .join( - exp.to_table(self.auto_restatements_table).as_("auto_restatements"), - on=exp.and_( - exp.column("name", table="snapshots").eq( - exp.column("snapshot_name", table="auto_restatements") - ), - exp.column("version", table="snapshots").eq( - exp.column("snapshot_version", table="auto_restatements") - ), - ), - join_type="left", - copy=False, - ) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - for ( - name, - identifier, - updated_ts, - unpaused_ts, - unrestorable, - next_auto_restatement_ts, - ) in self._fetchall(query): - snapshot_id = SnapshotId(name=name, identifier=identifier) - snapshot = snapshots[snapshot_id] - snapshot.updated_ts = updated_ts - snapshot.unpaused_ts = unpaused_ts - snapshot.unrestorable = unrestorable - snapshot.next_auto_restatement_ts = next_auto_restatement_ts - cached_snapshots_in_state.add(snapshot_id) - - missing_cached_snapshots = cached_snapshots - cached_snapshots_in_state - for missing_cached_snapshot_id in missing_cached_snapshots: - snapshots.pop(missing_cached_snapshot_id, None) - - if snapshots and hydrate_intervals: - intervals = self.interval_state.get_snapshot_intervals(snapshots.values()) - Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) - - if duplicates: - self._push_snapshots(duplicates.values(), overwrite=True) - logger.error("Found duplicate snapshots in the state store.") - + snapshots = self.snapshot_state.get_snapshots(snapshot_ids) + intervals = self.interval_state.get_snapshot_intervals(snapshots.values()) + Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) return snapshots - def _get_snapshots_expressions( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - lock_for_update: bool = False, - batch_size: t.Optional[int] = None, - ) -> t.Iterator[exp.Expression]: - for where in self._snapshot_id_filter( - snapshot_ids, alias="snapshots", batch_size=batch_size - ): - query = ( - exp.select( - "snapshots.snapshot", - "snapshots.name", - "snapshots.identifier", - "snapshots.version", - "snapshots.updated_ts", - "snapshots.unpaused_ts", - "snapshots.unrestorable", - "auto_restatements.next_auto_restatement_ts", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .join( - exp.to_table(self.auto_restatements_table).as_("auto_restatements"), - on=exp.and_( - exp.column("name", table="snapshots").eq( - exp.column("snapshot_name", table="auto_restatements") - ), - exp.column("version", table="snapshots").eq( - exp.column("snapshot_version", table="auto_restatements") - ), - ), - join_type="left", - copy=False, - ) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - yield query - - def _get_snapshots_with_same_version( - self, - snapshots: t.Collection[SnapshotNameVersionLike], - lock_for_update: bool = False, - ) -> t.List[SharedVersionSnapshot]: - """Fetches all snapshots that share the same version as the snapshots. - - The output includes the snapshots with the specified identifiers. - - Args: - snapshots: The collection of target name / version pairs. - lock_for_update: Lock the snapshot rows for future update - - Returns: - The list of Snapshot objects. - """ - if not snapshots: - return [] - - snapshot_rows = [] - - for where in self._snapshot_name_version_filter(snapshots): - query = ( - exp.select( - "snapshot", - "name", - "identifier", - "version", - "updated_ts", - "unpaused_ts", - "unrestorable", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - - snapshot_rows.extend(self._fetchall(query)) - - return [ - SharedVersionSnapshot.from_snapshot_record( - name=name, - identifier=identifier, - version=version, - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - snapshot=snapshot, - ) - for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows - ] - def _get_versions(self, lock_for_update: bool = False) -> Versions: no_version = Versions() @@ -993,14 +570,14 @@ def rollback(self) -> None: """Rollback to the previous migration.""" logger.info("Starting migration rollback.") tables = ( - self.snapshots_table, + self.snapshot_state.snapshots_table, self.environment_state.environments_table, self.versions_table, ) optional_tables = ( self.interval_state.intervals_table, self.plan_dags_table, - self.auto_restatements_table, + self.snapshot_state.auto_restatements_table, ) versions = self.get_versions(validate=False) if versions.schema_version == 0: @@ -1027,12 +604,12 @@ def state_type(self) -> str: @transactional() def _backup_state(self) -> None: for table in ( - self.snapshots_table, + self.snapshot_state.snapshots_table, self.environment_state.environments_table, self.versions_table, self.interval_state.intervals_table, self.plan_dags_table, - self.auto_restatements_table, + self.snapshot_state.auto_restatements_table, ): if self.engine_adapter.table_exists(table): backup_name = _backup_table_name(table) @@ -1040,10 +617,6 @@ def _backup_state(self) -> None: self.engine_adapter.create_table_like(backup_name, table) self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) - def _snapshot_count(self) -> int: - result = self._fetchone(exp.select("COUNT(*)").from_(self.snapshots_table)) - return result[0] if result else 0 - def _apply_migrations( self, default_catalog: t.Optional[str], @@ -1061,13 +634,13 @@ def _apply_migrations( if not skip_backup and should_backup: self._backup_state() - snapshot_count_before = self._snapshot_count() if versions.schema_version else None + snapshot_count_before = self.snapshot_state.count() if versions.schema_version else None for migration in migrations: logger.info(f"Applying migration {migration}") migration.migrate(self, default_catalog=default_catalog) - snapshot_count_after = self._snapshot_count() + snapshot_count_after = self.snapshot_state.count() if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" @@ -1113,7 +686,7 @@ def _migrate_snapshot_rows( exp.select( "name", "identifier", "snapshot", "updated_ts", "unpaused_ts", "unrestorable" ) - .from_(self.snapshots_table) + .from_(self.snapshot_state.snapshots_table) .where(where) .lock() ) @@ -1150,7 +723,7 @@ def _push_new_snapshots() -> None: ] if new_snapshots_to_push: logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) - self._push_snapshots(new_snapshots_to_push) + self.snapshot_state.push_snapshots(new_snapshots_to_push) new_snapshots.clear() snapshot_id_mapping.clear() @@ -1284,17 +857,6 @@ def _migrate_environment_rows( self.console.stop_env_migration_progress() - def _snapshot_ids_exist( - self, snapshot_ids: t.Iterable[SnapshotIdLike], table_name: exp.Table - ) -> t.Set[SnapshotId]: - return { - SnapshotId(name=name, identifier=identifier) - for where in self._snapshot_id_filter(snapshot_ids) - for name, identifier in self._fetchall( - exp.select("name", "identifier").from_(table_name).where(where) - ) - } - def _snapshot_id_filter( self, snapshot_ids: t.Iterable[SnapshotIdLike], @@ -1331,49 +893,6 @@ def _snapshot_id_filter( ] ) - def _snapshot_name_version_filter( - self, - snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], - version_column_name: str = "version", - alias: t.Optional[str] = "snapshots", - column_prefix: t.Optional[str] = None, - ) -> t.Iterator[exp.Condition]: - name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) - batches = self._batches(name_versions) - - name_column_name = "name" - if column_prefix: - name_column_name = f"{column_prefix}_{name_column_name}" - version_column_name = f"{column_prefix}_{version_column_name}" - - name_column = exp.column(name_column_name, table=alias) - version_column = exp.column(version_column_name, table=alias) - - if not name_versions: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for versions in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - name_column, - version_column, - ) - ), - ).isin(*versions) - else: - for versions in batches: - yield exp.or_( - *[ - exp.and_( - name_column.eq(name), - version_column.eq(version), - ) - for name, version in versions - ] - ) - def _batches(self, l: t.List[T], batch_size: t.Optional[int] = None) -> t.List[t.List[T]]: batch_size = batch_size or self.SNAPSHOT_BATCH_SIZE return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] @@ -1384,76 +903,12 @@ def _transaction(self) -> t.Iterator[None]: yield -def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": snapshot.name, - "identifier": snapshot.identifier, - "version": snapshot.version, - "snapshot": _snapshot_to_json(snapshot), - "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, - "updated_ts": snapshot.updated_ts, - "unpaused_ts": snapshot.unpaused_ts, - "ttl_ms": snapshot.ttl_ms, - "unrestorable": snapshot.unrestorable, - } - for snapshot in snapshots - ] - ) - - -def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "snapshot_name": name_version.name, - "snapshot_version": name_version.version, - "next_auto_restatement_ts": ts, - } - for name_version, ts in auto_restatements.items() - ] - ) - - def _backup_table_name(table_name: TableName) -> exp.Table: table = exp.to_table(table_name).copy() table.set("this", exp.to_identifier(table.name + "_backup")) return table -def _snapshot_to_json(snapshot: Snapshot) -> str: - return snapshot.json( - exclude={ - "intervals", - "dev_intervals", - "pending_restatement_intervals", - "updated_ts", - "unpaused_ts", - "unrestorable", - "next_auto_restatement_ts", - } - ) - - -def parse_snapshot( - serialized_snapshot: str, - updated_ts: int, - unpaused_ts: t.Optional[int], - unrestorable: bool, - next_auto_restatement_ts: t.Optional[int], -) -> Snapshot: - return Snapshot( - **{ - **json.loads(serialized_snapshot), - "updated_ts": updated_ts, - "unpaused_ts": unpaused_ts, - "unrestorable": unrestorable, - "next_auto_restatement_ts": next_auto_restatement_ts, - } - ) - - class LazilyParsedSnapshots: def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): self._raw_snapshots = raw_snapshots @@ -1476,94 +931,3 @@ def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: if snapshot is None: raise KeyError(snapshot_id) return snapshot - - -class SharedVersionSnapshot(PydanticModel): - """A stripped down version of a snapshot that is used for fetching snapshots that share the same version - with a significantly reduced parsing overhead. - """ - - name: str - version: str - dev_version_: t.Optional[str] = Field(alias="dev_version") - identifier: str - fingerprint: SnapshotFingerprint - interval_unit: IntervalUnit - change_category: SnapshotChangeCategory - updated_ts: int - unpaused_ts: t.Optional[int] - unrestorable: bool - disable_restatement: bool - effective_from: t.Optional[TimeLike] - raw_snapshot: t.Dict[str, t.Any] - - @property - def snapshot_id(self) -> SnapshotId: - return SnapshotId(name=self.name, identifier=self.identifier) - - @property - def is_forward_only(self) -> bool: - return self.change_category == SnapshotChangeCategory.FORWARD_ONLY - - @property - def normalized_effective_from_ts(self) -> t.Optional[int]: - return ( - to_timestamp(self.interval_unit.cron_floor(self.effective_from)) - if self.effective_from - else None - ) - - @property - def dev_version(self) -> str: - return self.dev_version_ or self.fingerprint.to_version() - - @property - def full_snapshot(self) -> Snapshot: - return Snapshot( - **{ - **self.raw_snapshot, - "updated_ts": self.updated_ts, - "unpaused_ts": self.unpaused_ts, - "unrestorable": self.unrestorable, - } - ) - - def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: - """Sets the timestamp for when this snapshot was unpaused. - - Args: - unpaused_dt: The datetime object of when this snapshot was unpaused. - """ - self.unpaused_ts = ( - to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None - ) - - @classmethod - def from_snapshot_record( - cls, - *, - name: str, - identifier: str, - version: str, - updated_ts: int, - unpaused_ts: t.Optional[int], - unrestorable: bool, - snapshot: str, - ) -> SharedVersionSnapshot: - raw_snapshot = json.loads(snapshot) - raw_node = raw_snapshot["node"] - return SharedVersionSnapshot( - name=name, - version=version, - dev_version=raw_snapshot.get("dev_version"), - identifier=identifier, - fingerprint=raw_snapshot["fingerprint"], - interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])), - change_category=raw_snapshot["change_category"], - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), - effective_from=raw_snapshot.get("effective_from"), - raw_snapshot=raw_snapshot, - ) diff --git a/sqlmesh/core/state_sync/engine_adapter/interval.py b/sqlmesh/core/state_sync/engine_adapter/interval.py index 786a0cc698..e44b11ada1 100644 --- a/sqlmesh/core/state_sync/engine_adapter/interval.py +++ b/sqlmesh/core/state_sync/engine_adapter/interval.py @@ -34,6 +34,7 @@ class IntervalState: INTERVAL_BATCH_SIZE = 1000 + SNAPSHOT_BATCH_SIZE = 1000 def __init__( self, @@ -120,7 +121,10 @@ def remove_intervals( name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} all_snapshots = [] for where in snapshot_name_version_filter( - self.engine_adapter, name_version_mapping, alias=None + self.engine_adapter, + name_version_mapping, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, ): all_snapshots.extend( [ @@ -202,7 +206,7 @@ def max_interval_end_per_model( result: t.Dict[str, int] = {} for where in snapshot_name_version_filter( - self.engine_adapter, snapshots, alias=table_alias + self.engine_adapter, snapshots, alias=table_alias, batch_size=self.SNAPSHOT_BATCH_SIZE ): query = ( exp.select( @@ -299,7 +303,12 @@ def _get_snapshot_intervals( ] = {} for where in ( - snapshot_name_version_filter(self.engine_adapter, snapshots, alias="intervals") + snapshot_name_version_filter( + self.engine_adapter, + snapshots, + alias="intervals", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) if snapshots else [None] ): @@ -409,7 +418,9 @@ def _update_intervals_for_deleted_snapshots( if not snapshot_ids: return - for where in snapshot_id_filter(self.engine_adapter, snapshot_ids, alias=None): + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, alias=None, batch_size=self.SNAPSHOT_BATCH_SIZE + ): # Nullify the identifier for dev intervals # Set is_compacted to False so that it's compacted during the next compaction self.engine_adapter.update_table( @@ -436,7 +447,11 @@ def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupT return for where in snapshot_name_version_filter( - self.engine_adapter, dev_keys_to_delete, version_column_name="dev_version", alias=None + self.engine_adapter, + dev_keys_to_delete, + version_column_name="dev_version", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, ): self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) @@ -447,7 +462,10 @@ def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask] return for where in snapshot_name_version_filter( - self.engine_adapter, non_dev_keys_to_delete, alias=None + self.engine_adapter, + non_dev_keys_to_delete, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, ): self.engine_adapter.delete_from(self.intervals_table, where) diff --git a/sqlmesh/core/state_sync/engine_adapter/snapshot.py b/sqlmesh/core/state_sync/engine_adapter/snapshot.py new file mode 100644 index 0000000000..d8ce8d7af8 --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/snapshot.py @@ -0,0 +1,790 @@ +from __future__ import annotations + +import typing as t +import pandas as pd +import json +import logging +from pathlib import Path +from collections import defaultdict +from sqlglot import exp +from pydantic import Field + +from sqlmesh.core import constants as c +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.engine_adapter.utils import ( + snapshot_name_version_filter, + snapshot_id_filter, + fetchall, + create_batches, +) +from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import SeedModel, ModelKindName +from sqlmesh.core.snapshot.cache import SnapshotCache +from sqlmesh.core.snapshot import ( + SnapshotIdLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + SnapshotInfoLike, + Snapshot, + SnapshotId, + SnapshotFingerprint, + SnapshotChangeCategory, +) +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils import unique + +if t.TYPE_CHECKING: + from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState + + +logger = logging.getLogger(__name__) + + +class SnapshotState: + SNAPSHOT_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + context_path: Path = Path(), + ): + self.engine_adapter = engine_adapter + self.snapshots_table = exp.table_("_snapshots", db=schema) + self.auto_restatements_table = exp.table_("_auto_restatements", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + self._snapshot_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + + self._auto_restatement_columns_to_types = { + "snapshot_name": exp.DataType.build(index_type), + "snapshot_version": exp.DataType.build(index_type), + "next_auto_restatement_ts": exp.DataType.build("bigint"), + } + + self._snapshot_cache = SnapshotCache(context_path / c.CACHE) + + def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: + """Pushes snapshots to the state store. + + Args: + snapshots: The snapshots to push. + overwrite: Whether to overwrite existing snapshots. + """ + if overwrite: + snapshots = tuple(snapshots) + self.delete_snapshots(snapshots) + + snapshots_to_store = [] + + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + columns_to_types=self._snapshot_columns_to_types, + ) + + for snapshot in snapshots: + self._snapshot_cache.put(snapshot) + + def unpause_snapshots( + self, + snapshots: t.Collection[SnapshotInfoLike], + unpaused_dt: TimeLike, + interval_state: IntervalState, + ) -> None: + """Unpauses given snapshots while pausing all other snapshots that share the same version. + + Args: + snapshots: The snapshots to unpause. + unpaused_dt: The timestamp to unpause the snapshots at. + interval_state: The interval state to use to remove intervals when needed. + """ + current_ts = now() + + target_snapshot_ids = {s.snapshot_id for s in snapshots} + same_version_snapshots = self._get_snapshots_with_same_version( + snapshots, lock_for_update=True + ) + target_snapshots_by_version = { + (s.name, s.version): s + for s in same_version_snapshots + if s.snapshot_id in target_snapshot_ids + } + + unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list) + paused_snapshots: t.List[SnapshotId] = [] + unrestorable_snapshots: t.List[SnapshotId] = [] + + for snapshot in same_version_snapshots: + is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids + if is_target_snapshot and not snapshot.unpaused_ts: + logger.info("Unpausing snapshot %s", snapshot.snapshot_id) + snapshot.set_unpaused_ts(unpaused_dt) + assert snapshot.unpaused_ts is not None + unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id) + elif not is_target_snapshot: + target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] + if ( + target_snapshot.normalized_effective_from_ts + and not target_snapshot.disable_restatement + ): + # Making sure that there are no overlapping intervals. + effective_from_ts = target_snapshot.normalized_effective_from_ts + logger.info( + "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", + target_snapshot.effective_from, + snapshot.snapshot_id, + target_snapshot.snapshot_id, + ) + full_snapshot = snapshot.full_snapshot + interval_state.remove_intervals( + [ + ( + full_snapshot, + full_snapshot.get_removal_interval(effective_from_ts, current_ts), + ) + ] + ) + + if snapshot.unpaused_ts: + logger.info("Pausing snapshot %s", snapshot.snapshot_id) + snapshot.set_unpaused_ts(None) + paused_snapshots.append(snapshot.snapshot_id) + + if ( + not snapshot.is_forward_only + and target_snapshot.is_forward_only + and not snapshot.unrestorable + ): + logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) + snapshot.unrestorable = True + unrestorable_snapshots.append(snapshot.snapshot_id) + + if unpaused_snapshots: + for unpaused_ts, snapshot_ids in unpaused_snapshots.items(): + self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts) + + if paused_snapshots: + self._update_snapshots(paused_snapshots, unpaused_ts=None) + + if unrestorable_snapshots: + self._update_snapshots(unrestorable_snapshots, unrestorable=True) + + def delete_expired_snapshots( + self, environments: t.Iterable[Environment], ignore_ttl: bool = False + ) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]: + """Deletes expired snapshots. + + Args: + ignore_ttl: Whether to ignore the TTL of the snapshots. + + Returns: + A tuple of expired snapshot IDs and cleanup targets. + """ + current_ts = now_timestamp(minute_floor=False) + + expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) + + if not ignore_ttl: + expired_query = expired_query.where( + (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts + ) + + expired_candidates = { + SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( + name=name, version=version + ) + for name, identifier, version in fetchall(self.engine_adapter, expired_query) + } + if not expired_candidates: + return set(), [] + + promoted_snapshot_ids = { + snapshot.snapshot_id + for environment in environments + for snapshot in environment.snapshots + } + + def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: + return ( + snapshot.snapshot_id in promoted_snapshot_ids + or snapshot.snapshot_id not in expired_candidates + ) + + unique_expired_versions = unique(expired_candidates.values()) + version_batches = create_batches( + unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + cleanup_targets = [] + expired_snapshot_ids = set() + for versions_batch in version_batches: + snapshots = self._get_snapshots_with_same_version(versions_batch) + + snapshots_by_version = defaultdict(set) + snapshots_by_dev_version = defaultdict(set) + for s in snapshots: + snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) + snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) + + expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] + expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) + + for snapshot in expired_snapshots: + shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] + shared_version_snapshots.discard(snapshot.snapshot_id) + + shared_dev_version_snapshots = snapshots_by_dev_version[ + (snapshot.name, snapshot.dev_version) + ] + shared_dev_version_snapshots.discard(snapshot.snapshot_id) + + if not shared_dev_version_snapshots: + cleanup_targets.append( + SnapshotTableCleanupTask( + snapshot=snapshot.full_snapshot.table_info, + dev_table_only=bool(shared_version_snapshots), + ) + ) + + if expired_snapshot_ids: + self.delete_snapshots(expired_snapshot_ids) + return expired_snapshot_ids, cleanup_targets + + def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Deletes snapshots. + + Args: + snapshot_ids: The snapshot IDs to delete. + """ + if not snapshot_ids: + return + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.delete_from(self.snapshots_table, where=where) + + def touch_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Touch snapshots to set their updated_ts to the current timestamp. + + Args: + snapshot_ids: The snapshot IDs to touch. + """ + self._update_snapshots(snapshot_ids) + + def get_snapshots( + self, + snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches snapshots. + + Args: + snapshot_ids: The snapshot IDs to fetch. + + Returns: + A dictionary of snapshot IDs to snapshots. + """ + if snapshot_ids is None: + raise SQLMeshError("Must provide snapshot IDs to fetch snapshots.") + return self._get_snapshots(snapshot_ids) + + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: + """Checks if snapshots exist. + + Args: + snapshot_ids: The snapshot IDs to check. + + Returns: + A set of snapshot IDs to check for existence. + """ + return { + SnapshotId(name=name, identifier=identifier) + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + for name, identifier in fetchall( + self.engine_adapter, + exp.select("name", "identifier").from_(self.snapshots_table).where(where), + ) + } + + def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: + """Checks if nodes with given names exist. + + Args: + names: The node names to check. + exclude_external: Whether to exclude external nodes. + + Returns: + A set of node names that exist. + """ + names = set(names) + + if not names: + return names + + query = ( + exp.select("name") + .from_(self.snapshots_table) + .where(exp.column("name").isin(*names)) + .distinct() + ) + if exclude_external: + query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) + return {name for (name,) in fetchall(self.engine_adapter, query)} + + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + """Updates the auto restatement timestamps. + + Args: + next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp. + """ + for where in snapshot_name_version_filter( + self.engine_adapter, + next_auto_restatement_ts, + column_prefix="snapshot", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.auto_restatements_table, where=where) + + next_auto_restatement_ts_filtered = { + k: v for k, v in next_auto_restatement_ts.items() if v is not None + } + if not next_auto_restatement_ts_filtered: + return + + self.engine_adapter.insert_append( + self.auto_restatements_table, + _auto_restatements_to_df(next_auto_restatement_ts_filtered), + columns_to_types=self._auto_restatement_columns_to_types, + ) + + def count(self) -> int: + """Counts the number of snapshots in the state.""" + result = self.engine_adapter.fetchone(exp.select("COUNT(*)").from_(self.snapshots_table)) + return result[0] if result else 0 + + def clear_cache(self) -> None: + """Clears the snapshot cache.""" + self._snapshot_cache.clear() + + def _update_snapshots( + self, + snapshots: t.Iterable[SnapshotIdLike], + **kwargs: t.Any, + ) -> None: + properties = kwargs + properties["updated_ts"] = now_timestamp() + + for where in snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.update_table( + self.snapshots_table, + properties, + where=where, + ) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: + if overwrite: + snapshots = tuple(snapshots) + self.delete_snapshots(snapshots) + + snapshots_to_store = [] + + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + columns_to_types=self._snapshot_columns_to_types, + ) + + def _get_snapshots( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches specified snapshots or all snapshots. + + Args: + snapshot_ids: The collection of snapshot like objects to fetch. + lock_for_update: Lock the snapshot rows for future update + + Returns: + A dictionary of snapshot ids to snapshots for ones that could be found. + """ + duplicates: t.Dict[SnapshotId, Snapshot] = {} + + def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: + fetched_snapshots: t.Dict[SnapshotId, Snapshot] = {} + for query in self._get_snapshots_expressions(snapshot_ids_to_load, lock_for_update): + for ( + serialized_snapshot, + _, + _, + _, + updated_ts, + unpaused_ts, + unrestorable, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot = parse_snapshot( + serialized_snapshot=serialized_snapshot, + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + next_auto_restatement_ts=next_auto_restatement_ts, + ) + snapshot_id = snapshot.snapshot_id + if snapshot_id in fetched_snapshots: + other = duplicates.get(snapshot_id, fetched_snapshots[snapshot_id]) + duplicates[snapshot_id] = ( + snapshot if snapshot.updated_ts > other.updated_ts else other + ) + fetched_snapshots[snapshot_id] = duplicates[snapshot_id] + else: + fetched_snapshots[snapshot_id] = snapshot + return fetched_snapshots.values() + + snapshots, cached_snapshots = self._snapshot_cache.get_or_load( + {s.snapshot_id for s in snapshot_ids}, _loader + ) + + if cached_snapshots: + cached_snapshots_in_state: t.Set[SnapshotId] = set() + for where in snapshot_id_filter( + self.engine_adapter, cached_snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "name", + "identifier", + "updated_ts", + "unpaused_ts", + "unrestorable", + "next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + for ( + name, + identifier, + updated_ts, + unpaused_ts, + unrestorable, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot_id = SnapshotId(name=name, identifier=identifier) + snapshot = snapshots[snapshot_id] + snapshot.updated_ts = updated_ts + snapshot.unpaused_ts = unpaused_ts + snapshot.unrestorable = unrestorable + snapshot.next_auto_restatement_ts = next_auto_restatement_ts + cached_snapshots_in_state.add(snapshot_id) + + missing_cached_snapshots = cached_snapshots - cached_snapshots_in_state + for missing_cached_snapshot_id in missing_cached_snapshots: + snapshots.pop(missing_cached_snapshot_id, None) + + if duplicates: + self.push_snapshots(duplicates.values(), overwrite=True) + logger.error("Found duplicate snapshots in the state store.") + + return snapshots + + def _get_snapshots_expressions( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Iterator[exp.Expression]: + for where in snapshot_id_filter( + self.engine_adapter, + snapshot_ids, + alias="snapshots", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + query = ( + exp.select( + "snapshots.snapshot", + "snapshots.name", + "snapshots.identifier", + "snapshots.version", + "snapshots.updated_ts", + "snapshots.unpaused_ts", + "snapshots.unrestorable", + "auto_restatements.next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + yield query + + def _get_snapshots_with_same_version( + self, + snapshots: t.Collection[SnapshotNameVersionLike], + lock_for_update: bool = False, + ) -> t.List[SharedVersionSnapshot]: + """Fetches all snapshots that share the same version as the snapshots. + + The output includes the snapshots with the specified identifiers. + + Args: + snapshots: The collection of target name / version pairs. + lock_for_update: Lock the snapshot rows for future update + + Returns: + The list of Snapshot objects. + """ + if not snapshots: + return [] + + snapshot_rows = [] + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "snapshot", + "name", + "identifier", + "version", + "updated_ts", + "unpaused_ts", + "unrestorable", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + + snapshot_rows.extend(fetchall(self.engine_adapter, query)) + + return [ + SharedVersionSnapshot.from_snapshot_record( + name=name, + identifier=identifier, + version=version, + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + snapshot=snapshot, + ) + for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows + ] + + +def parse_snapshot( + serialized_snapshot: str, + updated_ts: int, + unpaused_ts: t.Optional[int], + unrestorable: bool, + next_auto_restatement_ts: t.Optional[int], +) -> Snapshot: + return Snapshot( + **{ + **json.loads(serialized_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + "next_auto_restatement_ts": next_auto_restatement_ts, + } + ) + + +def _snapshot_to_json(snapshot: Snapshot) -> str: + return snapshot.json( + exclude={ + "intervals", + "dev_intervals", + "pending_restatement_intervals", + "updated_ts", + "unpaused_ts", + "unrestorable", + "next_auto_restatement_ts", + } + ) + + +def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "name": snapshot.name, + "identifier": snapshot.identifier, + "version": snapshot.version, + "snapshot": _snapshot_to_json(snapshot), + "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, + "updated_ts": snapshot.updated_ts, + "unpaused_ts": snapshot.unpaused_ts, + "ttl_ms": snapshot.ttl_ms, + "unrestorable": snapshot.unrestorable, + } + for snapshot in snapshots + ] + ) + + +def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "snapshot_name": name_version.name, + "snapshot_version": name_version.version, + "next_auto_restatement_ts": ts, + } + for name_version, ts in auto_restatements.items() + ] + ) + + +class SharedVersionSnapshot(PydanticModel): + """A stripped down version of a snapshot that is used for fetching snapshots that share the same version + with a significantly reduced parsing overhead. + """ + + name: str + version: str + dev_version_: t.Optional[str] = Field(alias="dev_version") + identifier: str + fingerprint: SnapshotFingerprint + interval_unit: IntervalUnit + change_category: SnapshotChangeCategory + updated_ts: int + unpaused_ts: t.Optional[int] + unrestorable: bool + disable_restatement: bool + effective_from: t.Optional[TimeLike] + raw_snapshot: t.Dict[str, t.Any] + + @property + def snapshot_id(self) -> SnapshotId: + return SnapshotId(name=self.name, identifier=self.identifier) + + @property + def is_forward_only(self) -> bool: + return self.change_category == SnapshotChangeCategory.FORWARD_ONLY + + @property + def normalized_effective_from_ts(self) -> t.Optional[int]: + return ( + to_timestamp(self.interval_unit.cron_floor(self.effective_from)) + if self.effective_from + else None + ) + + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + + @property + def full_snapshot(self) -> Snapshot: + return Snapshot( + **{ + **self.raw_snapshot, + "updated_ts": self.updated_ts, + "unpaused_ts": self.unpaused_ts, + "unrestorable": self.unrestorable, + } + ) + + def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: + """Sets the timestamp for when this snapshot was unpaused. + + Args: + unpaused_dt: The datetime object of when this snapshot was unpaused. + """ + self.unpaused_ts = ( + to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None + ) + + @classmethod + def from_snapshot_record( + cls, + *, + name: str, + identifier: str, + version: str, + updated_ts: int, + unpaused_ts: t.Optional[int], + unrestorable: bool, + snapshot: str, + ) -> SharedVersionSnapshot: + raw_snapshot = json.loads(snapshot) + raw_node = raw_snapshot["node"] + return SharedVersionSnapshot( + name=name, + version=version, + dev_version=raw_snapshot.get("dev_version"), + identifier=identifier, + fingerprint=raw_snapshot["fingerprint"], + interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])), + change_category=raw_snapshot["change_category"], + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), + effective_from=raw_snapshot.get("effective_from"), + raw_snapshot=raw_snapshot, + ) diff --git a/sqlmesh/core/state_sync/engine_adapter/utils.py b/sqlmesh/core/state_sync/engine_adapter/utils.py index 5520f94b87..951153da83 100644 --- a/sqlmesh/core/state_sync/engine_adapter/utils.py +++ b/sqlmesh/core/state_sync/engine_adapter/utils.py @@ -10,14 +10,11 @@ T = t.TypeVar("T") -DEFAULT_BATCH_SIZE = 1000 - - def snapshot_id_filter( engine_adapter: EngineAdapter, snapshot_ids: t.Iterable[SnapshotIdLike], + batch_size: int, alias: t.Optional[str] = None, - batch_size: int = DEFAULT_BATCH_SIZE, ) -> t.Iterator[exp.Condition]: name_identifiers = sorted( {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} @@ -53,10 +50,10 @@ def snapshot_id_filter( def snapshot_name_version_filter( engine_adapter: EngineAdapter, snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], + batch_size: int, version_column_name: str = "version", alias: t.Optional[str] = "snapshots", column_prefix: t.Optional[str] = None, - batch_size: int = DEFAULT_BATCH_SIZE, ) -> t.Iterator[exp.Condition]: name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) batches = create_batches(name_versions, batch_size=batch_size) @@ -95,7 +92,7 @@ def snapshot_name_version_filter( ) -def create_batches(l: t.List[T], batch_size: int = 1000) -> t.List[t.List[T]]: +def create_batches(l: t.List[T], batch_size: int) -> t.List[t.List[T]]: return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 2ae8e9cc98..1e2d2eee3d 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -195,10 +195,10 @@ def test_duplicates(state_sync: EngineAdapterStateSync, make_snapshot: t.Callabl snapshot_b.updated_ts = snapshot_a.updated_ts + 1 snapshot_c.updated_ts = 0 state_sync.push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_b]) - state_sync._push_snapshots([snapshot_c]) - state_sync._snapshot_cache.clear() + state_sync.snapshot_state.push_snapshots([snapshot_a]) + state_sync.snapshot_state.push_snapshots([snapshot_b]) + state_sync.snapshot_state.push_snapshots([snapshot_c]) + state_sync.snapshot_state.clear_cache() assert ( state_sync.get_snapshots([snapshot_a])[snapshot_a.snapshot_id].updated_ts == snapshot_b.updated_ts @@ -1236,7 +1236,9 @@ def test_delete_expired_snapshots_promoted( env.snapshots_ = [] state_sync.promote(env) - now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.engine_adapter.facade.now_timestamp") + now_timestamp_mock = mocker.patch( + "sqlmesh.core.state_sync.engine_adapter.snapshot.now_timestamp" + ) now_timestamp_mock.return_value = now_timestamp() + 11000 assert state_sync.delete_expired_snapshots() == [ @@ -2215,7 +2217,7 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> match="SQLMesh migration failed.", ): state_sync.migrate(default_catalog=None) - assert not state_sync.engine_adapter.table_exists(state_sync.snapshots_table) + assert not state_sync.engine_adapter.table_exists(state_sync.snapshot_state.snapshots_table) assert not state_sync.engine_adapter.table_exists( state_sync.environment_state.environments_table ) @@ -2379,7 +2381,7 @@ def test_seed_hydration( assert snapshot.model.is_hydrated assert snapshot.model.seed.content == "header\n1\n2" - state_sync._snapshot_cache.clear() + state_sync.snapshot_state.clear_cache() stored_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] assert isinstance(stored_snapshot.model, SeedModel) assert not stored_snapshot.model.is_hydrated @@ -2770,8 +2772,8 @@ def test_get_snapshots(mocker): def test_snapshot_batching(state_sync, mocker, make_snapshot): mock = mocker.Mock() - state_sync.SNAPSHOT_BATCH_SIZE = 2 - state_sync.engine_adapter = mock + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 2 + state_sync.snapshot_state.engine_adapter = mock snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1")), "1") snapshot_b = make_snapshot(SqlModel(name="a", query=parse_one("select 2")), "2") @@ -2842,13 +2844,12 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): ], ] - snapshots = state_sync._get_snapshots( + snapshots = state_sync.snapshot_state.get_snapshots( ( SnapshotId(name="a", identifier="1"), SnapshotId(name="a", identifier="2"), SnapshotId(name="a", identifier="3"), ), - hydrate_intervals=False, ) assert len(snapshots) == 3 calls = mock.fetchall.call_args_list @@ -2887,13 +2888,12 @@ def test_snapshot_cache( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture ): cache_mock = mocker.Mock() - state_sync._snapshot_cache = cache_mock + state_sync.snapshot_state._snapshot_cache = cache_mock snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1"))) cache_mock.get_or_load.return_value = ({snapshot.snapshot_id: snapshot}, {snapshot.snapshot_id}) - # Use _push_snapshots to bypass cache. - state_sync._push_snapshots([snapshot]) + state_sync.snapshot_state.push_snapshots([snapshot]) assert state_sync.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} cache_mock.get_or_load.assert_called_once_with({snapshot.snapshot_id}, mocker.ANY) @@ -2901,7 +2901,9 @@ def test_snapshot_cache( # Update the snapshot in the state and make sure this update is reflected on the cached instance. assert snapshot.unpaused_ts is None assert not snapshot.unrestorable - state_sync._update_snapshots([snapshot.snapshot_id], unpaused_ts=1, unrestorable=True) + state_sync.snapshot_state._update_snapshots( + [snapshot.snapshot_id], unpaused_ts=1, unrestorable=True + ) new_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] assert new_snapshot.unpaused_ts == 1 assert new_snapshot.unrestorable @@ -2916,7 +2918,7 @@ def test_update_auto_restatements(state_sync: EngineAdapterStateSync, make_snaps snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select 2")), version="2") snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("select 3")), version="3") - state_sync._push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + state_sync.snapshot_state.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] = { snapshot_a.name_version: 1, diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 2abb73765b..2fc1573928 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -27,7 +27,7 @@ ViewKind, ) from sqlmesh.core.model.kind import SCDType2ByColumnKind, SCDType2ByTimeKind -from sqlmesh.core.state_sync.engine_adapter.facade import _snapshot_to_json +from sqlmesh.core.state_sync.engine_adapter.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig, From a2b24abb0df6cfc0b9dfee4257777bfdbfe062ff Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 16:16:16 -0800 Subject: [PATCH 4/6] Chore: Refactor the version state from the state sync --- sqlmesh/core/state_sync/base.py | 5 +- .../core/state_sync/engine_adapter/facade.py | 201 ++++++------------ .../core/state_sync/engine_adapter/utils.py | 12 ++ .../core/state_sync/engine_adapter/version.py | 72 +++++++ tests/core/test_state_sync.py | 20 +- 5 files changed, 156 insertions(+), 154 deletions(-) create mode 100644 sqlmesh/core/state_sync/engine_adapter/version.py diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4113737970..f07e531fad 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -252,12 +252,9 @@ def raise_error( return versions @abc.abstractmethod - def _get_versions(self, lock_for_update: bool = False) -> Versions: + def _get_versions(self) -> Versions: """Queries the store to get the current versions of SQLMesh and deps. - Args: - lock_for_update: Whether or not the usage of this method plans to update the row. - Returns: The versions object. """ diff --git a/sqlmesh/core/state_sync/engine_adapter/facade.py b/sqlmesh/core/state_sync/engine_adapter/facade.py index 7658bfe6b7..6a29975c29 100644 --- a/sqlmesh/core/state_sync/engine_adapter/facade.py +++ b/sqlmesh/core/state_sync/engine_adapter/facade.py @@ -25,10 +25,8 @@ from pathlib import Path from datetime import datetime -import pandas as pd from sqlglot import __version__ as SQLGLOT_VERSION from sqlglot import exp -from sqlglot.helper import seq_get from sqlmesh.core import analytics from sqlmesh.core import constants as c @@ -55,7 +53,6 @@ ) from sqlmesh.core.state_sync.base import ( MIGRATIONS, - SCHEMA_VERSION, PromotionResult, StateSync, Versions, @@ -64,11 +61,12 @@ from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState +from sqlmesh.core.state_sync.engine_adapter.version import VersionState +from sqlmesh.core.state_sync.engine_adapter.utils import snapshot_id_filter, SQLMESH_VERSION from sqlmesh.utils import major_minor from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike, now_timestamp, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError -from sqlmesh.utils.migration import index_text_type logger = logging.getLogger(__name__) @@ -79,14 +77,6 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName -try: - # We can't import directly from the root package due to circular dependency - from sqlmesh._version import __version__ as SQLMESH_VERSION # type: ignore -except ImportError: - logger.error( - 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' - ) - class EngineAdapterStateSync(StateSync): """Manages state of nodes and snapshot with an existing engine adapter. @@ -116,19 +106,12 @@ def __init__( self.snapshot_state = SnapshotState( engine_adapter, schema=schema, context_path=context_path ) + self.version_state = VersionState(engine_adapter, schema=schema) # Make sure that if an empty string is provided that we treat it as None self.schema = schema or None self.engine_adapter = engine_adapter self.console = console or get_console() self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) - self.versions_table = exp.table_("_versions", db=self.schema) - - index_type = index_text_type(engine_adapter.dialect) - self._version_columns_to_types = { - "schema_version": exp.DataType.build("int"), - "sqlglot_version": exp.DataType.build(index_type), - "sqlmesh_version": exp.DataType.build(index_type), - } def _fetchone(self, query: t.Union[exp.Expression, str]) -> t.Optional[t.Tuple]: return self.engine_adapter.fetchone( @@ -173,7 +156,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: snapshots = snapshots_by_id.values() if snapshots: - self.snapshot_state.push_snapshots(snapshots) + self._push_snapshots(snapshots) @transactional() def promote( @@ -271,49 +254,6 @@ def promote( ), ) - def _ensure_no_gaps( - self, - target_snapshots: t.Iterable[Snapshot], - target_environment: Environment, - snapshot_names: t.Optional[t.Set[str]], - ) -> None: - target_snapshots_by_name = {s.name: s for s in target_snapshots} - - changed_version_prev_snapshots_by_name = { - s.name: s - for s in target_environment.snapshots - if s.name in target_snapshots_by_name - and target_snapshots_by_name[s.name].version != s.version - } - - prev_snapshots = self.get_snapshots( - changed_version_prev_snapshots_by_name.values() - ).values() - cache: t.Dict[str, datetime] = {} - - for prev_snapshot in prev_snapshots: - target_snapshot = target_snapshots_by_name[prev_snapshot.name] - if ( - (snapshot_names is None or prev_snapshot.name in snapshot_names) - and target_snapshot.is_incremental - and prev_snapshot.is_incremental - and prev_snapshot.intervals - ): - start = to_timestamp( - start_date(target_snapshot, target_snapshots_by_name.values(), cache) - ) - end = prev_snapshot.intervals[-1][1] - - if start < end: - missing_intervals = target_snapshot.missing_intervals( - start, end, end_bounded=True - ) - - if missing_intervals: - raise SQLMeshError( - f"Detected gaps in snapshot {target_snapshot.snapshot_id}: {missing_intervals}" - ) - @transactional() def finalize(self, environment: Environment) -> None: """Finalize the target environment, indicating that this environment has been @@ -330,28 +270,6 @@ def unpause_snapshots( ) -> None: self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state) - def _update_versions( - self, - schema_version: int = SCHEMA_VERSION, - sqlglot_version: str = SQLGLOT_VERSION, - sqlmesh_version: str = SQLMESH_VERSION, - ) -> None: - self.engine_adapter.delete_from(self.versions_table, "TRUE") - - self.engine_adapter.insert_append( - self.versions_table, - pd.DataFrame( - [ - { - "schema_version": schema_version, - "sqlglot_version": sqlglot_version, - "sqlmesh_version": sqlmesh_version, - } - ] - ), - columns_to_types=self._version_columns_to_types, - ) - def invalidate_environment(self, name: str) -> None: self.environment_state.invalidate_environment(name) @@ -385,7 +303,7 @@ def reset(self, default_catalog: t.Optional[str]) -> None: self.environment_state.environments_table, self.interval_state.intervals_table, self.plan_dags_table, - self.versions_table, + self.version_state.versions_table, ): self.engine_adapter.drop_table(table) self.snapshot_state.clear_cache() @@ -433,24 +351,6 @@ def get_snapshots( Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) return snapshots - def _get_versions(self, lock_for_update: bool = False) -> Versions: - no_version = Versions() - - if not self.engine_adapter.table_exists(self.versions_table): - return no_version - - query = exp.select("*").from_(self.versions_table) - if lock_for_update: - query.lock(copy=False) - - row = self._fetchone(query) - if not row: - return no_version - - return Versions( - schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) - ) - @transactional() def add_interval( self, @@ -507,6 +407,9 @@ def recycle(self) -> None: def close(self) -> None: self.engine_adapter.close() + def _get_versions(self) -> Versions: + return self.version_state.get_versions() + def _restore_table( self, table_name: TableName, @@ -540,7 +443,7 @@ def migrate( self._migrate_rows(promoted_snapshots_only) # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") - self._update_versions() + self.version_state.update_versions() analytics.collector.on_migration_end( from_sqlmesh_version=versions.sqlmesh_version, @@ -572,7 +475,7 @@ def rollback(self) -> None: tables = ( self.snapshot_state.snapshots_table, self.environment_state.environments_table, - self.versions_table, + self.version_state.versions_table, ) optional_tables = ( self.interval_state.intervals_table, @@ -606,7 +509,7 @@ def _backup_state(self) -> None: for table in ( self.snapshot_state.snapshots_table, self.environment_state.environments_table, - self.versions_table, + self.version_state.versions_table, self.interval_state.intervals_table, self.plan_dags_table, self.snapshot_state.auto_restatements_table, @@ -723,7 +626,7 @@ def _push_new_snapshots() -> None: ] if new_snapshots_to_push: logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) - self.snapshot_state.push_snapshots(new_snapshots_to_push) + self._push_snapshots(new_snapshots_to_push) new_snapshots.clear() snapshot_id_mapping.clear() @@ -861,41 +764,59 @@ def _snapshot_id_filter( self, snapshot_ids: t.Iterable[SnapshotIdLike], alias: t.Optional[str] = None, - batch_size: t.Optional[int] = None, ) -> t.Iterator[exp.Condition]: - name_identifiers = sorted( - {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} + yield from snapshot_id_filter( + self.engine_adapter, + snapshot_ids, + alias=alias, + batch_size=self.SNAPSHOT_BATCH_SIZE, ) - batches = self._batches(name_identifiers, batch_size=batch_size) - - if not name_identifiers: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for identifiers in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - exp.column("name", table=alias), - exp.column("identifier", table=alias), - ) - ), - ).isin(*identifiers) - else: - for identifiers in batches: - yield exp.or_( - *[ - exp.and_( - exp.column("name", table=alias).eq(name), - exp.column("identifier", table=alias).eq(identifier), - ) - for name, identifier in identifiers - ] + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: + self.snapshot_state.push_snapshots(snapshots, overwrite=overwrite) + + def _ensure_no_gaps( + self, + target_snapshots: t.Iterable[Snapshot], + target_environment: Environment, + snapshot_names: t.Optional[t.Set[str]], + ) -> None: + target_snapshots_by_name = {s.name: s for s in target_snapshots} + + changed_version_prev_snapshots_by_name = { + s.name: s + for s in target_environment.snapshots + if s.name in target_snapshots_by_name + and target_snapshots_by_name[s.name].version != s.version + } + + prev_snapshots = self.get_snapshots( + changed_version_prev_snapshots_by_name.values() + ).values() + cache: t.Dict[str, datetime] = {} + + for prev_snapshot in prev_snapshots: + target_snapshot = target_snapshots_by_name[prev_snapshot.name] + if ( + (snapshot_names is None or prev_snapshot.name in snapshot_names) + and target_snapshot.is_incremental + and prev_snapshot.is_incremental + and prev_snapshot.intervals + ): + start = to_timestamp( + start_date(target_snapshot, target_snapshots_by_name.values(), cache) ) + end = prev_snapshot.intervals[-1][1] + + if start < end: + missing_intervals = target_snapshot.missing_intervals( + start, end, end_bounded=True + ) - def _batches(self, l: t.List[T], batch_size: t.Optional[int] = None) -> t.List[t.List[T]]: - batch_size = batch_size or self.SNAPSHOT_BATCH_SIZE - return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] + if missing_intervals: + raise SQLMeshError( + f"Detected gaps in snapshot {target_snapshot.snapshot_id}: {missing_intervals}" + ) @contextlib.contextmanager def _transaction(self) -> t.Iterator[None]: diff --git a/sqlmesh/core/state_sync/engine_adapter/utils.py b/sqlmesh/core/state_sync/engine_adapter/utils.py index 951153da83..e5ffda6486 100644 --- a/sqlmesh/core/state_sync/engine_adapter/utils.py +++ b/sqlmesh/core/state_sync/engine_adapter/utils.py @@ -1,12 +1,24 @@ from __future__ import annotations import typing as t +import logging from sqlglot import exp from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot import SnapshotIdLike, SnapshotNameVersionLike +logger = logging.getLogger(__name__) + +try: + # We can't import directly from the root package due to circular dependency + from sqlmesh._version import __version__ as SQLMESH_VERSION # noqa +except ImportError: + logger.error( + 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' + ) + + T = t.TypeVar("T") diff --git a/sqlmesh/core/state_sync/engine_adapter/version.py b/sqlmesh/core/state_sync/engine_adapter/version.py new file mode 100644 index 0000000000..82cad4e21f --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/version.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import logging +import typing as t + +import pandas as pd +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp +from sqlglot.helper import seq_get + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.engine_adapter.utils import ( + fetchone, + SQLMESH_VERSION, +) +from sqlmesh.core.state_sync.base import ( + SCHEMA_VERSION, + Versions, +) +from sqlmesh.utils.migration import index_text_type + +logger = logging.getLogger(__name__) + + +class VersionState: + def __init__(self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None): + self.engine_adapter = engine_adapter + self.versions_table = exp.table_("_versions", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._version_columns_to_types = { + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build(index_type), + "sqlmesh_version": exp.DataType.build(index_type), + } + + def update_versions( + self, + schema_version: int = SCHEMA_VERSION, + sqlglot_version: str = SQLGLOT_VERSION, + sqlmesh_version: str = SQLMESH_VERSION, + ) -> None: + self.engine_adapter.delete_from(self.versions_table, "TRUE") + + self.engine_adapter.insert_append( + self.versions_table, + pd.DataFrame( + [ + { + "schema_version": schema_version, + "sqlglot_version": sqlglot_version, + "sqlmesh_version": sqlmesh_version, + } + ] + ), + columns_to_types=self._version_columns_to_types, + ) + + def get_versions(self) -> Versions: + no_version = Versions() + + if not self.engine_adapter.table_exists(self.versions_table): + return no_version + + query = exp.select("*").from_(self.versions_table) + row = fetchone(self.engine_adapter, query) + if not row: + return no_version + + return Versions( + schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) + ) diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 1e2d2eee3d..0e16367ca6 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -106,7 +106,7 @@ def promote_snapshots( def delete_versions(state_sync: EngineAdapterStateSync) -> None: - state_sync.engine_adapter.drop_table(state_sync.versions_table) + state_sync.engine_adapter.drop_table(state_sync.version_state.versions_table) def test_push_snapshots( @@ -2053,7 +2053,7 @@ def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: state_sync.migrate(default_catalog=None) # migration version is behind, always raise - state_sync._update_versions(schema_version=SCHEMA_VERSION + 1) + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION + 1) error = ( rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is behind '{SCHEMA_VERSION + 1}' \(remote\). " rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={SQLMESH_VERSION}"' command\).""" @@ -2066,7 +2066,7 @@ def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: state_sync.get_versions(validate=False) # migration version is ahead, only raise when validate is true - state_sync._update_versions(schema_version=SCHEMA_VERSION - 1) + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION - 1) with pytest.raises( SQLMeshError, match=rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is ahead of '{SCHEMA_VERSION - 1}'", @@ -2087,7 +2087,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: else f"{int(patch) + 1}" ) sqlmesh_version_patch_bump = f"{major}.{minor}.{new_patch}" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_patch_bump) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_patch_bump) state_sync.get_versions(validate=False) # sqlmesh version is behind @@ -2096,7 +2096,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is behind '{sqlmesh_version_minor_bump}' \(remote\). " rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={sqlmesh_version_minor_bump}"' command\).""" ) - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_bump) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_bump) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2104,7 +2104,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: # sqlmesh version is ahead sqlmesh_version_minor_decrease = f"{major}.{int(minor) - 1}.{patch}" error = rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is ahead of '{sqlmesh_version_minor_decrease}'" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2114,7 +2114,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: # patch version sqlglot doesn't matter major, minor, patch, *_ = SQLGLOT_VERSION.split(".") sqlglot_version = f"{major}.{minor}.{int(patch) + 1}" - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) state_sync.get_versions(validate=False) # sqlglot version is behind @@ -2123,7 +2123,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is behind '{sqlglot_version}' \(remote\). " rf"""Please upgrade SQLGlot \('pip install --upgrade "sqlglot=={sqlglot_version}"' command\).""" ) - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2131,7 +2131,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: # sqlglot version is ahead sqlglot_version = f"{major}.{int(minor) - 1}.{patch}" error = rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is ahead of '{sqlglot_version}'" - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2221,7 +2221,7 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> assert not state_sync.engine_adapter.table_exists( state_sync.environment_state.environments_table ) - assert not state_sync.engine_adapter.table_exists(state_sync.versions_table) + assert not state_sync.engine_adapter.table_exists(state_sync.version_state.versions_table) assert not state_sync.engine_adapter.table_exists(state_sync.interval_state.intervals_table) From b152421a919c99247b5e8aec280c886a49e39f0d Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 19:08:14 -0800 Subject: [PATCH 5/6] Chore: Refactor the state migration logic from the state sync --- .../core/state_sync/engine_adapter/facade.py | 430 +--------------- .../state_sync/engine_adapter/migrator.py | 459 ++++++++++++++++++ tests/core/test_state_sync.py | 22 +- 3 files changed, 494 insertions(+), 417 deletions(-) create mode 100644 sqlmesh/core/state_sync/engine_adapter/migrator.py diff --git a/sqlmesh/core/state_sync/engine_adapter/facade.py b/sqlmesh/core/state_sync/engine_adapter/facade.py index 6a29975c29..67e24140a4 100644 --- a/sqlmesh/core/state_sync/engine_adapter/facade.py +++ b/sqlmesh/core/state_sync/engine_adapter/facade.py @@ -17,26 +17,18 @@ from __future__ import annotations import contextlib -import json import logging -import time import typing as t -from copy import deepcopy from pathlib import Path from datetime import datetime -from sqlglot import __version__ as SQLGLOT_VERSION from sqlglot import exp -from sqlmesh.core import analytics -from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment from sqlmesh.core.snapshot import ( - Node, Snapshot, - SnapshotFingerprint, SnapshotId, SnapshotIdLike, SnapshotInfoLike, @@ -44,15 +36,12 @@ SnapshotNameVersion, SnapshotTableCleanupTask, SnapshotTableInfo, - fingerprint_from_node, start_date, ) from sqlmesh.core.snapshot.definition import ( Interval, - _parents_from_node, ) from sqlmesh.core.state_sync.base import ( - MIGRATIONS, PromotionResult, StateSync, Versions, @@ -62,10 +51,8 @@ from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState from sqlmesh.core.state_sync.engine_adapter.version import VersionState -from sqlmesh.core.state_sync.engine_adapter.utils import snapshot_id_filter, SQLMESH_VERSION -from sqlmesh.utils import major_minor -from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now_timestamp, to_timestamp +from sqlmesh.core.state_sync.engine_adapter.migrator import StateMigrator +from sqlmesh.utils.date import TimeLike, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError logger = logging.getLogger(__name__) @@ -91,9 +78,6 @@ class EngineAdapterStateSync(StateSync): context_path: The context path, used for caching snapshot models. """ - SNAPSHOT_BATCH_SIZE = 1000 - SNAPSHOT_MIGRATION_BATCH_SIZE = 500 - def __init__( self, engine_adapter: EngineAdapter, @@ -101,27 +85,26 @@ def __init__( console: t.Optional[Console] = None, context_path: Path = Path(), ): + self.plan_dags_table = exp.table_("_plan_dags", db=schema) self.interval_state = IntervalState(engine_adapter, schema=schema) self.environment_state = EnvironmentState(engine_adapter, schema=schema) self.snapshot_state = SnapshotState( engine_adapter, schema=schema, context_path=context_path ) self.version_state = VersionState(engine_adapter, schema=schema) + self.migrator = StateMigrator( + engine_adapter, + version_state=self.version_state, + snapshot_state=self.snapshot_state, + environment_state=self.environment_state, + interval_state=self.interval_state, + plan_dags_table=self.plan_dags_table, + console=console, + ) # Make sure that if an empty string is provided that we treat it as None self.schema = schema or None self.engine_adapter = engine_adapter self.console = console or get_console() - self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) - - def _fetchone(self, query: t.Union[exp.Expression, str]) -> t.Optional[t.Tuple]: - return self.engine_adapter.fetchone( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - - def _fetchall(self, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: - return self.engine_adapter.fetchall( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) @transactional() def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: @@ -407,20 +390,6 @@ def recycle(self) -> None: def close(self) -> None: self.engine_adapter.close() - def _get_versions(self) -> Versions: - return self.version_state.get_versions() - - def _restore_table( - self, - table_name: TableName, - backup_table_name: TableName, - ) -> None: - self.engine_adapter.drop_table(table_name) - self.engine_adapter.rename_table( - old_table_name=backup_table_name, - new_table_name=table_name, - ) - @transactional() def migrate( self, @@ -429,352 +398,27 @@ def migrate( promoted_snapshots_only: bool = True, ) -> None: """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" - versions = self.get_versions(validate=False) - - migration_start_ts = time.perf_counter() - - try: - migrate_rows = self._apply_migrations(default_catalog, skip_backup) - - if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: - return - - if migrate_rows: - self._migrate_rows(promoted_snapshots_only) - # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. - self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") - self.version_state.update_versions() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - ) - except Exception as e: - if skip_backup: - logger.error("Backup was skipped so no rollback was attempted.") - else: - self.rollback() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - error=e, - ) - - self.console.log_migration_status(success=False) - raise SQLMeshError("SQLMesh migration failed.") from e - - self.console.log_migration_status() + self.migrator.migrate( + self, + default_catalog, + skip_backup=skip_backup, + promoted_snapshots_only=promoted_snapshots_only, + ) @transactional() def rollback(self) -> None: """Rollback to the previous migration.""" - logger.info("Starting migration rollback.") - tables = ( - self.snapshot_state.snapshots_table, - self.environment_state.environments_table, - self.version_state.versions_table, - ) - optional_tables = ( - self.interval_state.intervals_table, - self.plan_dags_table, - self.snapshot_state.auto_restatements_table, - ) - versions = self.get_versions(validate=False) - if versions.schema_version == 0: - # Clean up state tables - for table in tables + optional_tables: - self.engine_adapter.drop_table(table) - else: - if not all( - self.engine_adapter.table_exists(_backup_table_name(table)) for table in tables - ): - raise SQLMeshError("There are no prior migrations to roll back to.") - for table in tables: - self._restore_table(table, _backup_table_name(table)) - - for optional_table in optional_tables: - if self.engine_adapter.table_exists(_backup_table_name(optional_table)): - self._restore_table(optional_table, _backup_table_name(optional_table)) - - logger.info("Migration rollback successful.") + self.migrator.rollback() def state_type(self) -> str: return self.engine_adapter.dialect - @transactional() - def _backup_state(self) -> None: - for table in ( - self.snapshot_state.snapshots_table, - self.environment_state.environments_table, - self.version_state.versions_table, - self.interval_state.intervals_table, - self.plan_dags_table, - self.snapshot_state.auto_restatements_table, - ): - if self.engine_adapter.table_exists(table): - backup_name = _backup_table_name(table) - self.engine_adapter.drop_table(backup_name) - self.engine_adapter.create_table_like(backup_name, table) - self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) - - def _apply_migrations( - self, - default_catalog: t.Optional[str], - skip_backup: bool, - ) -> bool: - versions = self.get_versions(validate=False) - migrations = MIGRATIONS[versions.schema_version :] - should_backup = any( - [ - migrations, - major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version, - major_minor(SQLMESH_VERSION) != versions.minor_sqlmesh_version, - ] - ) - if not skip_backup and should_backup: - self._backup_state() - - snapshot_count_before = self.snapshot_state.count() if versions.schema_version else None - - for migration in migrations: - logger.info(f"Applying migration {migration}") - migration.migrate(self, default_catalog=default_catalog) - - snapshot_count_after = self.snapshot_state.count() - - if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: - scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" - raise SQLMeshError( - f"Number of snapshots before ({snapshot_count_before}) and after " - f"({snapshot_count_after}) applying migration scripts {scripts} does not match. " - "Please file an issue issue at https://github.com/TobikoData/sqlmesh/issues/new." - ) - - migrate_snapshots_and_environments = ( - bool(migrations) or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version - ) - return migrate_snapshots_and_environments - - def _migrate_rows(self, promoted_snapshots_only: bool) -> None: - logger.info("Fetching environments") - environments = self.get_environments() - # Only migrate snapshots that are part of at least one environment. - snapshots_to_migrate = ( - {s.snapshot_id for e in environments for s in e.snapshots} - if promoted_snapshots_only - else None - ) - snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) - if not snapshot_mapping: - logger.info("No changes to snapshots detected") - return - self._migrate_environment_rows(environments, snapshot_mapping) - - def _migrate_snapshot_rows( - self, snapshots: t.Optional[t.Set[SnapshotId]] - ) -> t.Dict[SnapshotId, SnapshotTableInfo]: - logger.info("Migrating snapshot rows...") - raw_snapshots = { - SnapshotId(name=name, identifier=identifier): { - **json.loads(raw_snapshot), - "updated_ts": updated_ts, - "unpaused_ts": unpaused_ts, - "unrestorable": unrestorable, - } - for where in (self._snapshot_id_filter(snapshots) if snapshots is not None else [None]) - for name, identifier, raw_snapshot, updated_ts, unpaused_ts, unrestorable in self._fetchall( - exp.select( - "name", "identifier", "snapshot", "updated_ts", "unpaused_ts", "unrestorable" - ) - .from_(self.snapshot_state.snapshots_table) - .where(where) - .lock() - ) - } - if not raw_snapshots: - return {} - - dag: DAG[SnapshotId] = DAG() - for snapshot_id, raw_snapshot in raw_snapshots.items(): - parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] - dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) - - reversed_dag_raw = dag.reversed.graph - - self.console.start_snapshot_migration_progress(len(raw_snapshots)) - - parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) - all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} - snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} - new_snapshots: t.Dict[SnapshotId, Snapshot] = {} - visited: t.Set[SnapshotId] = set() - - def _push_new_snapshots() -> None: - all_snapshot_mapping.update( - { - from_id: new_snapshots[to_id].table_info - for from_id, to_id in snapshot_id_mapping.items() - } - ) - - existing_new_snapshots = self.snapshots_exist(new_snapshots) - new_snapshots_to_push = [ - s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots - ] - if new_snapshots_to_push: - logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) - self._push_snapshots(new_snapshots_to_push) - new_snapshots.clear() - snapshot_id_mapping.clear() - - def _visit( - snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] - ) -> None: - if snapshot_id in visited or snapshot_id not in raw_snapshots: - return - visited.add(snapshot_id) - - snapshot = parsed_snapshots[snapshot_id] - node = snapshot.node - - node_seen = set() - node_queue = {snapshot_id} - nodes: t.Dict[str, Node] = {} - while node_queue: - next_snapshot_id = node_queue.pop() - next_snapshot = parsed_snapshots.get(next_snapshot_id) - - if next_snapshot_id in node_seen or not next_snapshot: - continue - - node_seen.add(next_snapshot_id) - node_queue.update(next_snapshot.parents) - nodes[next_snapshot.name] = next_snapshot.node - - new_snapshot = deepcopy(snapshot) - try: - new_snapshot.fingerprint = fingerprint_from_node( - node, - nodes=nodes, - cache=fingerprint_cache, - ) - new_snapshot.parents = tuple( - SnapshotId( - name=parent_node.fqn, - identifier=fingerprint_from_node( - parent_node, - nodes=nodes, - cache=fingerprint_cache, - ).to_identifier(), - ) - for parent_node in _parents_from_node(node, nodes).values() - ) - except Exception: - logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) - return - - # Reset the effective_from date for the new snapshot to avoid unexpected backfills. - new_snapshot.effective_from = None - new_snapshot.previous_versions = snapshot.all_versions - new_snapshot.migrated = True - if not new_snapshot.dev_version_: - new_snapshot.dev_version_ = snapshot.dev_version - - self.console.update_snapshot_migration_progress(1) - - # Visit children and evict them from the parsed_snapshots cache after. - for child in reversed_dag_raw.get(snapshot_id, []): - # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. - _visit(child, fingerprint_cache.copy()) - parsed_snapshots.evict(child) - - if new_snapshot.fingerprint == snapshot.fingerprint: - logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") - return - - new_snapshot_id = new_snapshot.snapshot_id - - if new_snapshot_id in raw_snapshots: - # Mapped to an existing snapshot. - new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] - logger.debug("Migrated snapshot %s already exists", new_snapshot_id) - elif ( - new_snapshot_id not in new_snapshots - or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts - ): - new_snapshots[new_snapshot_id] = new_snapshot - - snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id - logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) - - if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: - _push_new_snapshots() - - for root_snapshot_id in dag.roots: - _visit(root_snapshot_id, {}) - - if new_snapshots: - _push_new_snapshots() - - self.console.stop_snapshot_migration_progress() - return all_snapshot_mapping - - def _migrate_environment_rows( - self, - environments: t.List[Environment], - snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], - ) -> None: - logger.info("Migrating environment rows...") - - updated_prod_environment: t.Optional[Environment] = None - updated_environments = [] - for environment in environments: - snapshots = [ - ( - snapshot_mapping[info.snapshot_id] - if info.snapshot_id in snapshot_mapping - else info - ) - for info in environment.snapshots - ] - - if snapshots != environment.snapshots: - environment.snapshots_ = snapshots - updated_environments.append(environment) - if environment.name == c.PROD: - updated_prod_environment = environment - self.console.start_env_migration_progress(len(updated_environments)) - - for environment in updated_environments: - self.environment_state.update_environment(environment) - self.console.update_env_migration_progress(1) - - if updated_prod_environment: - try: - self.unpause_snapshots(updated_prod_environment.snapshots, now_timestamp()) - except Exception: - logger.warning("Failed to unpause migrated snapshots", exc_info=True) - - self.console.stop_env_migration_progress() - - def _snapshot_id_filter( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - alias: t.Optional[str] = None, - ) -> t.Iterator[exp.Condition]: - yield from snapshot_id_filter( - self.engine_adapter, - snapshot_ids, - alias=alias, - batch_size=self.SNAPSHOT_BATCH_SIZE, - ) - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: self.snapshot_state.push_snapshots(snapshots, overwrite=overwrite) + def _get_versions(self) -> Versions: + return self.version_state.get_versions() + def _ensure_no_gaps( self, target_snapshots: t.Iterable[Snapshot], @@ -822,33 +466,3 @@ def _ensure_no_gaps( def _transaction(self) -> t.Iterator[None]: with self.engine_adapter.transaction(): yield - - -def _backup_table_name(table_name: TableName) -> exp.Table: - table = exp.to_table(table_name).copy() - table.set("this", exp.to_identifier(table.name + "_backup")) - return table - - -class LazilyParsedSnapshots: - def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): - self._raw_snapshots = raw_snapshots - self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} - - def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: - if snapshot_id not in self._parsed_snapshots: - raw_snapshot = self._raw_snapshots.get(snapshot_id) - if raw_snapshot: - self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) - else: - self._parsed_snapshots[snapshot_id] = None - return self._parsed_snapshots[snapshot_id] - - def evict(self, snapshot_id: SnapshotId) -> None: - self._parsed_snapshots.pop(snapshot_id, None) - - def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: - snapshot = self.get(snapshot_id) - if snapshot is None: - raise KeyError(snapshot_id) - return snapshot diff --git a/sqlmesh/core/state_sync/engine_adapter/migrator.py b/sqlmesh/core/state_sync/engine_adapter/migrator.py new file mode 100644 index 0000000000..c2a3f78b4b --- /dev/null +++ b/sqlmesh/core/state_sync/engine_adapter/migrator.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import json +import logging +import time +import typing as t +from copy import deepcopy + +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp + +from sqlmesh.core import analytics +from sqlmesh.core import constants as c +from sqlmesh.core.console import Console, get_console +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import ( + Node, + Snapshot, + SnapshotFingerprint, + SnapshotId, + SnapshotTableInfo, + fingerprint_from_node, +) +from sqlmesh.core.snapshot.definition import ( + _parents_from_node, +) +from sqlmesh.core.state_sync.base import ( + MIGRATIONS, +) +from sqlmesh.core.state_sync.base import StateSync +from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState +from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState +from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState +from sqlmesh.core.state_sync.engine_adapter.version import VersionState +from sqlmesh.core.state_sync.engine_adapter.utils import ( + SQLMESH_VERSION, + snapshot_id_filter, + fetchall, +) +from sqlmesh.utils import major_minor +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.errors import SQLMeshError + +logger = logging.getLogger(__name__) + + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName + + +class StateMigrator: + SNAPSHOT_BATCH_SIZE = 1000 + SNAPSHOT_MIGRATION_BATCH_SIZE = 500 + + def __init__( + self, + engine_adapter: EngineAdapter, + version_state: VersionState, + snapshot_state: SnapshotState, + environment_state: EnvironmentState, + interval_state: IntervalState, + plan_dags_table: TableName, + console: t.Optional[Console] = None, + ): + self.engine_adapter = engine_adapter + self.console = console or get_console() + self.version_state = version_state + self.snapshot_state = snapshot_state + self.environment_state = environment_state + self.interval_state = interval_state + self.plan_dags_table = plan_dags_table + + self._state_tables = [ + self.snapshot_state.snapshots_table, + self.environment_state.environments_table, + self.version_state.versions_table, + ] + self._optional_state_tables = [ + self.interval_state.intervals_table, + self.plan_dags_table, + self.snapshot_state.auto_restatements_table, + ] + + def migrate( + self, + state_sync: StateSync, + default_catalog: t.Optional[str], + skip_backup: bool = False, + promoted_snapshots_only: bool = True, + ) -> None: + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + versions = self.version_state.get_versions() + migration_start_ts = time.perf_counter() + + try: + migrate_rows = self._apply_migrations(state_sync, default_catalog, skip_backup) + + if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: + return + + if migrate_rows: + self._migrate_rows(promoted_snapshots_only) + # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. + self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") + self.version_state.update_versions() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + ) + except Exception as e: + if skip_backup: + logger.error("Backup was skipped so no rollback was attempted.") + else: + self.rollback() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + error=e, + ) + + self.console.log_migration_status(success=False) + raise SQLMeshError("SQLMesh migration failed.") from e + + self.console.log_migration_status() + + def rollback(self) -> None: + """Rollback to the previous migration.""" + logger.info("Starting migration rollback.") + versions = self.version_state.get_versions() + if versions.schema_version == 0: + # Clean up state tables + for table in self._state_tables + self._optional_state_tables: + self.engine_adapter.drop_table(table) + else: + if not all( + self.engine_adapter.table_exists(_backup_table_name(table)) + for table in self._state_tables + ): + raise SQLMeshError("There are no prior migrations to roll back to.") + for table in self._state_tables: + self._restore_table(table, _backup_table_name(table)) + + for optional_table in self._optional_state_tables: + if self.engine_adapter.table_exists(_backup_table_name(optional_table)): + self._restore_table(optional_table, _backup_table_name(optional_table)) + + logger.info("Migration rollback successful.") + + def _apply_migrations( + self, + state_sync: StateSync, + default_catalog: t.Optional[str], + skip_backup: bool, + ) -> bool: + versions = self.version_state.get_versions() + migrations = MIGRATIONS[versions.schema_version :] + should_backup = any( + [ + migrations, + major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version, + major_minor(SQLMESH_VERSION) != versions.minor_sqlmesh_version, + ] + ) + if not skip_backup and should_backup: + self._backup_state() + + snapshot_count_before = self.snapshot_state.count() if versions.schema_version else None + + for migration in migrations: + logger.info(f"Applying migration {migration}") + migration.migrate(state_sync, default_catalog=default_catalog) + + snapshot_count_after = self.snapshot_state.count() + + if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: + scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" + raise SQLMeshError( + f"Number of snapshots before ({snapshot_count_before}) and after " + f"({snapshot_count_after}) applying migration scripts {scripts} does not match. " + "Please file an issue issue at https://github.com/TobikoData/sqlmesh/issues/new." + ) + + migrate_snapshots_and_environments = ( + bool(migrations) or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version + ) + return migrate_snapshots_and_environments + + def _migrate_rows(self, promoted_snapshots_only: bool) -> None: + logger.info("Fetching environments") + environments = self.environment_state.get_environments() + # Only migrate snapshots that are part of at least one environment. + snapshots_to_migrate = ( + {s.snapshot_id for e in environments for s in e.snapshots} + if promoted_snapshots_only + else None + ) + snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) + if not snapshot_mapping: + logger.info("No changes to snapshots detected") + return + self._migrate_environment_rows(environments, snapshot_mapping) + + def _migrate_snapshot_rows( + self, snapshots: t.Optional[t.Set[SnapshotId]] + ) -> t.Dict[SnapshotId, SnapshotTableInfo]: + logger.info("Migrating snapshot rows...") + raw_snapshots = { + SnapshotId(name=name, identifier=identifier): { + **json.loads(raw_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + } + for where in ( + snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + if snapshots is not None + else [None] + ) + for name, identifier, raw_snapshot, updated_ts, unpaused_ts, unrestorable in fetchall( + self.engine_adapter, + exp.select( + "name", "identifier", "snapshot", "updated_ts", "unpaused_ts", "unrestorable" + ) + .from_(self.snapshot_state.snapshots_table) + .where(where) + .lock(), + ) + } + if not raw_snapshots: + return {} + + dag: DAG[SnapshotId] = DAG() + for snapshot_id, raw_snapshot in raw_snapshots.items(): + parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] + dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) + + reversed_dag_raw = dag.reversed.graph + + self.console.start_snapshot_migration_progress(len(raw_snapshots)) + + parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) + all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} + snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} + new_snapshots: t.Dict[SnapshotId, Snapshot] = {} + visited: t.Set[SnapshotId] = set() + + def _push_new_snapshots() -> None: + all_snapshot_mapping.update( + { + from_id: new_snapshots[to_id].table_info + for from_id, to_id in snapshot_id_mapping.items() + } + ) + + existing_new_snapshots = self.snapshot_state.snapshots_exist(new_snapshots) + new_snapshots_to_push = [ + s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots + ] + if new_snapshots_to_push: + logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) + self._push_snapshots(new_snapshots_to_push) + new_snapshots.clear() + snapshot_id_mapping.clear() + + def _visit( + snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] + ) -> None: + if snapshot_id in visited or snapshot_id not in raw_snapshots: + return + visited.add(snapshot_id) + + snapshot = parsed_snapshots[snapshot_id] + node = snapshot.node + + node_seen = set() + node_queue = {snapshot_id} + nodes: t.Dict[str, Node] = {} + while node_queue: + next_snapshot_id = node_queue.pop() + next_snapshot = parsed_snapshots.get(next_snapshot_id) + + if next_snapshot_id in node_seen or not next_snapshot: + continue + + node_seen.add(next_snapshot_id) + node_queue.update(next_snapshot.parents) + nodes[next_snapshot.name] = next_snapshot.node + + new_snapshot = deepcopy(snapshot) + try: + new_snapshot.fingerprint = fingerprint_from_node( + node, + nodes=nodes, + cache=fingerprint_cache, + ) + new_snapshot.parents = tuple( + SnapshotId( + name=parent_node.fqn, + identifier=fingerprint_from_node( + parent_node, + nodes=nodes, + cache=fingerprint_cache, + ).to_identifier(), + ) + for parent_node in _parents_from_node(node, nodes).values() + ) + except Exception: + logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) + return + + # Reset the effective_from date for the new snapshot to avoid unexpected backfills. + new_snapshot.effective_from = None + new_snapshot.previous_versions = snapshot.all_versions + new_snapshot.migrated = True + if not new_snapshot.dev_version_: + new_snapshot.dev_version_ = snapshot.dev_version + + self.console.update_snapshot_migration_progress(1) + + # Visit children and evict them from the parsed_snapshots cache after. + for child in reversed_dag_raw.get(snapshot_id, []): + # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. + _visit(child, fingerprint_cache.copy()) + parsed_snapshots.evict(child) + + if new_snapshot.fingerprint == snapshot.fingerprint: + logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") + return + + new_snapshot_id = new_snapshot.snapshot_id + + if new_snapshot_id in raw_snapshots: + # Mapped to an existing snapshot. + new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] + logger.debug("Migrated snapshot %s already exists", new_snapshot_id) + elif ( + new_snapshot_id not in new_snapshots + or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts + ): + new_snapshots[new_snapshot_id] = new_snapshot + + snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id + logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) + + if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: + _push_new_snapshots() + + for root_snapshot_id in dag.roots: + _visit(root_snapshot_id, {}) + + if new_snapshots: + _push_new_snapshots() + + self.console.stop_snapshot_migration_progress() + return all_snapshot_mapping + + def _migrate_environment_rows( + self, + environments: t.List[Environment], + snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], + ) -> None: + logger.info("Migrating environment rows...") + + updated_prod_environment: t.Optional[Environment] = None + updated_environments = [] + for environment in environments: + snapshots = [ + ( + snapshot_mapping[info.snapshot_id] + if info.snapshot_id in snapshot_mapping + else info + ) + for info in environment.snapshots + ] + + if snapshots != environment.snapshots: + environment.snapshots_ = snapshots + updated_environments.append(environment) + if environment.name == c.PROD: + updated_prod_environment = environment + self.console.start_env_migration_progress(len(updated_environments)) + + for environment in updated_environments: + self.environment_state.update_environment(environment) + self.console.update_env_migration_progress(1) + + if updated_prod_environment: + try: + self.snapshot_state.unpause_snapshots( + updated_prod_environment.snapshots, now_timestamp(), self.interval_state + ) + except Exception: + logger.warning("Failed to unpause migrated snapshots", exc_info=True) + + self.console.stop_env_migration_progress() + + def _backup_state(self) -> None: + for table in [ + *self._state_tables, + *self._optional_state_tables, + ]: + if self.engine_adapter.table_exists(table): + with self.engine_adapter.transaction(): + backup_name = _backup_table_name(table) + self.engine_adapter.drop_table(backup_name) + self.engine_adapter.create_table_like(backup_name, table) + self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) + + def _restore_table( + self, + table_name: TableName, + backup_table_name: TableName, + ) -> None: + self.engine_adapter.drop_table(table_name) + self.engine_adapter.rename_table( + old_table_name=backup_table_name, + new_table_name=table_name, + ) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + self.snapshot_state.push_snapshots(snapshots) + + +def _backup_table_name(table_name: TableName) -> exp.Table: + table = exp.to_table(table_name).copy() + table.set("this", exp.to_identifier(table.name + "_backup")) + return table + + +class LazilyParsedSnapshots: + def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): + self._raw_snapshots = raw_snapshots + self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} + + def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: + if snapshot_id not in self._parsed_snapshots: + raw_snapshot = self._raw_snapshots.get(snapshot_id) + if raw_snapshot: + self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) + else: + self._parsed_snapshots[snapshot_id] = None + return self._parsed_snapshots[snapshot_id] + + def evict(self, snapshot_id: SnapshotId) -> None: + self._parsed_snapshots.pop(snapshot_id, None) + + def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: + snapshot = self.get(snapshot_id) + if snapshot is None: + raise KeyError(snapshot_id) + return snapshot diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index 0e16367ca6..c435ea75ff 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -403,7 +403,7 @@ def test_refresh_snapshot_intervals( def test_get_snapshot_intervals( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, get_snapshot_intervals ) -> None: - state_sync.SNAPSHOT_BATCH_SIZE = 1 + state_sync.interval_state.SNAPSHOT_BATCH_SIZE = 1 snapshot_a = make_snapshot( SqlModel( @@ -1164,7 +1164,7 @@ def test_delete_expired_snapshots_seed( def test_delete_expired_snapshots_batching( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): - state_sync.SNAPSHOT_BATCH_SIZE = 1 + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 1 now_ts = now_timestamp() snapshot_a = make_snapshot( @@ -2150,8 +2150,12 @@ def test_empty_versions() -> None: def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture, tmp_path) -> None: from sqlmesh import __version__ as SQLMESH_VERSION - migrate_rows_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._migrate_rows") - backup_state_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._backup_state") + migrate_rows_mock = mocker.patch( + "sqlmesh.core.state_sync.engine_adapter.migrator.StateMigrator._migrate_rows" + ) + backup_state_mock = mocker.patch( + "sqlmesh.core.state_sync.engine_adapter.migrator.StateMigrator._backup_state" + ) state_sync.migrate(default_catalog=None) migrate_rows_mock.assert_not_called() backup_state_mock.assert_not_called() @@ -2185,8 +2189,8 @@ def test_rollback(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> ): state_sync.rollback() - restore_table_spy = mocker.spy(state_sync, "_restore_table") - state_sync._backup_state() + restore_table_spy = mocker.spy(state_sync.migrator, "_restore_table") + state_sync.migrator._backup_state() state_sync.rollback() calls = {(a.sql(), b.sql()) for (a, b), _ in restore_table_spy.call_args_list} @@ -2211,7 +2215,7 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> state_sync = EngineAdapterStateSync( create_engine_adapter(lambda: duck_conn, "duckdb"), schema=c.SQLMESH, context_path=tmp_path ) - mocker.patch.object(state_sync, "_migrate_rows", side_effect=Exception("mocked error")) + mocker.patch.object(state_sync.migrator, "_migrate_rows", side_effect=Exception("mocked error")) with pytest.raises( SQLMeshError, match="SQLMesh migration failed.", @@ -2319,7 +2323,7 @@ def test_backup_state(state_sync: EngineAdapterStateSync, mocker: MockerFixture) }, ) - state_sync._backup_state() + state_sync.migrator._backup_state() pd.testing.assert_frame_equal( state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots"), state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots_backup"), @@ -2344,7 +2348,7 @@ def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: "select count(*) from sqlmesh._snapshots" ) assert old_snapshots_count == (12,) - state_sync._backup_state() + state_sync.migrator._backup_state() state_sync.engine_adapter.delete_from("sqlmesh._snapshots", "TRUE") snapshots_count = state_sync.engine_adapter.fetchone("select count(*) from sqlmesh._snapshots") From 7a50d5d34f453cf073268b26f2bcfa46d7d16177 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 25 Feb 2025 20:40:26 -0800 Subject: [PATCH 6/6] Rename the package to db --- sqlmesh/core/state_sync/__init__.py | 2 +- sqlmesh/core/state_sync/db/__init__.py | 3 +++ .../{engine_adapter => db}/environment.py | 2 +- .../state_sync/{engine_adapter => db}/facade.py | 17 +++++++++-------- .../{engine_adapter => db}/interval.py | 2 +- .../{engine_adapter => db}/migrator.py | 15 +++++++++------ .../{engine_adapter => db}/snapshot.py | 11 +++-------- .../state_sync/{engine_adapter => db}/utils.py | 0 .../{engine_adapter => db}/version.py | 2 +- .../core/state_sync/engine_adapter/__init__.py | 3 --- tests/core/test_context.py | 2 +- tests/core/test_environment.py | 2 +- tests/core/test_state_sync.py | 12 +++++------- tests/dbt/test_transformation.py | 2 +- 14 files changed, 36 insertions(+), 39 deletions(-) create mode 100644 sqlmesh/core/state_sync/db/__init__.py rename sqlmesh/core/state_sync/{engine_adapter => db}/environment.py (99%) rename sqlmesh/core/state_sync/{engine_adapter => db}/facade.py (96%) rename sqlmesh/core/state_sync/{engine_adapter => db}/interval.py (99%) rename sqlmesh/core/state_sync/{engine_adapter => db}/migrator.py (97%) rename sqlmesh/core/state_sync/{engine_adapter => db}/snapshot.py (98%) rename sqlmesh/core/state_sync/{engine_adapter => db}/utils.py (100%) rename sqlmesh/core/state_sync/{engine_adapter => db}/version.py (97%) delete mode 100644 sqlmesh/core/state_sync/engine_adapter/__init__.py diff --git a/sqlmesh/core/state_sync/__init__.py b/sqlmesh/core/state_sync/__init__.py index 4626de3c16..1585d6211f 100644 --- a/sqlmesh/core/state_sync/__init__.py +++ b/sqlmesh/core/state_sync/__init__.py @@ -21,4 +21,4 @@ ) from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync from sqlmesh.core.state_sync.common import cleanup_expired_views as cleanup_expired_views -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync as EngineAdapterStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync as EngineAdapterStateSync diff --git a/sqlmesh/core/state_sync/db/__init__.py b/sqlmesh/core/state_sync/db/__init__.py new file mode 100644 index 0000000000..3292449359 --- /dev/null +++ b/sqlmesh/core/state_sync/db/__init__.py @@ -0,0 +1,3 @@ +from sqlmesh.core.state_sync.db.facade import EngineAdapterStateSync + +__all__ = ["EngineAdapterStateSync"] diff --git a/sqlmesh/core/state_sync/engine_adapter/environment.py b/sqlmesh/core/state_sync/db/environment.py similarity index 99% rename from sqlmesh/core/state_sync/engine_adapter/environment.py rename to sqlmesh/core/state_sync/db/environment.py index 1150d3f9e2..8fcc6787e1 100644 --- a/sqlmesh/core/state_sync/engine_adapter/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -8,7 +8,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.state_sync.engine_adapter.utils import ( +from sqlmesh.core.state_sync.db.utils import ( fetchall, fetchone, ) diff --git a/sqlmesh/core/state_sync/engine_adapter/facade.py b/sqlmesh/core/state_sync/db/facade.py similarity index 96% rename from sqlmesh/core/state_sync/engine_adapter/facade.py rename to sqlmesh/core/state_sync/db/facade.py index 67e24140a4..3210d5ea35 100644 --- a/sqlmesh/core/state_sync/engine_adapter/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -47,11 +47,11 @@ Versions, ) from sqlmesh.core.state_sync.common import transactional -from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState -from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState -from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState -from sqlmesh.core.state_sync.engine_adapter.version import VersionState -from sqlmesh.core.state_sync.engine_adapter.migrator import StateMigrator +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.migrator import StateMigrator from sqlmesh.utils.date import TimeLike, to_timestamp from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError @@ -62,7 +62,7 @@ if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName + pass class EngineAdapterStateSync(StateSync): @@ -266,6 +266,7 @@ def delete_expired_snapshots( self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) return cleanup_targets + @transactional() def delete_expired_environments(self) -> t.List[Environment]: return self.environment_state.delete_expired_environments() @@ -413,8 +414,8 @@ def rollback(self) -> None: def state_type(self) -> str: return self.engine_adapter.dialect - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: - self.snapshot_state.push_snapshots(snapshots, overwrite=overwrite) + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + self.snapshot_state.push_snapshots(snapshots) def _get_versions(self) -> Versions: return self.version_state.get_versions() diff --git a/sqlmesh/core/state_sync/engine_adapter/interval.py b/sqlmesh/core/state_sync/db/interval.py similarity index 99% rename from sqlmesh/core/state_sync/engine_adapter/interval.py rename to sqlmesh/core/state_sync/db/interval.py index e44b11ada1..944e4b650c 100644 --- a/sqlmesh/core/state_sync/engine_adapter/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -7,7 +7,7 @@ from sqlglot import exp from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.state_sync.engine_adapter.utils import ( +from sqlmesh.core.state_sync.db.utils import ( snapshot_name_version_filter, snapshot_id_filter, create_batches, diff --git a/sqlmesh/core/state_sync/engine_adapter/migrator.py b/sqlmesh/core/state_sync/db/migrator.py similarity index 97% rename from sqlmesh/core/state_sync/engine_adapter/migrator.py rename to sqlmesh/core/state_sync/db/migrator.py index c2a3f78b4b..56ab974e10 100644 --- a/sqlmesh/core/state_sync/engine_adapter/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -29,11 +29,11 @@ MIGRATIONS, ) from sqlmesh.core.state_sync.base import StateSync -from sqlmesh.core.state_sync.engine_adapter.environment import EnvironmentState -from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState -from sqlmesh.core.state_sync.engine_adapter.snapshot import SnapshotState -from sqlmesh.core.state_sync.engine_adapter.version import VersionState -from sqlmesh.core.state_sync.engine_adapter.utils import ( +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.utils import ( SQLMESH_VERSION, snapshot_id_filter, fetchall, @@ -389,7 +389,7 @@ def _migrate_environment_rows( self.console.start_env_migration_progress(len(updated_environments)) for environment in updated_environments: - self.environment_state.update_environment(environment) + self._update_environment(environment) self.console.update_env_migration_progress(1) if updated_prod_environment: @@ -425,6 +425,9 @@ def _restore_table( new_table_name=table_name, ) + def _update_environment(self, environment: Environment) -> None: + self.environment_state.update_environment(environment) + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.snapshot_state.push_snapshots(snapshots) diff --git a/sqlmesh/core/state_sync/engine_adapter/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py similarity index 98% rename from sqlmesh/core/state_sync/engine_adapter/snapshot.py rename to sqlmesh/core/state_sync/db/snapshot.py index d8ce8d7af8..6b7d64c57b 100644 --- a/sqlmesh/core/state_sync/engine_adapter/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -11,7 +11,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.state_sync.engine_adapter.utils import ( +from sqlmesh.core.state_sync.db.utils import ( snapshot_name_version_filter, snapshot_id_filter, fetchall, @@ -39,7 +39,7 @@ from sqlmesh.utils import unique if t.TYPE_CHECKING: - from sqlmesh.core.state_sync.engine_adapter.interval import IntervalState + from sqlmesh.core.state_sync.db.interval import IntervalState logger = logging.getLogger(__name__) @@ -409,13 +409,8 @@ def _update_snapshots( where=where, ) - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: - if overwrite: - snapshots = tuple(snapshots) - self.delete_snapshots(snapshots) - + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: snapshots_to_store = [] - for snapshot in snapshots: if isinstance(snapshot.node, SeedModel): seed_model = t.cast(SeedModel, snapshot.node) diff --git a/sqlmesh/core/state_sync/engine_adapter/utils.py b/sqlmesh/core/state_sync/db/utils.py similarity index 100% rename from sqlmesh/core/state_sync/engine_adapter/utils.py rename to sqlmesh/core/state_sync/db/utils.py diff --git a/sqlmesh/core/state_sync/engine_adapter/version.py b/sqlmesh/core/state_sync/db/version.py similarity index 97% rename from sqlmesh/core/state_sync/engine_adapter/version.py rename to sqlmesh/core/state_sync/db/version.py index 82cad4e21f..8ef6860b92 100644 --- a/sqlmesh/core/state_sync/engine_adapter/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -9,7 +9,7 @@ from sqlglot.helper import seq_get from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.state_sync.engine_adapter.utils import ( +from sqlmesh.core.state_sync.db.utils import ( fetchone, SQLMESH_VERSION, ) diff --git a/sqlmesh/core/state_sync/engine_adapter/__init__.py b/sqlmesh/core/state_sync/engine_adapter/__init__.py deleted file mode 100644 index 86839f1797..0000000000 --- a/sqlmesh/core/state_sync/engine_adapter/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from sqlmesh.core.state_sync.engine_adapter.facade import EngineAdapterStateSync - -__all__ = ["EngineAdapterStateSync"] diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 871ab3bc07..5a449b4137 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -34,7 +34,7 @@ from sqlmesh.core.model.kind import ModelKindName from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder from sqlmesh.core.state_sync.cache import CachingStateSync -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync from sqlmesh.utils.connection_pool import SingletonConnectionPool, ThreadLocalConnectionPool from sqlmesh.utils.date import ( make_inclusive_end, diff --git a/tests/core/test_environment.py b/tests/core/test_environment.py index 8de10318e6..307f220c25 100644 --- a/tests/core/test_environment.py +++ b/tests/core/test_environment.py @@ -2,7 +2,7 @@ from sqlmesh.core.environment import Environment, EnvironmentNamingInfo from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo -from sqlmesh.core.state_sync.engine_adapter.environment import _environment_to_df +from sqlmesh.core.state_sync.db.environment import _environment_to_df def test_sanitize_name(): diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index c435ea75ff..f1008a10dc 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -143,7 +143,7 @@ def test_push_snapshots( snapshot_b.snapshot_id: snapshot_b, } - logger = logging.getLogger("sqlmesh.core.state_sync.engine_adapter.facade") + logger = logging.getLogger("sqlmesh.core.state_sync.db.facade") with patch.object(logger, "error") as mock_logger: state_sync.push_snapshots([snapshot_a]) assert str({snapshot_a.snapshot_id}) == mock_logger.call_args[0][1] @@ -1236,9 +1236,7 @@ def test_delete_expired_snapshots_promoted( env.snapshots_ = [] state_sync.promote(env) - now_timestamp_mock = mocker.patch( - "sqlmesh.core.state_sync.engine_adapter.snapshot.now_timestamp" - ) + now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.db.snapshot.now_timestamp") now_timestamp_mock.return_value = now_timestamp() + 11000 assert state_sync.delete_expired_snapshots() == [ @@ -2151,10 +2149,10 @@ def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture, tmp_ from sqlmesh import __version__ as SQLMESH_VERSION migrate_rows_mock = mocker.patch( - "sqlmesh.core.state_sync.engine_adapter.migrator.StateMigrator._migrate_rows" + "sqlmesh.core.state_sync.db.migrator.StateMigrator._migrate_rows" ) backup_state_mock = mocker.patch( - "sqlmesh.core.state_sync.engine_adapter.migrator.StateMigrator._backup_state" + "sqlmesh.core.state_sync.db.migrator.StateMigrator._backup_state" ) state_sync.migrate(default_catalog=None) migrate_rows_mock.assert_not_called() @@ -2353,7 +2351,7 @@ def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: state_sync.engine_adapter.delete_from("sqlmesh._snapshots", "TRUE") snapshots_count = state_sync.engine_adapter.fetchone("select count(*) from sqlmesh._snapshots") assert snapshots_count == (0,) - state_sync._restore_table( + state_sync.migrator._restore_table( table_name="sqlmesh._snapshots", backup_table_name="sqlmesh._snapshots_backup", ) diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 2fc1573928..7de86e72b3 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -27,7 +27,7 @@ ViewKind, ) from sqlmesh.core.model.kind import SCDType2ByColumnKind, SCDType2ByTimeKind -from sqlmesh.core.state_sync.engine_adapter.snapshot import _snapshot_to_json +from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig,