diff --git a/sqlmesh/core/state_sync/__init__.py b/sqlmesh/core/state_sync/__init__.py index 4626de3c16..1585d6211f 100644 --- a/sqlmesh/core/state_sync/__init__.py +++ b/sqlmesh/core/state_sync/__init__.py @@ -21,4 +21,4 @@ ) from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync from sqlmesh.core.state_sync.common import cleanup_expired_views as cleanup_expired_views -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync as EngineAdapterStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync as EngineAdapterStateSync diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4113737970..f07e531fad 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -252,12 +252,9 @@ def raise_error( return versions @abc.abstractmethod - def _get_versions(self, lock_for_update: bool = False) -> Versions: + def _get_versions(self) -> Versions: """Queries the store to get the current versions of SQLMesh and deps. - Args: - lock_for_update: Whether or not the usage of this method plans to update the row. - Returns: The versions object. """ diff --git a/sqlmesh/core/state_sync/db/__init__.py b/sqlmesh/core/state_sync/db/__init__.py new file mode 100644 index 0000000000..3292449359 --- /dev/null +++ b/sqlmesh/core/state_sync/db/__init__.py @@ -0,0 +1,3 @@ +from sqlmesh.core.state_sync.db.facade import EngineAdapterStateSync + +__all__ = ["EngineAdapterStateSync"] diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py new file mode 100644 index 0000000000..8fcc6787e1 --- /dev/null +++ b/sqlmesh/core/state_sync/db/environment.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import typing as t +import pandas as pd +import json +import logging +from sqlglot import exp + +from sqlmesh.core import constants as c +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + fetchall, + fetchone, +) +from sqlmesh.core.environment import Environment +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, time_like_to_str +from sqlmesh.utils.errors import SQLMeshError + + +logger = logging.getLogger(__name__) + + +class EnvironmentState: + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.environments_table = exp.table_("_environments", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + self._environment_columns_to_types = { + "name": exp.DataType.build(index_type), + "snapshots": exp.DataType.build(blob_type), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build(blob_type), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build(blob_type), + "normalize_name": exp.DataType.build("boolean"), + "requirements": exp.DataType.build(blob_type), + } + + def update_environment(self, environment: Environment) -> None: + """Updates the environment. + + Args: + environment: The environment + """ + self.engine_adapter.delete_from( + self.environments_table, + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment.name), + ), + ) + + self.engine_adapter.insert_append( + self.environments_table, + _environment_to_df(environment), + columns_to_types=self._environment_columns_to_types, + ) + + def invalidate_environment(self, name: str) -> None: + """Invalidates the environment. + + Args: + name: The name of the environment + """ + name = name.lower() + if name == c.PROD: + raise SQLMeshError("Cannot invalidate the production environment.") + + filter_expr = exp.column("name").eq(name) + + self.engine_adapter.update_table( + self.environments_table, + {"expiration_ts": now_timestamp()}, + where=filter_expr, + ) + + def finalize(self, environment: Environment) -> None: + """Finalize the target environment, indicating that this environment has been + fully promoted and is ready for use. + + Args: + environment: The target environment to finalize. + """ + logger.info("Finalizing environment '%s'", environment.name) + + environment_filter = exp.column("name").eq(exp.Literal.string(environment.name)) + + stored_plan_id_query = ( + exp.select("plan_id") + .from_(self.environments_table) + .where(environment_filter, copy=False) + .lock(copy=False) + ) + stored_plan_id_row = fetchone(self.engine_adapter, stored_plan_id_query) + + if not stored_plan_id_row: + raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized") + + stored_plan_id = stored_plan_id_row[0] + if stored_plan_id != environment.plan_id: + raise SQLMeshError( + f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " + f"Stored plan ID: '{stored_plan_id}'. Please recreate the plan and try again" + ) + + environment.finalized_ts = now_timestamp() + self.engine_adapter.update_table( + self.environments_table, + {"finalized_ts": environment.finalized_ts}, + where=environment_filter, + ) + + def delete_expired_environments(self) -> t.List[Environment]: + """Deletes expired environments. + + Returns: + A list of deleted environments. + """ + now_ts = now_timestamp() + filter_expr = exp.LTE( + this=exp.column("expiration_ts"), + expression=exp.Literal.number(now_ts), + ) + + rows = fetchall( + self.engine_adapter, + self._environments_query( + where=filter_expr, + lock_for_update=True, + ), + ) + environments = [self._environment_from_row(r) for r in rows] + + self.engine_adapter.delete_from( + self.environments_table, + where=filter_expr, + ) + + return environments + + def get_environments(self) -> t.List[Environment]: + """Fetches all environments. + + Returns: + A list of all environments. + """ + return [ + self._environment_from_row(row) + for row in fetchall(self.engine_adapter, self._environments_query()) + ] + + def get_environments_summary(self) -> t.Dict[str, int]: + """Fetches all environment names along with expiry datetime. + + Returns: + A dict of all environment names along with expiry datetime. + """ + return dict( + fetchall( + self.engine_adapter, + self._environments_query(required_fields=["name", "expiration_ts"]), + ), + ) + + def get_environment( + self, environment: str, lock_for_update: bool = False + ) -> t.Optional[Environment]: + """Fetches the environment if it exists. + + Args: + environment: The environment + lock_for_update: Lock the snapshot rows for future update + + Returns: + The environment object. + """ + row = fetchone( + self.engine_adapter, + self._environments_query( + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment), + ), + lock_for_update=lock_for_update, + ), + ) + + if not row: + return None + + env = self._environment_from_row(row) + return env + + def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: + return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) + + def _environments_query( + self, + where: t.Optional[str | exp.Expression] = None, + lock_for_update: bool = False, + required_fields: t.Optional[t.List[str]] = None, + ) -> exp.Select: + query_fields = required_fields if required_fields else Environment.all_fields() + query = ( + exp.select(*(exp.to_identifier(field) for field in query_fields)) + .from_(self.environments_table) + .where(where) + ) + if lock_for_update: + return query.lock(copy=False) + return query + + +def _environment_to_df(environment: Environment) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "name": environment.name, + "snapshots": json.dumps(environment.snapshot_dicts()), + "start_at": time_like_to_str(environment.start_at), + "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, + "plan_id": environment.plan_id, + "previous_plan_id": environment.previous_plan_id, + "expiration_ts": environment.expiration_ts, + "finalized_ts": environment.finalized_ts, + "promoted_snapshot_ids": ( + json.dumps(environment.promoted_snapshot_id_dicts()) + if environment.promoted_snapshot_ids is not None + else None + ), + "suffix_target": environment.suffix_target.value, + "catalog_name_override": environment.catalog_name_override, + "previous_finalized_snapshots": ( + json.dumps(environment.previous_finalized_snapshot_dicts()) + if environment.previous_finalized_snapshots is not None + else None + ), + "normalize_name": environment.normalize_name, + "requirements": json.dumps(environment.requirements), + } + ] + ) diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py new file mode 100644 index 0000000000..3210d5ea35 --- /dev/null +++ b/sqlmesh/core/state_sync/db/facade.py @@ -0,0 +1,469 @@ +""" +# StateSync + +State sync is how SQLMesh keeps track of environments and their states, e.g. snapshots. + +# StateReader + +StateReader provides a subset of the functionalities of the StateSync class. As its name +implies, it only allows for read-only operations on snapshots and environment states. + +# EngineAdapterStateSync + +The provided `sqlmesh.core.state_sync.EngineAdapterStateSync` leverages an existing engine +adapter to read and write state to the underlying data store. +""" + +from __future__ import annotations + +import contextlib +import logging +import typing as t +from pathlib import Path +from datetime import datetime + +from sqlglot import exp + +from sqlmesh.core.console import Console, get_console +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotId, + SnapshotIdLike, + SnapshotInfoLike, + SnapshotIntervals, + SnapshotNameVersion, + SnapshotTableCleanupTask, + SnapshotTableInfo, + start_date, +) +from sqlmesh.core.snapshot.definition import ( + Interval, +) +from sqlmesh.core.state_sync.base import ( + PromotionResult, + StateSync, + Versions, +) +from sqlmesh.core.state_sync.common import transactional +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.migrator import StateMigrator +from sqlmesh.utils.date import TimeLike, to_timestamp +from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError + +logger = logging.getLogger(__name__) + + +T = t.TypeVar("T") + + +if t.TYPE_CHECKING: + pass + + +class EngineAdapterStateSync(StateSync): + """Manages state of nodes and snapshot with an existing engine adapter. + + This state sync is convenient to use because it requires no additional setup. + You can reuse the same engine/warehouse that your data is stored in. + + Args: + engine_adapter: The EngineAdapter to use to store and fetch snapshots. + schema: The schema to store state metadata in. If None or empty string then no schema is defined + console: The console to log information to. + context_path: The context path, used for caching snapshot models. + """ + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str], + console: t.Optional[Console] = None, + context_path: Path = Path(), + ): + self.plan_dags_table = exp.table_("_plan_dags", db=schema) + self.interval_state = IntervalState(engine_adapter, schema=schema) + self.environment_state = EnvironmentState(engine_adapter, schema=schema) + self.snapshot_state = SnapshotState( + engine_adapter, schema=schema, context_path=context_path + ) + self.version_state = VersionState(engine_adapter, schema=schema) + self.migrator = StateMigrator( + engine_adapter, + version_state=self.version_state, + snapshot_state=self.snapshot_state, + environment_state=self.environment_state, + interval_state=self.interval_state, + plan_dags_table=self.plan_dags_table, + console=console, + ) + # Make sure that if an empty string is provided that we treat it as None + self.schema = schema or None + self.engine_adapter = engine_adapter + self.console = console or get_console() + + @transactional() + def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + """Pushes snapshots to the state store, merging them with existing ones. + + This method first finds all existing snapshots in the store and merges them with + the local snapshots. It will then delete all existing snapshots and then + insert all the local snapshots. This can be made safer with locks or merge/upsert. + + Args: + snapshots: The snapshots to push. + """ + snapshots_by_id = {} + for snapshot in snapshots: + if not snapshot.version: + raise SQLMeshError( + f"Snapshot {snapshot} has not been versioned yet. Create a plan before pushing a snapshot." + ) + snapshots_by_id[snapshot.snapshot_id] = snapshot + + existing = self.snapshots_exist(snapshots_by_id) + + if existing: + logger.error( + "Snapshots %s already exists. This could be due to a concurrent plan or a hash collision. If this is a hash collision, add a stamp to your model.", + str(existing), + ) + + for sid in tuple(snapshots_by_id): + if sid in existing: + snapshots_by_id.pop(sid) + + snapshots = snapshots_by_id.values() + if snapshots: + self._push_snapshots(snapshots) + + @transactional() + def promote( + self, + environment: Environment, + no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, + ) -> PromotionResult: + """Update the environment to reflect the current state. + + This method verifies that snapshots have been pushed. + + Args: + environment: The environment to promote. + no_gaps_snapshot_names: A set of snapshot names to check for data gaps. If None, + all snapshots will be checked. The data gap check ensures that models that are already a + part of the target environment have no data gaps when compared against previous + snapshots for same models. + + Returns: + A tuple of (added snapshot table infos, removed snapshot table infos, and environment target suffix for the removed table infos) + """ + logger.info("Promoting environment '%s'", environment.name) + + missing = {s.snapshot_id for s in environment.snapshots} - self.snapshots_exist( + environment.snapshots + ) + if missing: + raise SQLMeshError( + f"Missing snapshots {missing}. Make sure to push and backfill your snapshots." + ) + + existing_environment = self.environment_state.get_environment( + environment.name, lock_for_update=True + ) + + existing_table_infos = ( + {table_info.name: table_info for table_info in existing_environment.promoted_snapshots} + if existing_environment + else {} + ) + table_infos = {table_info.name: table_info for table_info in environment.promoted_snapshots} + views_that_changed_location: t.Set[SnapshotTableInfo] = set() + if existing_environment: + views_that_changed_location = { + existing_table_info + for name, existing_table_info in existing_table_infos.items() + if name in table_infos + and existing_table_info.qualified_view_name.for_environment( + existing_environment.naming_info + ) + != table_infos[name].qualified_view_name.for_environment(environment.naming_info) + } + if not existing_environment.expired: + if environment.previous_plan_id != existing_environment.plan_id: + raise ConflictingPlanError( + f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " + f"Expected previous plan ID: '{environment.previous_plan_id}', actual previous plan ID: '{existing_environment.plan_id}'. " + "Please recreate the plan and try again" + ) + if no_gaps_snapshot_names != set(): + snapshots = self.get_snapshots(environment.snapshots).values() + self._ensure_no_gaps( + snapshots, + existing_environment, + no_gaps_snapshot_names, + ) + demoted_snapshots = set(existing_environment.snapshots) - set(environment.snapshots) + # Update the updated_at attribute. + self.snapshot_state.touch_snapshots(demoted_snapshots) + + missing_models = set(existing_table_infos) - { + snapshot.name for snapshot in environment.promoted_snapshots + } + + added_table_infos = set(table_infos.values()) + if ( + existing_environment + and existing_environment.finalized_ts + and not existing_environment.expired + ): + # Only promote new snapshots. + added_table_infos -= set(existing_environment.promoted_snapshots) + + self.environment_state.update_environment(environment) + + removed = {existing_table_infos[name] for name in missing_models}.union( + views_that_changed_location + ) + + return PromotionResult( + added=sorted(added_table_infos), + removed=list(removed), + removed_environment_naming_info=( + existing_environment.naming_info if removed and existing_environment else None + ), + ) + + @transactional() + def finalize(self, environment: Environment) -> None: + """Finalize the target environment, indicating that this environment has been + fully promoted and is ready for use. + + Args: + environment: The target environment to finalize. + """ + self.environment_state.finalize(environment) + + @transactional() + def unpause_snapshots( + self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike + ) -> None: + self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state) + + def invalidate_environment(self, name: str) -> None: + self.environment_state.invalidate_environment(name) + + @transactional() + def delete_expired_snapshots( + self, ignore_ttl: bool = False + ) -> t.List[SnapshotTableCleanupTask]: + expired_snapshot_ids, cleanup_targets = self.snapshot_state.delete_expired_snapshots( + self.environment_state.get_environments(), ignore_ttl=ignore_ttl + ) + self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) + return cleanup_targets + + @transactional() + def delete_expired_environments(self) -> t.List[Environment]: + return self.environment_state.delete_expired_environments() + + def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + self.snapshot_state.delete_snapshots(snapshot_ids) + + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: + return self.snapshot_state.snapshots_exist(snapshot_ids) + + def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: + return self.snapshot_state.nodes_exist(names, exclude_external) + + def reset(self, default_catalog: t.Optional[str]) -> None: + """Resets the state store to the state when it was first initialized.""" + for table in ( + self.snapshot_state.snapshots_table, + self.snapshot_state.auto_restatements_table, + self.environment_state.environments_table, + self.interval_state.intervals_table, + self.plan_dags_table, + self.version_state.versions_table, + ): + self.engine_adapter.drop_table(table) + self.snapshot_state.clear_cache() + self.migrate(default_catalog) + + @transactional() + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + self.snapshot_state.update_auto_restatements(next_auto_restatement_ts) + + def get_environment(self, environment: str) -> t.Optional[Environment]: + return self.environment_state.get_environment(environment) + + def get_environments(self) -> t.List[Environment]: + """Fetches all environments. + + Returns: + A list of all environments. + """ + return self.environment_state.get_environments() + + def get_environments_summary(self) -> t.Dict[str, int]: + """Fetches all environment names along with expiry datetime. + + Returns: + A dict of all environment names along with expiry datetime. + """ + return self.environment_state.get_environments_summary() + + def get_snapshots( + self, + snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches snapshots from the state. + + Args: + snapshot_ids: The snapshot IDs to fetch. + + Returns: + A dict of snapshots. + """ + snapshots = self.snapshot_state.get_snapshots(snapshot_ids) + intervals = self.interval_state.get_snapshot_intervals(snapshots.values()) + Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) + return snapshots + + @transactional() + def add_interval( + self, + snapshot: Snapshot, + start: TimeLike, + end: TimeLike, + is_dev: bool = False, + ) -> None: + super().add_interval(snapshot, start, end, is_dev) + + @transactional() + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + self.interval_state.add_snapshots_intervals(snapshots_intervals) + + @transactional() + def remove_intervals( + self, + snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + remove_shared_versions: bool = False, + ) -> None: + self.interval_state.remove_intervals(snapshot_intervals, remove_shared_versions) + + @transactional() + def compact_intervals(self) -> None: + self.interval_state.compact_intervals() + + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + return self.interval_state.refresh_snapshot_intervals(snapshots) + + def max_interval_end_per_model( + self, + environment: str, + models: t.Optional[t.Set[str]] = None, + ensure_finalized_snapshots: bool = False, + ) -> t.Dict[str, int]: + env = self.get_environment(environment) + if not env: + return {} + + snapshots = ( + env.snapshots if not ensure_finalized_snapshots else env.finalized_or_current_snapshots + ) + if models is not None: + snapshots = [s for s in snapshots if s.name in models] + + if not snapshots: + return {} + + return self.interval_state.max_interval_end_per_model(snapshots) + + def recycle(self) -> None: + self.engine_adapter.recycle() + + def close(self) -> None: + self.engine_adapter.close() + + @transactional() + def migrate( + self, + default_catalog: t.Optional[str], + skip_backup: bool = False, + promoted_snapshots_only: bool = True, + ) -> None: + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + self.migrator.migrate( + self, + default_catalog, + skip_backup=skip_backup, + promoted_snapshots_only=promoted_snapshots_only, + ) + + @transactional() + def rollback(self) -> None: + """Rollback to the previous migration.""" + self.migrator.rollback() + + def state_type(self) -> str: + return self.engine_adapter.dialect + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + self.snapshot_state.push_snapshots(snapshots) + + def _get_versions(self) -> Versions: + return self.version_state.get_versions() + + def _ensure_no_gaps( + self, + target_snapshots: t.Iterable[Snapshot], + target_environment: Environment, + snapshot_names: t.Optional[t.Set[str]], + ) -> None: + target_snapshots_by_name = {s.name: s for s in target_snapshots} + + changed_version_prev_snapshots_by_name = { + s.name: s + for s in target_environment.snapshots + if s.name in target_snapshots_by_name + and target_snapshots_by_name[s.name].version != s.version + } + + prev_snapshots = self.get_snapshots( + changed_version_prev_snapshots_by_name.values() + ).values() + cache: t.Dict[str, datetime] = {} + + for prev_snapshot in prev_snapshots: + target_snapshot = target_snapshots_by_name[prev_snapshot.name] + if ( + (snapshot_names is None or prev_snapshot.name in snapshot_names) + and target_snapshot.is_incremental + and prev_snapshot.is_incremental + and prev_snapshot.intervals + ): + start = to_timestamp( + start_date(target_snapshot, target_snapshots_by_name.values(), cache) + ) + end = prev_snapshot.intervals[-1][1] + + if start < end: + missing_intervals = target_snapshot.missing_intervals( + start, end, end_bounded=True + ) + + if missing_intervals: + raise SQLMeshError( + f"Detected gaps in snapshot {target_snapshot.snapshot_id}: {missing_intervals}" + ) + + @contextlib.contextmanager + def _transaction(self) -> t.Iterator[None]: + with self.engine_adapter.transaction(): + yield diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py new file mode 100644 index 0000000000..944e4b650c --- /dev/null +++ b/sqlmesh/core/state_sync/db/interval.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +import typing as t +import logging +import pandas as pd + +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + snapshot_name_version_filter, + snapshot_id_filter, + create_batches, + fetchall, +) +from sqlmesh.core.snapshot import ( + SnapshotIntervals, + SnapshotIdLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + SnapshotInfoLike, + Snapshot, + SnapshotId, +) +from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils import random_id +from sqlmesh.utils.date import now_timestamp, time_like_to_str + + +logger = logging.getLogger(__name__) + + +class IntervalState: + INTERVAL_BATCH_SIZE = 1000 + SNAPSHOT_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + table_name: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.intervals_table = exp.table_(table_name or "_intervals", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._interval_columns_to_types = { + "id": exp.DataType.build(index_type), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + "is_pending_restatement": exp.DataType.build("boolean"), + } + + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + def remove_partial_intervals( + intervals: t.List[Interval], snapshot_id: t.Optional[SnapshotId], *, is_dev: bool + ) -> t.List[Interval]: + results = [] + for start_ts, end_ts in intervals: + if start_ts < end_ts: + logger.info( + "Adding %s (%s, %s) for snapshot %s", + "dev interval" if is_dev else "interval", + time_like_to_str(start_ts), + time_like_to_str(end_ts), + snapshot_id, + ) + results.append((start_ts, end_ts)) + else: + logger.info( + "Skipping partial interval (%s, %s) for snapshot %s", + start_ts, + end_ts, + snapshot_id, + ) + return results + + intervals_to_insert = [] + for snapshot_intervals in snapshots_intervals: + snapshot_intervals = snapshot_intervals.copy( + update={ + "intervals": remove_partial_intervals( + snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False + ), + "dev_intervals": remove_partial_intervals( + snapshot_intervals.dev_intervals, + snapshot_intervals.snapshot_id, + is_dev=True, + ), + } + ) + if ( + snapshot_intervals.intervals + or snapshot_intervals.dev_intervals + or snapshot_intervals.pending_restatement_intervals + ): + intervals_to_insert.append(snapshot_intervals) + + if intervals_to_insert: + self._push_snapshot_intervals(intervals_to_insert) + + def remove_intervals( + self, + snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + remove_shared_versions: bool = False, + ) -> None: + intervals_to_remove: t.Sequence[ + t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] + ] = snapshot_intervals + if remove_shared_versions: + name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} + all_snapshots = [] + for where in snapshot_name_version_filter( + self.engine_adapter, + name_version_mapping, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + all_snapshots.extend( + [ + SnapshotIntervals( + name=r[0], + identifier=r[1], + version=r[2], + dev_version=r[3], + intervals=[], + dev_intervals=[], + ) + for r in fetchall( + self.engine_adapter, + exp.select("name", "identifier", "version", "dev_version") + .from_(self.intervals_table) + .where(where) + .distinct(), + ) + ] + ) + intervals_to_remove = [ + (snapshot, name_version_mapping[snapshot.name_version]) + for snapshot in all_snapshots + ] + + if logger.isEnabledFor(logging.INFO): + snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) + logger.info("Removing interval for snapshots: %s", snapshot_ids) + + for is_dev in (True, False): + self.engine_adapter.insert_append( + self.intervals_table, + _intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True), + columns_to_types=self._interval_columns_to_types, + ) + + def get_snapshot_intervals( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.List[SnapshotIntervals]: + return self._get_snapshot_intervals(snapshots)[1] + + def compact_intervals(self) -> None: + interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) + + logger.info( + "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) + ) + + self._push_snapshot_intervals(snapshot_intervals, is_compacted=True) + + if interval_ids: + for interval_id_batch in create_batches( + list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE + ): + self.engine_adapter.delete_from( + self.intervals_table, exp.column("id").isin(*interval_id_batch) + ) + + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + if not snapshots: + return [] + + _, intervals = self._get_snapshot_intervals(snapshots) + for s in snapshots: + s.intervals = [] + s.dev_intervals = [] + return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) + + def max_interval_end_per_model( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.Dict[str, int]: + if not snapshots: + return {} + + table_alias = "intervals" + name_col = exp.column("name", table=table_alias) + version_col = exp.column("version", table=table_alias) + + result: t.Dict[str, int] = {} + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, alias=table_alias, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + name_col, + exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), + ) + .from_(exp.to_table(self.intervals_table).as_(table_alias)) + .where(where, copy=False) + .where( + exp.and_( + exp.to_column("is_dev").not_(), + exp.to_column("is_removed").not_(), + exp.to_column("is_pending_restatement").not_(), + ), + copy=False, + ) + .group_by(name_col, version_col, copy=False) + ) + + for name, max_end in fetchall(self.engine_adapter, query): + result[name] = max_end + + return result + + def cleanup_intervals( + self, + cleanup_targets: t.List[SnapshotTableCleanupTask], + expired_snapshot_ids: t.Collection[SnapshotIdLike], + ) -> None: + # Cleanup can only happen for compacted intervals + self.compact_intervals() + # Delete intervals for non-dev tables that are no longer used + self._delete_intervals_by_version(cleanup_targets) + # Delete dev intervals for dev tables that are no longer used + self._delete_intervals_by_dev_version(cleanup_targets) + # Nullify the snapshot identifiers of interval records for snapshots that have been deleted + self._update_intervals_for_deleted_snapshots(expired_snapshot_ids) + + def _push_snapshot_intervals( + self, + snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]], + is_compacted: bool = False, + ) -> None: + new_intervals = [] + for snapshot in snapshots: + logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) + for start_ts, end_ts in snapshot.intervals: + new_intervals.append( + _interval_to_df( + snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted + ) + ) + for start_ts, end_ts in snapshot.dev_intervals: + new_intervals.append( + _interval_to_df( + snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted + ) + ) + + # Make sure that all pending restatement intervals are recorded last + for snapshot in snapshots: + for start_ts, end_ts in snapshot.pending_restatement_intervals: + new_intervals.append( + _interval_to_df( + snapshot, + start_ts, + end_ts, + is_dev=False, + is_compacted=is_compacted, + is_pending_restatement=True, + ) + ) + + if new_intervals: + self.engine_adapter.insert_append( + self.intervals_table, + pd.DataFrame(new_intervals), + columns_to_types=self._interval_columns_to_types, + ) + + def _get_snapshot_intervals( + self, + snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, + uncompacted_only: bool = False, + ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: + if not snapshots and snapshots is not None: + return (set(), []) + + query = self._get_snapshot_intervals_query(uncompacted_only) + + interval_ids: t.Set[str] = set() + intervals: t.Dict[ + t.Tuple[str, str, t.Optional[str], t.Optional[str]], SnapshotIntervals + ] = {} + + for where in ( + snapshot_name_version_filter( + self.engine_adapter, + snapshots, + alias="intervals", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + if snapshots + else [None] + ): + rows = fetchall(self.engine_adapter, query.where(where)) + for ( + interval_id, + name, + identifier, + version, + dev_version, + start, + end, + is_dev, + is_removed, + is_pending_restatement, + ) in rows: + interval_ids.add(interval_id) + merge_key = (name, version, dev_version, identifier) + # Pending restatement intervals are merged by name and version + pending_restatement_interval_merge_key = (name, version, None, None) + + if merge_key not in intervals: + intervals[merge_key] = SnapshotIntervals( + name=name, + identifier=identifier, + version=version, + dev_version=dev_version, + ) + + if pending_restatement_interval_merge_key not in intervals: + intervals[pending_restatement_interval_merge_key] = SnapshotIntervals( + name=name, + identifier=None, + version=version, + dev_version=None, + ) + + if is_removed: + if is_dev: + intervals[merge_key].remove_dev_interval(start, end) + else: + intervals[merge_key].remove_interval(start, end) + elif is_pending_restatement: + intervals[ + pending_restatement_interval_merge_key + ].add_pending_restatement_interval(start, end) + else: + if is_dev: + intervals[merge_key].add_dev_interval(start, end) + else: + intervals[merge_key].add_interval(start, end) + # Remove all pending restatement intervals recorded before the current interval has been added + intervals[ + pending_restatement_interval_merge_key + ].remove_pending_restatement_interval(start, end) + + return interval_ids, [i for i in intervals.values() if not i.is_empty()] + + def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: + query = ( + exp.select( + "id", + exp.column("name", table="intervals"), + exp.column("identifier", table="intervals"), + exp.column("version", table="intervals"), + exp.column("dev_version", table="intervals"), + "start_ts", + "end_ts", + "is_dev", + "is_removed", + "is_pending_restatement", + ) + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .order_by( + exp.column("name", table="intervals"), + exp.column("version", table="intervals"), + "created_ts", + "is_removed", + "is_pending_restatement", + ) + ) + if uncompacted_only: + query.join( + exp.select("name", "version") + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .where(exp.column("is_compacted").not_()) + .distinct() + .subquery(alias="uncompacted"), + on=exp.and_( + exp.column("name", table="intervals").eq( + exp.column("name", table="uncompacted") + ), + exp.column("version", table="intervals").eq( + exp.column("version", table="uncompacted") + ), + ), + copy=False, + ) + return query + + def _update_intervals_for_deleted_snapshots( + self, snapshot_ids: t.Collection[SnapshotIdLike] + ) -> None: + """Nullifies the snapshot identifiers of dev interval records and snapshot identifiers and dev versions of + non-dev interval records for snapshots that have been deleted so that they can be compacted efficiently. + """ + if not snapshot_ids: + return + + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, alias=None, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + # Nullify the identifier for dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev")), + ) + # Nullify both identifier and dev version for non-dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "dev_version": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev").not_()), + ) + + def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes dev intervals for snapshot dev versions that are no longer used.""" + dev_keys_to_delete = [ + SnapshotNameVersion(name=t.snapshot.name, version=t.snapshot.dev_version) + for t in targets + if t.dev_table_only + ] + if not dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, + dev_keys_to_delete, + version_column_name="dev_version", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) + + def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes intervals for snapshot versions that are no longer used.""" + non_dev_keys_to_delete = [t.snapshot for t in targets if not t.dev_table_only] + if not non_dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, + non_dev_keys_to_delete, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.intervals_table, where) + + +def _intervals_to_df( + snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], + is_dev: bool, + is_removed: bool, +) -> pd.DataFrame: + return pd.DataFrame( + [ + _interval_to_df( + s, + *interval, + is_dev=is_dev, + is_removed=is_removed, + ) + for s, interval in snapshot_intervals + ] + ) + + +def _interval_to_df( + snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], + start_ts: int, + end_ts: int, + is_dev: bool = False, + is_removed: bool = False, + is_compacted: bool = False, + is_pending_restatement: bool = False, +) -> t.Dict[str, t.Any]: + return { + "id": random_id(), + "created_ts": now_timestamp(), + "name": snapshot.name, + "identifier": snapshot.identifier if not is_pending_restatement else None, + "version": snapshot.version, + "dev_version": snapshot.dev_version if not is_pending_restatement else None, + "start_ts": start_ts, + "end_ts": end_ts, + "is_dev": is_dev, + "is_removed": is_removed, + "is_compacted": is_compacted, + "is_pending_restatement": is_pending_restatement, + } diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py new file mode 100644 index 0000000000..56ab974e10 --- /dev/null +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -0,0 +1,462 @@ +from __future__ import annotations + +import json +import logging +import time +import typing as t +from copy import deepcopy + +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp + +from sqlmesh.core import analytics +from sqlmesh.core import constants as c +from sqlmesh.core.console import Console, get_console +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import ( + Node, + Snapshot, + SnapshotFingerprint, + SnapshotId, + SnapshotTableInfo, + fingerprint_from_node, +) +from sqlmesh.core.snapshot.definition import ( + _parents_from_node, +) +from sqlmesh.core.state_sync.base import ( + MIGRATIONS, +) +from sqlmesh.core.state_sync.base import StateSync +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.utils import ( + SQLMESH_VERSION, + snapshot_id_filter, + fetchall, +) +from sqlmesh.utils import major_minor +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.errors import SQLMeshError + +logger = logging.getLogger(__name__) + + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName + + +class StateMigrator: + SNAPSHOT_BATCH_SIZE = 1000 + SNAPSHOT_MIGRATION_BATCH_SIZE = 500 + + def __init__( + self, + engine_adapter: EngineAdapter, + version_state: VersionState, + snapshot_state: SnapshotState, + environment_state: EnvironmentState, + interval_state: IntervalState, + plan_dags_table: TableName, + console: t.Optional[Console] = None, + ): + self.engine_adapter = engine_adapter + self.console = console or get_console() + self.version_state = version_state + self.snapshot_state = snapshot_state + self.environment_state = environment_state + self.interval_state = interval_state + self.plan_dags_table = plan_dags_table + + self._state_tables = [ + self.snapshot_state.snapshots_table, + self.environment_state.environments_table, + self.version_state.versions_table, + ] + self._optional_state_tables = [ + self.interval_state.intervals_table, + self.plan_dags_table, + self.snapshot_state.auto_restatements_table, + ] + + def migrate( + self, + state_sync: StateSync, + default_catalog: t.Optional[str], + skip_backup: bool = False, + promoted_snapshots_only: bool = True, + ) -> None: + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + versions = self.version_state.get_versions() + migration_start_ts = time.perf_counter() + + try: + migrate_rows = self._apply_migrations(state_sync, default_catalog, skip_backup) + + if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: + return + + if migrate_rows: + self._migrate_rows(promoted_snapshots_only) + # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. + self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") + self.version_state.update_versions() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + ) + except Exception as e: + if skip_backup: + logger.error("Backup was skipped so no rollback was attempted.") + else: + self.rollback() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + error=e, + ) + + self.console.log_migration_status(success=False) + raise SQLMeshError("SQLMesh migration failed.") from e + + self.console.log_migration_status() + + def rollback(self) -> None: + """Rollback to the previous migration.""" + logger.info("Starting migration rollback.") + versions = self.version_state.get_versions() + if versions.schema_version == 0: + # Clean up state tables + for table in self._state_tables + self._optional_state_tables: + self.engine_adapter.drop_table(table) + else: + if not all( + self.engine_adapter.table_exists(_backup_table_name(table)) + for table in self._state_tables + ): + raise SQLMeshError("There are no prior migrations to roll back to.") + for table in self._state_tables: + self._restore_table(table, _backup_table_name(table)) + + for optional_table in self._optional_state_tables: + if self.engine_adapter.table_exists(_backup_table_name(optional_table)): + self._restore_table(optional_table, _backup_table_name(optional_table)) + + logger.info("Migration rollback successful.") + + def _apply_migrations( + self, + state_sync: StateSync, + default_catalog: t.Optional[str], + skip_backup: bool, + ) -> bool: + versions = self.version_state.get_versions() + migrations = MIGRATIONS[versions.schema_version :] + should_backup = any( + [ + migrations, + major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version, + major_minor(SQLMESH_VERSION) != versions.minor_sqlmesh_version, + ] + ) + if not skip_backup and should_backup: + self._backup_state() + + snapshot_count_before = self.snapshot_state.count() if versions.schema_version else None + + for migration in migrations: + logger.info(f"Applying migration {migration}") + migration.migrate(state_sync, default_catalog=default_catalog) + + snapshot_count_after = self.snapshot_state.count() + + if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: + scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" + raise SQLMeshError( + f"Number of snapshots before ({snapshot_count_before}) and after " + f"({snapshot_count_after}) applying migration scripts {scripts} does not match. " + "Please file an issue issue at https://github.com/TobikoData/sqlmesh/issues/new." + ) + + migrate_snapshots_and_environments = ( + bool(migrations) or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version + ) + return migrate_snapshots_and_environments + + def _migrate_rows(self, promoted_snapshots_only: bool) -> None: + logger.info("Fetching environments") + environments = self.environment_state.get_environments() + # Only migrate snapshots that are part of at least one environment. + snapshots_to_migrate = ( + {s.snapshot_id for e in environments for s in e.snapshots} + if promoted_snapshots_only + else None + ) + snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) + if not snapshot_mapping: + logger.info("No changes to snapshots detected") + return + self._migrate_environment_rows(environments, snapshot_mapping) + + def _migrate_snapshot_rows( + self, snapshots: t.Optional[t.Set[SnapshotId]] + ) -> t.Dict[SnapshotId, SnapshotTableInfo]: + logger.info("Migrating snapshot rows...") + raw_snapshots = { + SnapshotId(name=name, identifier=identifier): { + **json.loads(raw_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + } + for where in ( + snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + if snapshots is not None + else [None] + ) + for name, identifier, raw_snapshot, updated_ts, unpaused_ts, unrestorable in fetchall( + self.engine_adapter, + exp.select( + "name", "identifier", "snapshot", "updated_ts", "unpaused_ts", "unrestorable" + ) + .from_(self.snapshot_state.snapshots_table) + .where(where) + .lock(), + ) + } + if not raw_snapshots: + return {} + + dag: DAG[SnapshotId] = DAG() + for snapshot_id, raw_snapshot in raw_snapshots.items(): + parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] + dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) + + reversed_dag_raw = dag.reversed.graph + + self.console.start_snapshot_migration_progress(len(raw_snapshots)) + + parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) + all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} + snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} + new_snapshots: t.Dict[SnapshotId, Snapshot] = {} + visited: t.Set[SnapshotId] = set() + + def _push_new_snapshots() -> None: + all_snapshot_mapping.update( + { + from_id: new_snapshots[to_id].table_info + for from_id, to_id in snapshot_id_mapping.items() + } + ) + + existing_new_snapshots = self.snapshot_state.snapshots_exist(new_snapshots) + new_snapshots_to_push = [ + s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots + ] + if new_snapshots_to_push: + logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) + self._push_snapshots(new_snapshots_to_push) + new_snapshots.clear() + snapshot_id_mapping.clear() + + def _visit( + snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] + ) -> None: + if snapshot_id in visited or snapshot_id not in raw_snapshots: + return + visited.add(snapshot_id) + + snapshot = parsed_snapshots[snapshot_id] + node = snapshot.node + + node_seen = set() + node_queue = {snapshot_id} + nodes: t.Dict[str, Node] = {} + while node_queue: + next_snapshot_id = node_queue.pop() + next_snapshot = parsed_snapshots.get(next_snapshot_id) + + if next_snapshot_id in node_seen or not next_snapshot: + continue + + node_seen.add(next_snapshot_id) + node_queue.update(next_snapshot.parents) + nodes[next_snapshot.name] = next_snapshot.node + + new_snapshot = deepcopy(snapshot) + try: + new_snapshot.fingerprint = fingerprint_from_node( + node, + nodes=nodes, + cache=fingerprint_cache, + ) + new_snapshot.parents = tuple( + SnapshotId( + name=parent_node.fqn, + identifier=fingerprint_from_node( + parent_node, + nodes=nodes, + cache=fingerprint_cache, + ).to_identifier(), + ) + for parent_node in _parents_from_node(node, nodes).values() + ) + except Exception: + logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) + return + + # Reset the effective_from date for the new snapshot to avoid unexpected backfills. + new_snapshot.effective_from = None + new_snapshot.previous_versions = snapshot.all_versions + new_snapshot.migrated = True + if not new_snapshot.dev_version_: + new_snapshot.dev_version_ = snapshot.dev_version + + self.console.update_snapshot_migration_progress(1) + + # Visit children and evict them from the parsed_snapshots cache after. + for child in reversed_dag_raw.get(snapshot_id, []): + # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. + _visit(child, fingerprint_cache.copy()) + parsed_snapshots.evict(child) + + if new_snapshot.fingerprint == snapshot.fingerprint: + logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") + return + + new_snapshot_id = new_snapshot.snapshot_id + + if new_snapshot_id in raw_snapshots: + # Mapped to an existing snapshot. + new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] + logger.debug("Migrated snapshot %s already exists", new_snapshot_id) + elif ( + new_snapshot_id not in new_snapshots + or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts + ): + new_snapshots[new_snapshot_id] = new_snapshot + + snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id + logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) + + if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: + _push_new_snapshots() + + for root_snapshot_id in dag.roots: + _visit(root_snapshot_id, {}) + + if new_snapshots: + _push_new_snapshots() + + self.console.stop_snapshot_migration_progress() + return all_snapshot_mapping + + def _migrate_environment_rows( + self, + environments: t.List[Environment], + snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], + ) -> None: + logger.info("Migrating environment rows...") + + updated_prod_environment: t.Optional[Environment] = None + updated_environments = [] + for environment in environments: + snapshots = [ + ( + snapshot_mapping[info.snapshot_id] + if info.snapshot_id in snapshot_mapping + else info + ) + for info in environment.snapshots + ] + + if snapshots != environment.snapshots: + environment.snapshots_ = snapshots + updated_environments.append(environment) + if environment.name == c.PROD: + updated_prod_environment = environment + self.console.start_env_migration_progress(len(updated_environments)) + + for environment in updated_environments: + self._update_environment(environment) + self.console.update_env_migration_progress(1) + + if updated_prod_environment: + try: + self.snapshot_state.unpause_snapshots( + updated_prod_environment.snapshots, now_timestamp(), self.interval_state + ) + except Exception: + logger.warning("Failed to unpause migrated snapshots", exc_info=True) + + self.console.stop_env_migration_progress() + + def _backup_state(self) -> None: + for table in [ + *self._state_tables, + *self._optional_state_tables, + ]: + if self.engine_adapter.table_exists(table): + with self.engine_adapter.transaction(): + backup_name = _backup_table_name(table) + self.engine_adapter.drop_table(backup_name) + self.engine_adapter.create_table_like(backup_name, table) + self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) + + def _restore_table( + self, + table_name: TableName, + backup_table_name: TableName, + ) -> None: + self.engine_adapter.drop_table(table_name) + self.engine_adapter.rename_table( + old_table_name=backup_table_name, + new_table_name=table_name, + ) + + def _update_environment(self, environment: Environment) -> None: + self.environment_state.update_environment(environment) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + self.snapshot_state.push_snapshots(snapshots) + + +def _backup_table_name(table_name: TableName) -> exp.Table: + table = exp.to_table(table_name).copy() + table.set("this", exp.to_identifier(table.name + "_backup")) + return table + + +class LazilyParsedSnapshots: + def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): + self._raw_snapshots = raw_snapshots + self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} + + def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: + if snapshot_id not in self._parsed_snapshots: + raw_snapshot = self._raw_snapshots.get(snapshot_id) + if raw_snapshot: + self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) + else: + self._parsed_snapshots[snapshot_id] = None + return self._parsed_snapshots[snapshot_id] + + def evict(self, snapshot_id: SnapshotId) -> None: + self._parsed_snapshots.pop(snapshot_id, None) + + def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: + snapshot = self.get(snapshot_id) + if snapshot is None: + raise KeyError(snapshot_id) + return snapshot diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py new file mode 100644 index 0000000000..6b7d64c57b --- /dev/null +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -0,0 +1,785 @@ +from __future__ import annotations + +import typing as t +import pandas as pd +import json +import logging +from pathlib import Path +from collections import defaultdict +from sqlglot import exp +from pydantic import Field + +from sqlmesh.core import constants as c +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + snapshot_name_version_filter, + snapshot_id_filter, + fetchall, + create_batches, +) +from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import SeedModel, ModelKindName +from sqlmesh.core.snapshot.cache import SnapshotCache +from sqlmesh.core.snapshot import ( + SnapshotIdLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + SnapshotInfoLike, + Snapshot, + SnapshotId, + SnapshotFingerprint, + SnapshotChangeCategory, +) +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils import unique + +if t.TYPE_CHECKING: + from sqlmesh.core.state_sync.db.interval import IntervalState + + +logger = logging.getLogger(__name__) + + +class SnapshotState: + SNAPSHOT_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + context_path: Path = Path(), + ): + self.engine_adapter = engine_adapter + self.snapshots_table = exp.table_("_snapshots", db=schema) + self.auto_restatements_table = exp.table_("_auto_restatements", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + self._snapshot_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + + self._auto_restatement_columns_to_types = { + "snapshot_name": exp.DataType.build(index_type), + "snapshot_version": exp.DataType.build(index_type), + "next_auto_restatement_ts": exp.DataType.build("bigint"), + } + + self._snapshot_cache = SnapshotCache(context_path / c.CACHE) + + def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: + """Pushes snapshots to the state store. + + Args: + snapshots: The snapshots to push. + overwrite: Whether to overwrite existing snapshots. + """ + if overwrite: + snapshots = tuple(snapshots) + self.delete_snapshots(snapshots) + + snapshots_to_store = [] + + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + columns_to_types=self._snapshot_columns_to_types, + ) + + for snapshot in snapshots: + self._snapshot_cache.put(snapshot) + + def unpause_snapshots( + self, + snapshots: t.Collection[SnapshotInfoLike], + unpaused_dt: TimeLike, + interval_state: IntervalState, + ) -> None: + """Unpauses given snapshots while pausing all other snapshots that share the same version. + + Args: + snapshots: The snapshots to unpause. + unpaused_dt: The timestamp to unpause the snapshots at. + interval_state: The interval state to use to remove intervals when needed. + """ + current_ts = now() + + target_snapshot_ids = {s.snapshot_id for s in snapshots} + same_version_snapshots = self._get_snapshots_with_same_version( + snapshots, lock_for_update=True + ) + target_snapshots_by_version = { + (s.name, s.version): s + for s in same_version_snapshots + if s.snapshot_id in target_snapshot_ids + } + + unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list) + paused_snapshots: t.List[SnapshotId] = [] + unrestorable_snapshots: t.List[SnapshotId] = [] + + for snapshot in same_version_snapshots: + is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids + if is_target_snapshot and not snapshot.unpaused_ts: + logger.info("Unpausing snapshot %s", snapshot.snapshot_id) + snapshot.set_unpaused_ts(unpaused_dt) + assert snapshot.unpaused_ts is not None + unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id) + elif not is_target_snapshot: + target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] + if ( + target_snapshot.normalized_effective_from_ts + and not target_snapshot.disable_restatement + ): + # Making sure that there are no overlapping intervals. + effective_from_ts = target_snapshot.normalized_effective_from_ts + logger.info( + "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", + target_snapshot.effective_from, + snapshot.snapshot_id, + target_snapshot.snapshot_id, + ) + full_snapshot = snapshot.full_snapshot + interval_state.remove_intervals( + [ + ( + full_snapshot, + full_snapshot.get_removal_interval(effective_from_ts, current_ts), + ) + ] + ) + + if snapshot.unpaused_ts: + logger.info("Pausing snapshot %s", snapshot.snapshot_id) + snapshot.set_unpaused_ts(None) + paused_snapshots.append(snapshot.snapshot_id) + + if ( + not snapshot.is_forward_only + and target_snapshot.is_forward_only + and not snapshot.unrestorable + ): + logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) + snapshot.unrestorable = True + unrestorable_snapshots.append(snapshot.snapshot_id) + + if unpaused_snapshots: + for unpaused_ts, snapshot_ids in unpaused_snapshots.items(): + self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts) + + if paused_snapshots: + self._update_snapshots(paused_snapshots, unpaused_ts=None) + + if unrestorable_snapshots: + self._update_snapshots(unrestorable_snapshots, unrestorable=True) + + def delete_expired_snapshots( + self, environments: t.Iterable[Environment], ignore_ttl: bool = False + ) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]: + """Deletes expired snapshots. + + Args: + ignore_ttl: Whether to ignore the TTL of the snapshots. + + Returns: + A tuple of expired snapshot IDs and cleanup targets. + """ + current_ts = now_timestamp(minute_floor=False) + + expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) + + if not ignore_ttl: + expired_query = expired_query.where( + (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts + ) + + expired_candidates = { + SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( + name=name, version=version + ) + for name, identifier, version in fetchall(self.engine_adapter, expired_query) + } + if not expired_candidates: + return set(), [] + + promoted_snapshot_ids = { + snapshot.snapshot_id + for environment in environments + for snapshot in environment.snapshots + } + + def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: + return ( + snapshot.snapshot_id in promoted_snapshot_ids + or snapshot.snapshot_id not in expired_candidates + ) + + unique_expired_versions = unique(expired_candidates.values()) + version_batches = create_batches( + unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + cleanup_targets = [] + expired_snapshot_ids = set() + for versions_batch in version_batches: + snapshots = self._get_snapshots_with_same_version(versions_batch) + + snapshots_by_version = defaultdict(set) + snapshots_by_dev_version = defaultdict(set) + for s in snapshots: + snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) + snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) + + expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] + expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) + + for snapshot in expired_snapshots: + shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] + shared_version_snapshots.discard(snapshot.snapshot_id) + + shared_dev_version_snapshots = snapshots_by_dev_version[ + (snapshot.name, snapshot.dev_version) + ] + shared_dev_version_snapshots.discard(snapshot.snapshot_id) + + if not shared_dev_version_snapshots: + cleanup_targets.append( + SnapshotTableCleanupTask( + snapshot=snapshot.full_snapshot.table_info, + dev_table_only=bool(shared_version_snapshots), + ) + ) + + if expired_snapshot_ids: + self.delete_snapshots(expired_snapshot_ids) + return expired_snapshot_ids, cleanup_targets + + def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Deletes snapshots. + + Args: + snapshot_ids: The snapshot IDs to delete. + """ + if not snapshot_ids: + return + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.delete_from(self.snapshots_table, where=where) + + def touch_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Touch snapshots to set their updated_ts to the current timestamp. + + Args: + snapshot_ids: The snapshot IDs to touch. + """ + self._update_snapshots(snapshot_ids) + + def get_snapshots( + self, + snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches snapshots. + + Args: + snapshot_ids: The snapshot IDs to fetch. + + Returns: + A dictionary of snapshot IDs to snapshots. + """ + if snapshot_ids is None: + raise SQLMeshError("Must provide snapshot IDs to fetch snapshots.") + return self._get_snapshots(snapshot_ids) + + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: + """Checks if snapshots exist. + + Args: + snapshot_ids: The snapshot IDs to check. + + Returns: + A set of snapshot IDs to check for existence. + """ + return { + SnapshotId(name=name, identifier=identifier) + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + for name, identifier in fetchall( + self.engine_adapter, + exp.select("name", "identifier").from_(self.snapshots_table).where(where), + ) + } + + def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: + """Checks if nodes with given names exist. + + Args: + names: The node names to check. + exclude_external: Whether to exclude external nodes. + + Returns: + A set of node names that exist. + """ + names = set(names) + + if not names: + return names + + query = ( + exp.select("name") + .from_(self.snapshots_table) + .where(exp.column("name").isin(*names)) + .distinct() + ) + if exclude_external: + query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) + return {name for (name,) in fetchall(self.engine_adapter, query)} + + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + """Updates the auto restatement timestamps. + + Args: + next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp. + """ + for where in snapshot_name_version_filter( + self.engine_adapter, + next_auto_restatement_ts, + column_prefix="snapshot", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.auto_restatements_table, where=where) + + next_auto_restatement_ts_filtered = { + k: v for k, v in next_auto_restatement_ts.items() if v is not None + } + if not next_auto_restatement_ts_filtered: + return + + self.engine_adapter.insert_append( + self.auto_restatements_table, + _auto_restatements_to_df(next_auto_restatement_ts_filtered), + columns_to_types=self._auto_restatement_columns_to_types, + ) + + def count(self) -> int: + """Counts the number of snapshots in the state.""" + result = self.engine_adapter.fetchone(exp.select("COUNT(*)").from_(self.snapshots_table)) + return result[0] if result else 0 + + def clear_cache(self) -> None: + """Clears the snapshot cache.""" + self._snapshot_cache.clear() + + def _update_snapshots( + self, + snapshots: t.Iterable[SnapshotIdLike], + **kwargs: t.Any, + ) -> None: + properties = kwargs + properties["updated_ts"] = now_timestamp() + + for where in snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.update_table( + self.snapshots_table, + properties, + where=where, + ) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + snapshots_to_store = [] + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + columns_to_types=self._snapshot_columns_to_types, + ) + + def _get_snapshots( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches specified snapshots or all snapshots. + + Args: + snapshot_ids: The collection of snapshot like objects to fetch. + lock_for_update: Lock the snapshot rows for future update + + Returns: + A dictionary of snapshot ids to snapshots for ones that could be found. + """ + duplicates: t.Dict[SnapshotId, Snapshot] = {} + + def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: + fetched_snapshots: t.Dict[SnapshotId, Snapshot] = {} + for query in self._get_snapshots_expressions(snapshot_ids_to_load, lock_for_update): + for ( + serialized_snapshot, + _, + _, + _, + updated_ts, + unpaused_ts, + unrestorable, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot = parse_snapshot( + serialized_snapshot=serialized_snapshot, + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + next_auto_restatement_ts=next_auto_restatement_ts, + ) + snapshot_id = snapshot.snapshot_id + if snapshot_id in fetched_snapshots: + other = duplicates.get(snapshot_id, fetched_snapshots[snapshot_id]) + duplicates[snapshot_id] = ( + snapshot if snapshot.updated_ts > other.updated_ts else other + ) + fetched_snapshots[snapshot_id] = duplicates[snapshot_id] + else: + fetched_snapshots[snapshot_id] = snapshot + return fetched_snapshots.values() + + snapshots, cached_snapshots = self._snapshot_cache.get_or_load( + {s.snapshot_id for s in snapshot_ids}, _loader + ) + + if cached_snapshots: + cached_snapshots_in_state: t.Set[SnapshotId] = set() + for where in snapshot_id_filter( + self.engine_adapter, cached_snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "name", + "identifier", + "updated_ts", + "unpaused_ts", + "unrestorable", + "next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + for ( + name, + identifier, + updated_ts, + unpaused_ts, + unrestorable, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot_id = SnapshotId(name=name, identifier=identifier) + snapshot = snapshots[snapshot_id] + snapshot.updated_ts = updated_ts + snapshot.unpaused_ts = unpaused_ts + snapshot.unrestorable = unrestorable + snapshot.next_auto_restatement_ts = next_auto_restatement_ts + cached_snapshots_in_state.add(snapshot_id) + + missing_cached_snapshots = cached_snapshots - cached_snapshots_in_state + for missing_cached_snapshot_id in missing_cached_snapshots: + snapshots.pop(missing_cached_snapshot_id, None) + + if duplicates: + self.push_snapshots(duplicates.values(), overwrite=True) + logger.error("Found duplicate snapshots in the state store.") + + return snapshots + + def _get_snapshots_expressions( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Iterator[exp.Expression]: + for where in snapshot_id_filter( + self.engine_adapter, + snapshot_ids, + alias="snapshots", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + query = ( + exp.select( + "snapshots.snapshot", + "snapshots.name", + "snapshots.identifier", + "snapshots.version", + "snapshots.updated_ts", + "snapshots.unpaused_ts", + "snapshots.unrestorable", + "auto_restatements.next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + yield query + + def _get_snapshots_with_same_version( + self, + snapshots: t.Collection[SnapshotNameVersionLike], + lock_for_update: bool = False, + ) -> t.List[SharedVersionSnapshot]: + """Fetches all snapshots that share the same version as the snapshots. + + The output includes the snapshots with the specified identifiers. + + Args: + snapshots: The collection of target name / version pairs. + lock_for_update: Lock the snapshot rows for future update + + Returns: + The list of Snapshot objects. + """ + if not snapshots: + return [] + + snapshot_rows = [] + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "snapshot", + "name", + "identifier", + "version", + "updated_ts", + "unpaused_ts", + "unrestorable", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + + snapshot_rows.extend(fetchall(self.engine_adapter, query)) + + return [ + SharedVersionSnapshot.from_snapshot_record( + name=name, + identifier=identifier, + version=version, + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + snapshot=snapshot, + ) + for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows + ] + + +def parse_snapshot( + serialized_snapshot: str, + updated_ts: int, + unpaused_ts: t.Optional[int], + unrestorable: bool, + next_auto_restatement_ts: t.Optional[int], +) -> Snapshot: + return Snapshot( + **{ + **json.loads(serialized_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + "next_auto_restatement_ts": next_auto_restatement_ts, + } + ) + + +def _snapshot_to_json(snapshot: Snapshot) -> str: + return snapshot.json( + exclude={ + "intervals", + "dev_intervals", + "pending_restatement_intervals", + "updated_ts", + "unpaused_ts", + "unrestorable", + "next_auto_restatement_ts", + } + ) + + +def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "name": snapshot.name, + "identifier": snapshot.identifier, + "version": snapshot.version, + "snapshot": _snapshot_to_json(snapshot), + "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, + "updated_ts": snapshot.updated_ts, + "unpaused_ts": snapshot.unpaused_ts, + "ttl_ms": snapshot.ttl_ms, + "unrestorable": snapshot.unrestorable, + } + for snapshot in snapshots + ] + ) + + +def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "snapshot_name": name_version.name, + "snapshot_version": name_version.version, + "next_auto_restatement_ts": ts, + } + for name_version, ts in auto_restatements.items() + ] + ) + + +class SharedVersionSnapshot(PydanticModel): + """A stripped down version of a snapshot that is used for fetching snapshots that share the same version + with a significantly reduced parsing overhead. + """ + + name: str + version: str + dev_version_: t.Optional[str] = Field(alias="dev_version") + identifier: str + fingerprint: SnapshotFingerprint + interval_unit: IntervalUnit + change_category: SnapshotChangeCategory + updated_ts: int + unpaused_ts: t.Optional[int] + unrestorable: bool + disable_restatement: bool + effective_from: t.Optional[TimeLike] + raw_snapshot: t.Dict[str, t.Any] + + @property + def snapshot_id(self) -> SnapshotId: + return SnapshotId(name=self.name, identifier=self.identifier) + + @property + def is_forward_only(self) -> bool: + return self.change_category == SnapshotChangeCategory.FORWARD_ONLY + + @property + def normalized_effective_from_ts(self) -> t.Optional[int]: + return ( + to_timestamp(self.interval_unit.cron_floor(self.effective_from)) + if self.effective_from + else None + ) + + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + + @property + def full_snapshot(self) -> Snapshot: + return Snapshot( + **{ + **self.raw_snapshot, + "updated_ts": self.updated_ts, + "unpaused_ts": self.unpaused_ts, + "unrestorable": self.unrestorable, + } + ) + + def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: + """Sets the timestamp for when this snapshot was unpaused. + + Args: + unpaused_dt: The datetime object of when this snapshot was unpaused. + """ + self.unpaused_ts = ( + to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None + ) + + @classmethod + def from_snapshot_record( + cls, + *, + name: str, + identifier: str, + version: str, + updated_ts: int, + unpaused_ts: t.Optional[int], + unrestorable: bool, + snapshot: str, + ) -> SharedVersionSnapshot: + raw_snapshot = json.loads(snapshot) + raw_node = raw_snapshot["node"] + return SharedVersionSnapshot( + name=name, + version=version, + dev_version=raw_snapshot.get("dev_version"), + identifier=identifier, + fingerprint=raw_snapshot["fingerprint"], + interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])), + change_category=raw_snapshot["change_category"], + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), + effective_from=raw_snapshot.get("effective_from"), + raw_snapshot=raw_snapshot, + ) diff --git a/sqlmesh/core/state_sync/db/utils.py b/sqlmesh/core/state_sync/db/utils.py new file mode 100644 index 0000000000..e5ffda6486 --- /dev/null +++ b/sqlmesh/core/state_sync/db/utils.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import typing as t +import logging + +from sqlglot import exp +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.snapshot import SnapshotIdLike, SnapshotNameVersionLike + + +logger = logging.getLogger(__name__) + +try: + # We can't import directly from the root package due to circular dependency + from sqlmesh._version import __version__ as SQLMESH_VERSION # noqa +except ImportError: + logger.error( + 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' + ) + + +T = t.TypeVar("T") + + +def snapshot_id_filter( + engine_adapter: EngineAdapter, + snapshot_ids: t.Iterable[SnapshotIdLike], + batch_size: int, + alias: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + name_identifiers = sorted( + {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} + ) + batches = create_batches(name_identifiers, batch_size=batch_size) + + if not name_identifiers: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for identifiers in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + exp.column("name", table=alias), + exp.column("identifier", table=alias), + ) + ), + ).isin(*identifiers) + else: + for identifiers in batches: + yield exp.or_( + *[ + exp.and_( + exp.column("name", table=alias).eq(name), + exp.column("identifier", table=alias).eq(identifier), + ) + for name, identifier in identifiers + ] + ) + + +def snapshot_name_version_filter( + engine_adapter: EngineAdapter, + snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], + batch_size: int, + version_column_name: str = "version", + alias: t.Optional[str] = "snapshots", + column_prefix: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) + batches = create_batches(name_versions, batch_size=batch_size) + + name_column_name = "name" + if column_prefix: + name_column_name = f"{column_prefix}_{name_column_name}" + version_column_name = f"{column_prefix}_{version_column_name}" + + name_column = exp.column(name_column_name, table=alias) + version_column = exp.column(version_column_name, table=alias) + + if not name_versions: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for versions in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + name_column, + version_column, + ) + ), + ).isin(*versions) + else: + for versions in batches: + yield exp.or_( + *[ + exp.and_( + name_column.eq(name), + version_column.eq(version), + ) + for name, version in versions + ] + ) + + +def create_batches(l: t.List[T], batch_size: int) -> t.List[t.List[T]]: + return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] + + +def fetchone( + engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str] +) -> t.Optional[t.Tuple]: + return engine_adapter.fetchone(query, ignore_unsupported_errors=True, quote_identifiers=True) + + +def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: + return engine_adapter.fetchall(query, ignore_unsupported_errors=True, quote_identifiers=True) diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py new file mode 100644 index 0000000000..8ef6860b92 --- /dev/null +++ b/sqlmesh/core/state_sync/db/version.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import logging +import typing as t + +import pandas as pd +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp +from sqlglot.helper import seq_get + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + fetchone, + SQLMESH_VERSION, +) +from sqlmesh.core.state_sync.base import ( + SCHEMA_VERSION, + Versions, +) +from sqlmesh.utils.migration import index_text_type + +logger = logging.getLogger(__name__) + + +class VersionState: + def __init__(self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None): + self.engine_adapter = engine_adapter + self.versions_table = exp.table_("_versions", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._version_columns_to_types = { + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build(index_type), + "sqlmesh_version": exp.DataType.build(index_type), + } + + def update_versions( + self, + schema_version: int = SCHEMA_VERSION, + sqlglot_version: str = SQLGLOT_VERSION, + sqlmesh_version: str = SQLMESH_VERSION, + ) -> None: + self.engine_adapter.delete_from(self.versions_table, "TRUE") + + self.engine_adapter.insert_append( + self.versions_table, + pd.DataFrame( + [ + { + "schema_version": schema_version, + "sqlglot_version": sqlglot_version, + "sqlmesh_version": sqlmesh_version, + } + ] + ), + columns_to_types=self._version_columns_to_types, + ) + + def get_versions(self) -> Versions: + no_version = Versions() + + if not self.engine_adapter.table_exists(self.versions_table): + return no_version + + query = exp.select("*").from_(self.versions_table) + row = fetchone(self.engine_adapter, query) + if not row: + return no_version + + return Versions( + schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) + ) diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py deleted file mode 100644 index 22a4201c9a..0000000000 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ /dev/null @@ -1,2142 +0,0 @@ -""" -# StateSync - -State sync is how SQLMesh keeps track of environments and their states, e.g. snapshots. - -# StateReader - -StateReader provides a subset of the functionalities of the StateSync class. As its name -implies, it only allows for read-only operations on snapshots and environment states. - -# EngineAdapterStateSync - -The provided `sqlmesh.core.state_sync.EngineAdapterStateSync` leverages an existing engine -adapter to read and write state to the underlying data store. -""" - -from __future__ import annotations - -import contextlib -import json -import logging -import time -import typing as t -from collections import defaultdict -from copy import deepcopy -from pathlib import Path -from datetime import datetime - -import pandas as pd -from pydantic import Field -from sqlglot import __version__ as SQLGLOT_VERSION -from sqlglot import exp -from sqlglot.helper import seq_get - -from sqlmesh.core import analytics -from sqlmesh.core import constants as c -from sqlmesh.core.console import Console, get_console -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import ModelKindName, SeedModel -from sqlmesh.core.node import IntervalUnit -from sqlmesh.core.snapshot import ( - Node, - Snapshot, - SnapshotChangeCategory, - SnapshotFingerprint, - SnapshotId, - SnapshotIdLike, - SnapshotInfoLike, - SnapshotIntervals, - SnapshotNameVersion, - SnapshotNameVersionLike, - SnapshotTableCleanupTask, - SnapshotTableInfo, - fingerprint_from_node, - start_date, -) -from sqlmesh.core.snapshot.cache import SnapshotCache -from sqlmesh.core.snapshot.definition import ( - Interval, - _parents_from_node, -) -from sqlmesh.core.state_sync.base import ( - MIGRATIONS, - SCHEMA_VERSION, - PromotionResult, - StateSync, - Versions, -) -from sqlmesh.core.state_sync.common import transactional -from sqlmesh.utils import major_minor, random_id, unique -from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now, now_timestamp, time_like_to_str, to_timestamp -from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError -from sqlmesh.utils.migration import blob_text_type, index_text_type -from sqlmesh.utils.pydantic import PydanticModel - -logger = logging.getLogger(__name__) - - -T = t.TypeVar("T") - - -if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName - -try: - # We can't import directly from the root package due to circular dependency - from sqlmesh._version import __version__ as SQLMESH_VERSION # type: ignore -except ImportError: - logger.error( - 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' - ) - - -class EngineAdapterStateSync(StateSync): - """Manages state of nodes and snapshot with an existing engine adapter. - - This state sync is convenient to use because it requires no additional setup. - You can reuse the same engine/warehouse that your data is stored in. - - Args: - engine_adapter: The EngineAdapter to use to store and fetch snapshots. - schema: The schema to store state metadata in. If None or empty string then no schema is defined - console: The console to log information to. - context_path: The context path, used for caching snapshot models. - """ - - INTERVAL_BATCH_SIZE = 1000 - SNAPSHOT_BATCH_SIZE = 1000 - SNAPSHOT_MIGRATION_BATCH_SIZE = 500 - - def __init__( - self, - engine_adapter: EngineAdapter, - schema: t.Optional[str], - console: t.Optional[Console] = None, - context_path: Path = Path(), - ): - # Make sure that if an empty string is provided that we treat it as None - self.schema = schema or None - self.engine_adapter = engine_adapter - self.console = console or get_console() - self.snapshots_table = exp.table_("_snapshots", db=self.schema) - self.environments_table = exp.table_("_environments", db=self.schema) - self.intervals_table = exp.table_("_intervals", db=self.schema) - self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) - self.auto_restatements_table = exp.table_("_auto_restatements", db=self.schema) - self.versions_table = exp.table_("_versions", db=self.schema) - - index_type = index_text_type(engine_adapter.dialect) - blob_type = blob_text_type(engine_adapter.dialect) - self._snapshot_columns_to_types = { - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build("text"), - "updated_ts": exp.DataType.build("bigint"), - "unpaused_ts": exp.DataType.build("bigint"), - "ttl_ms": exp.DataType.build("bigint"), - "unrestorable": exp.DataType.build("boolean"), - } - - self._environment_columns_to_types = { - "name": exp.DataType.build(index_type), - "snapshots": exp.DataType.build(blob_type), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - "finalized_ts": exp.DataType.build("bigint"), - "promoted_snapshot_ids": exp.DataType.build(blob_type), - "suffix_target": exp.DataType.build("text"), - "catalog_name_override": exp.DataType.build("text"), - "previous_finalized_snapshots": exp.DataType.build(blob_type), - "normalize_name": exp.DataType.build("boolean"), - "requirements": exp.DataType.build(blob_type), - } - - self._interval_columns_to_types = { - "id": exp.DataType.build(index_type), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build(index_type), - "dev_version": exp.DataType.build(index_type), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - "is_pending_restatement": exp.DataType.build("boolean"), - } - - self._auto_restatement_columns_to_types = { - "snapshot_name": exp.DataType.build(index_type), - "snapshot_version": exp.DataType.build(index_type), - "next_auto_restatement_ts": exp.DataType.build("bigint"), - } - - self._version_columns_to_types = { - "schema_version": exp.DataType.build("int"), - "sqlglot_version": exp.DataType.build(index_type), - "sqlmesh_version": exp.DataType.build(index_type), - } - - self._snapshot_cache = SnapshotCache(context_path / c.CACHE) - - def _fetchone(self, query: t.Union[exp.Expression, str]) -> t.Optional[t.Tuple]: - return self.engine_adapter.fetchone( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - - def _fetchall(self, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: - return self.engine_adapter.fetchall( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - - @transactional() - def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: - """Pushes snapshots to the state store, merging them with existing ones. - - This method first finds all existing snapshots in the store and merges them with - the local snapshots. It will then delete all existing snapshots and then - insert all the local snapshots. This can be made safer with locks or merge/upsert. - - Args: - snapshot_ids: Iterable of snapshot ids to bulk push. - """ - snapshots_by_id = {} - for snapshot in snapshots: - if not snapshot.version: - raise SQLMeshError( - f"Snapshot {snapshot} has not been versioned yet. Create a plan before pushing a snapshot." - ) - snapshots_by_id[snapshot.snapshot_id] = snapshot - - existing = self.snapshots_exist(snapshots_by_id) - - if existing: - logger.error( - "Snapshots %s already exists. This could be due to a concurrent plan or a hash collision. If this is a hash collision, add a stamp to your model.", - str(existing), - ) - - for sid in tuple(snapshots_by_id): - if sid in existing: - snapshots_by_id.pop(sid) - - snapshots = snapshots_by_id.values() - - if snapshots: - self._push_snapshots(snapshots) - for snapshot in snapshots: - self._snapshot_cache.put(snapshot) - - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: - if overwrite: - snapshots = tuple(snapshots) - self.delete_snapshots(snapshots) - - snapshots_to_store = [] - - for snapshot in snapshots: - if isinstance(snapshot.node, SeedModel): - seed_model = t.cast(SeedModel, snapshot.node) - snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) - snapshots_to_store.append(snapshot) - - self.engine_adapter.insert_append( - self.snapshots_table, - _snapshots_to_df(snapshots_to_store), - columns_to_types=self._snapshot_columns_to_types, - ) - - @transactional() - def promote( - self, - environment: Environment, - no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, - ) -> PromotionResult: - """Update the environment to reflect the current state. - - This method verifies that snapshots have been pushed. - - Args: - environment: The environment to promote. - no_gaps_snapshot_names: A set of snapshot names to check for data gaps. If None, - all snapshots will be checked. The data gap check ensures that models that are already a - part of the target environment have no data gaps when compared against previous - snapshots for same models. - - Returns: - A tuple of (added snapshot table infos, removed snapshot table infos, and environment target suffix for the removed table infos) - """ - logger.info("Promoting environment '%s'", environment.name) - - missing = {s.snapshot_id for s in environment.snapshots} - self.snapshots_exist( - environment.snapshots - ) - if missing: - raise SQLMeshError( - f"Missing snapshots {missing}. Make sure to push and backfill your snapshots." - ) - - existing_environment = self._get_environment(environment.name, lock_for_update=True) - - existing_table_infos = ( - {table_info.name: table_info for table_info in existing_environment.promoted_snapshots} - if existing_environment - else {} - ) - table_infos = {table_info.name: table_info for table_info in environment.promoted_snapshots} - views_that_changed_location: t.Set[SnapshotTableInfo] = set() - if existing_environment: - views_that_changed_location = { - existing_table_info - for name, existing_table_info in existing_table_infos.items() - if name in table_infos - and existing_table_info.qualified_view_name.for_environment( - existing_environment.naming_info - ) - != table_infos[name].qualified_view_name.for_environment(environment.naming_info) - } - if not existing_environment.expired: - if environment.previous_plan_id != existing_environment.plan_id: - raise ConflictingPlanError( - f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " - f"Expected previous plan ID: '{environment.previous_plan_id}', actual previous plan ID: '{existing_environment.plan_id}'. " - "Please recreate the plan and try again" - ) - if no_gaps_snapshot_names != set(): - snapshots = self._get_snapshots(environment.snapshots).values() - self._ensure_no_gaps( - snapshots, - existing_environment, - no_gaps_snapshot_names, - ) - demoted_snapshots = set(existing_environment.snapshots) - set(environment.snapshots) - # Update the updated_at attribute. - self._update_snapshots(demoted_snapshots) - - missing_models = set(existing_table_infos) - { - snapshot.name for snapshot in environment.promoted_snapshots - } - - added_table_infos = set(table_infos.values()) - if ( - existing_environment - and existing_environment.finalized_ts - and not existing_environment.expired - ): - # Only promote new snapshots. - added_table_infos -= set(existing_environment.promoted_snapshots) - - self._update_environment(environment) - - removed = {existing_table_infos[name] for name in missing_models}.union( - views_that_changed_location - ) - - return PromotionResult( - added=sorted(added_table_infos), - removed=list(removed), - removed_environment_naming_info=( - existing_environment.naming_info if removed and existing_environment else None - ), - ) - - def _ensure_no_gaps( - self, - target_snapshots: t.Iterable[Snapshot], - target_environment: Environment, - snapshot_names: t.Optional[t.Set[str]], - ) -> None: - target_snapshots_by_name = {s.name: s for s in target_snapshots} - - changed_version_prev_snapshots_by_name = { - s.name: s - for s in target_environment.snapshots - if s.name in target_snapshots_by_name - and target_snapshots_by_name[s.name].version != s.version - } - - prev_snapshots = self._get_snapshots( - changed_version_prev_snapshots_by_name.values() - ).values() - cache: t.Dict[str, datetime] = {} - - for prev_snapshot in prev_snapshots: - target_snapshot = target_snapshots_by_name[prev_snapshot.name] - if ( - (snapshot_names is None or prev_snapshot.name in snapshot_names) - and target_snapshot.is_incremental - and prev_snapshot.is_incremental - and prev_snapshot.intervals - ): - start = to_timestamp( - start_date(target_snapshot, target_snapshots_by_name.values(), cache) - ) - end = prev_snapshot.intervals[-1][1] - - if start < end: - missing_intervals = target_snapshot.missing_intervals( - start, end, end_bounded=True - ) - - if missing_intervals: - raise SQLMeshError( - f"Detected gaps in snapshot {target_snapshot.snapshot_id}: {missing_intervals}" - ) - - @transactional() - def finalize(self, environment: Environment) -> None: - """Finalize the target environment, indicating that this environment has been - fully promoted and is ready for use. - - Args: - environment: The target environment to finalize. - """ - logger.info("Finalizing environment '%s'", environment.name) - - environment_filter = exp.column("name").eq(exp.Literal.string(environment.name)) - - stored_plan_id_query = ( - exp.select("plan_id") - .from_(self.environments_table) - .where(environment_filter, copy=False) - .lock(copy=False) - ) - stored_plan_id_row = self._fetchone(stored_plan_id_query) - - if not stored_plan_id_row: - raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized") - - stored_plan_id = stored_plan_id_row[0] - if stored_plan_id != environment.plan_id: - raise SQLMeshError( - f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " - f"Stored plan ID: '{stored_plan_id}'. Please recreate the plan and try again" - ) - - environment.finalized_ts = now_timestamp() - self.engine_adapter.update_table( - self.environments_table, - {"finalized_ts": environment.finalized_ts}, - where=environment_filter, - ) - - @transactional() - def unpause_snapshots( - self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike - ) -> None: - current_ts = now() - - target_snapshot_ids = {s.snapshot_id for s in snapshots} - same_version_snapshots = self._get_snapshots_with_same_version( - snapshots, lock_for_update=True - ) - target_snapshots_by_version = { - (s.name, s.version): s - for s in same_version_snapshots - if s.snapshot_id in target_snapshot_ids - } - - unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list) - paused_snapshots: t.List[SnapshotId] = [] - unrestorable_snapshots: t.List[SnapshotId] = [] - - for snapshot in same_version_snapshots: - is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids - if is_target_snapshot and not snapshot.unpaused_ts: - logger.info("Unpausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(unpaused_dt) - assert snapshot.unpaused_ts is not None - unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id) - elif not is_target_snapshot: - target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] - if ( - target_snapshot.normalized_effective_from_ts - and not target_snapshot.disable_restatement - ): - # Making sure that there are no overlapping intervals. - effective_from_ts = target_snapshot.normalized_effective_from_ts - logger.info( - "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", - target_snapshot.effective_from, - snapshot.snapshot_id, - target_snapshot.snapshot_id, - ) - full_snapshot = snapshot.full_snapshot - self.remove_intervals( - [ - ( - full_snapshot, - full_snapshot.get_removal_interval(effective_from_ts, current_ts), - ) - ] - ) - - if snapshot.unpaused_ts: - logger.info("Pausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(None) - paused_snapshots.append(snapshot.snapshot_id) - - if ( - not snapshot.is_forward_only - and target_snapshot.is_forward_only - and not snapshot.unrestorable - ): - logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) - snapshot.unrestorable = True - unrestorable_snapshots.append(snapshot.snapshot_id) - - if unpaused_snapshots: - for unpaused_ts, snapshot_ids in unpaused_snapshots.items(): - self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts) - - if paused_snapshots: - self._update_snapshots(paused_snapshots, unpaused_ts=None) - - if unrestorable_snapshots: - self._update_snapshots(unrestorable_snapshots, unrestorable=True) - - def _update_versions( - self, - schema_version: int = SCHEMA_VERSION, - sqlglot_version: str = SQLGLOT_VERSION, - sqlmesh_version: str = SQLMESH_VERSION, - ) -> None: - self.engine_adapter.delete_from(self.versions_table, "TRUE") - - self.engine_adapter.insert_append( - self.versions_table, - pd.DataFrame( - [ - { - "schema_version": schema_version, - "sqlglot_version": sqlglot_version, - "sqlmesh_version": sqlmesh_version, - } - ] - ), - columns_to_types=self._version_columns_to_types, - ) - - def invalidate_environment(self, name: str) -> None: - name = name.lower() - if name == c.PROD: - raise SQLMeshError("Cannot invalidate the production environment.") - - filter_expr = exp.column("name").eq(name) - - self.engine_adapter.update_table( - self.environments_table, - {"expiration_ts": now_timestamp()}, - where=filter_expr, - ) - - @transactional() - def delete_expired_snapshots( - self, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: - current_ts = now_timestamp(minute_floor=False) - - expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) - - if not ignore_ttl: - expired_query = expired_query.where( - (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts - ) - - expired_candidates = { - SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( - name=name, version=version - ) - for name, identifier, version in self._fetchall(expired_query) - } - if not expired_candidates: - return [] - - promoted_snapshot_ids = { - snapshot.snapshot_id - for environment in self.get_environments() - for snapshot in environment.snapshots - } - - def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: - return ( - snapshot.snapshot_id in promoted_snapshot_ids - or snapshot.snapshot_id not in expired_candidates - ) - - unique_expired_versions = unique(expired_candidates.values()) - version_batches = self._batches(unique_expired_versions) - cleanup_targets = [] - expired_snapshot_ids = set() - for versions_batch in version_batches: - snapshots = self._get_snapshots_with_same_version(versions_batch) - - snapshots_by_version = defaultdict(set) - snapshots_by_dev_version = defaultdict(set) - for s in snapshots: - snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) - snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) - - expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] - expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) - - for snapshot in expired_snapshots: - shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] - shared_version_snapshots.discard(snapshot.snapshot_id) - - shared_dev_version_snapshots = snapshots_by_dev_version[ - (snapshot.name, snapshot.dev_version) - ] - shared_dev_version_snapshots.discard(snapshot.snapshot_id) - - if not shared_dev_version_snapshots: - cleanup_targets.append( - SnapshotTableCleanupTask( - snapshot=snapshot.full_snapshot.table_info, - dev_table_only=bool(shared_version_snapshots), - ) - ) - - if expired_snapshot_ids: - self.delete_snapshots(expired_snapshot_ids) - - self._cleanup_intervals(cleanup_targets, expired_snapshot_ids) - - return cleanup_targets - - def delete_expired_environments(self) -> t.List[Environment]: - now_ts = now_timestamp() - filter_expr = exp.LTE( - this=exp.column("expiration_ts"), - expression=exp.Literal.number(now_ts), - ) - - rows = self._fetchall( - self._environments_query( - where=filter_expr, - lock_for_update=True, - ) - ) - environments = [self._environment_from_row(r) for r in rows] - - self.engine_adapter.delete_from( - self.environments_table, - where=filter_expr, - ) - - return environments - - def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: - if not snapshot_ids: - return - for where in self._snapshot_id_filter(snapshot_ids): - self.engine_adapter.delete_from(self.snapshots_table, where=where) - - def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: - return self._snapshot_ids_exist(snapshot_ids, self.snapshots_table) - - def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: - names = set(names) - - if not names: - return names - - query = ( - exp.select("name") - .from_(self.snapshots_table) - .where(exp.column("name").isin(*names)) - .distinct() - ) - if exclude_external: - query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) - return {name for (name,) in self._fetchall(query)} - - def reset(self, default_catalog: t.Optional[str]) -> None: - """Resets the state store to the state when it was first initialized.""" - for table in ( - self.snapshots_table, - self.environments_table, - self.intervals_table, - self.plan_dags_table, - self.versions_table, - ): - self.engine_adapter.drop_table(table) - self._snapshot_cache.clear() - self.migrate(default_catalog) - - @transactional() - def update_auto_restatements( - self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] - ) -> None: - for where in self._snapshot_name_version_filter( - next_auto_restatement_ts, column_prefix="snapshot", alias=None - ): - self.engine_adapter.delete_from(self.auto_restatements_table, where=where) - - next_auto_restatement_ts_filtered = { - k: v for k, v in next_auto_restatement_ts.items() if v is not None - } - if not next_auto_restatement_ts_filtered: - return - - self.engine_adapter.insert_append( - self.auto_restatements_table, - _auto_restatements_to_df(next_auto_restatement_ts_filtered), - columns_to_types=self._auto_restatement_columns_to_types, - ) - - def _update_environment(self, environment: Environment) -> None: - self.engine_adapter.delete_from( - self.environments_table, - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment.name), - ), - ) - - self.engine_adapter.insert_append( - self.environments_table, - _environment_to_df(environment), - columns_to_types=self._environment_columns_to_types, - ) - - def _update_snapshots( - self, - snapshots: t.Iterable[SnapshotIdLike], - **kwargs: t.Any, - ) -> None: - properties = kwargs - properties["updated_ts"] = now_timestamp() - - for where in self._snapshot_id_filter(snapshots): - self.engine_adapter.update_table( - self.snapshots_table, - properties, - where=where, - ) - - def get_environment(self, environment: str) -> t.Optional[Environment]: - return self._get_environment(environment) - - def get_environments(self) -> t.List[Environment]: - """Fetches all environments. - - Returns: - A list of all environments. - """ - return [ - self._environment_from_row(row) for row in self._fetchall(self._environments_query()) - ] - - def get_environments_summary(self) -> t.Dict[str, int]: - """Fetches all environment names along with expiry datetime. - - Returns: - A dict of all environment names along with expiry datetime. - """ - return dict( - self._fetchall(self._environments_query(required_fields=["name", "expiration_ts"])), - ) - - def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: - return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) - - def _environments_query( - self, - where: t.Optional[str | exp.Expression] = None, - lock_for_update: bool = False, - required_fields: t.Optional[t.List[str]] = None, - ) -> exp.Select: - query_fields = required_fields if required_fields else Environment.all_fields() - query = ( - exp.select(*(exp.to_identifier(field) for field in query_fields)) - .from_(self.environments_table) - .where(where) - ) - if lock_for_update: - return query.lock(copy=False) - return query - - def get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], - ) -> t.Dict[SnapshotId, Snapshot]: - if snapshot_ids is None: - raise SQLMeshError("Must provide snapshot IDs to fetch snapshots.") - return self._get_snapshots(snapshot_ids) - - def _get_snapshots( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - lock_for_update: bool = False, - hydrate_intervals: bool = True, - ) -> t.Dict[SnapshotId, Snapshot]: - """Fetches specified snapshots or all snapshots. - - Args: - snapshot_ids: The collection of snapshot like objects to fetch. - lock_for_update: Lock the snapshot rows for future update - hydrate_intervals: Whether to hydrate result snapshots with intervals. - - Returns: - A dictionary of snapshot ids to snapshots for ones that could be found. - """ - duplicates: t.Dict[SnapshotId, Snapshot] = {} - - def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: - fetched_snapshots: t.Dict[SnapshotId, Snapshot] = {} - for query in self._get_snapshots_expressions(snapshot_ids_to_load, lock_for_update): - for ( - serialized_snapshot, - _, - _, - _, - updated_ts, - unpaused_ts, - unrestorable, - next_auto_restatement_ts, - ) in self._fetchall(query): - snapshot = parse_snapshot( - serialized_snapshot=serialized_snapshot, - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - next_auto_restatement_ts=next_auto_restatement_ts, - ) - snapshot_id = snapshot.snapshot_id - if snapshot_id in fetched_snapshots: - other = duplicates.get(snapshot_id, fetched_snapshots[snapshot_id]) - duplicates[snapshot_id] = ( - snapshot if snapshot.updated_ts > other.updated_ts else other - ) - fetched_snapshots[snapshot_id] = duplicates[snapshot_id] - else: - fetched_snapshots[snapshot_id] = snapshot - return fetched_snapshots.values() - - snapshots, cached_snapshots = self._snapshot_cache.get_or_load( - {s.snapshot_id for s in snapshot_ids}, _loader - ) - - if cached_snapshots: - cached_snapshots_in_state: t.Set[SnapshotId] = set() - for where in self._snapshot_id_filter(cached_snapshots): - query = ( - exp.select( - "name", - "identifier", - "updated_ts", - "unpaused_ts", - "unrestorable", - "next_auto_restatement_ts", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .join( - exp.to_table(self.auto_restatements_table).as_("auto_restatements"), - on=exp.and_( - exp.column("name", table="snapshots").eq( - exp.column("snapshot_name", table="auto_restatements") - ), - exp.column("version", table="snapshots").eq( - exp.column("snapshot_version", table="auto_restatements") - ), - ), - join_type="left", - copy=False, - ) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - for ( - name, - identifier, - updated_ts, - unpaused_ts, - unrestorable, - next_auto_restatement_ts, - ) in self._fetchall(query): - snapshot_id = SnapshotId(name=name, identifier=identifier) - snapshot = snapshots[snapshot_id] - snapshot.updated_ts = updated_ts - snapshot.unpaused_ts = unpaused_ts - snapshot.unrestorable = unrestorable - snapshot.next_auto_restatement_ts = next_auto_restatement_ts - cached_snapshots_in_state.add(snapshot_id) - - missing_cached_snapshots = cached_snapshots - cached_snapshots_in_state - for missing_cached_snapshot_id in missing_cached_snapshots: - snapshots.pop(missing_cached_snapshot_id, None) - - if snapshots and hydrate_intervals: - _, intervals = self._get_snapshot_intervals(snapshots.values()) - Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) - - if duplicates: - self._push_snapshots(duplicates.values(), overwrite=True) - logger.error("Found duplicate snapshots in the state store.") - - return snapshots - - def _get_snapshots_expressions( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - lock_for_update: bool = False, - batch_size: t.Optional[int] = None, - ) -> t.Iterator[exp.Expression]: - for where in self._snapshot_id_filter( - snapshot_ids, alias="snapshots", batch_size=batch_size - ): - query = ( - exp.select( - "snapshots.snapshot", - "snapshots.name", - "snapshots.identifier", - "snapshots.version", - "snapshots.updated_ts", - "snapshots.unpaused_ts", - "snapshots.unrestorable", - "auto_restatements.next_auto_restatement_ts", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .join( - exp.to_table(self.auto_restatements_table).as_("auto_restatements"), - on=exp.and_( - exp.column("name", table="snapshots").eq( - exp.column("snapshot_name", table="auto_restatements") - ), - exp.column("version", table="snapshots").eq( - exp.column("snapshot_version", table="auto_restatements") - ), - ), - join_type="left", - copy=False, - ) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - yield query - - def _get_snapshots_with_same_version( - self, - snapshots: t.Collection[SnapshotNameVersionLike], - lock_for_update: bool = False, - ) -> t.List[SharedVersionSnapshot]: - """Fetches all snapshots that share the same version as the snapshots. - - The output includes the snapshots with the specified identifiers. - - Args: - snapshots: The collection of target name / version pairs. - lock_for_update: Lock the snapshot rows for future update - - Returns: - The list of Snapshot objects. - """ - if not snapshots: - return [] - - snapshot_rows = [] - - for where in self._snapshot_name_version_filter(snapshots): - query = ( - exp.select( - "snapshot", - "name", - "identifier", - "version", - "updated_ts", - "unpaused_ts", - "unrestorable", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - - snapshot_rows.extend(self._fetchall(query)) - - return [ - SharedVersionSnapshot.from_snapshot_record( - name=name, - identifier=identifier, - version=version, - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - snapshot=snapshot, - ) - for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows - ] - - def _get_versions(self, lock_for_update: bool = False) -> Versions: - no_version = Versions() - - if not self.engine_adapter.table_exists(self.versions_table): - return no_version - - query = exp.select("*").from_(self.versions_table) - if lock_for_update: - query.lock(copy=False) - - row = self._fetchone(query) - if not row: - return no_version - - return Versions( - schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) - ) - - def _get_environment( - self, environment: str, lock_for_update: bool = False - ) -> t.Optional[Environment]: - """Fetches the environment if it exists. - - Args: - environment: The environment - lock_for_update: Lock the snapshot rows for future update - - Returns: - The environment object. - """ - row = self._fetchone( - self._environments_query( - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment), - ), - lock_for_update=lock_for_update, - ) - ) - - if not row: - return None - - env = self._environment_from_row(row) - return env - - @transactional() - def add_interval( - self, - snapshot: Snapshot, - start: TimeLike, - end: TimeLike, - is_dev: bool = False, - ) -> None: - super().add_interval(snapshot, start, end, is_dev) - - @transactional() - def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: - def remove_partial_intervals( - intervals: t.List[Interval], snapshot_id: t.Optional[SnapshotId], *, is_dev: bool - ) -> t.List[Interval]: - results = [] - for start_ts, end_ts in intervals: - if start_ts < end_ts: - logger.info( - "Adding %s (%s, %s) for snapshot %s", - "dev interval" if is_dev else "interval", - time_like_to_str(start_ts), - time_like_to_str(end_ts), - snapshot_id, - ) - results.append((start_ts, end_ts)) - else: - logger.info( - "Skipping partial interval (%s, %s) for snapshot %s", - start_ts, - end_ts, - snapshot_id, - ) - return results - - intervals_to_insert = [] - for snapshot_intervals in snapshots_intervals: - snapshot_intervals = snapshot_intervals.copy( - update={ - "intervals": remove_partial_intervals( - snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False - ), - "dev_intervals": remove_partial_intervals( - snapshot_intervals.dev_intervals, - snapshot_intervals.snapshot_id, - is_dev=True, - ), - } - ) - if ( - snapshot_intervals.intervals - or snapshot_intervals.dev_intervals - or snapshot_intervals.pending_restatement_intervals - ): - intervals_to_insert.append(snapshot_intervals) - - if intervals_to_insert: - self._push_snapshot_intervals(intervals_to_insert) - - @transactional() - def remove_intervals( - self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], - remove_shared_versions: bool = False, - ) -> None: - intervals_to_remove: t.Sequence[ - t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] - ] = snapshot_intervals - if remove_shared_versions: - name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} - all_snapshots = [] - for where in self._snapshot_name_version_filter(name_version_mapping, alias=None): - all_snapshots.extend( - [ - SnapshotIntervals( - name=r[0], - identifier=r[1], - version=r[2], - dev_version=r[3], - intervals=[], - dev_intervals=[], - ) - for r in self._fetchall( - exp.select("name", "identifier", "version", "dev_version") - .from_(self.intervals_table) - .where(where) - .distinct() - ) - ] - ) - intervals_to_remove = [ - (snapshot, name_version_mapping[snapshot.name_version]) - for snapshot in all_snapshots - ] - - if logger.isEnabledFor(logging.INFO): - snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) - logger.info("Removing interval for snapshots: %s", snapshot_ids) - - for is_dev in (True, False): - self.engine_adapter.insert_append( - self.intervals_table, - _intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True), - columns_to_types=self._interval_columns_to_types, - ) - - @transactional() - def compact_intervals(self) -> None: - interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) - - logger.info( - "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) - ) - - self._push_snapshot_intervals(snapshot_intervals, is_compacted=True) - - if interval_ids: - for interval_id_batch in self._batches( - list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE - ): - self.engine_adapter.delete_from( - self.intervals_table, exp.column("id").isin(*interval_id_batch) - ) - - def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: - if not snapshots: - return [] - - _, intervals = self._get_snapshot_intervals(snapshots) - for s in snapshots: - s.intervals = [] - s.dev_intervals = [] - return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) - - def max_interval_end_per_model( - self, - environment: str, - models: t.Optional[t.Set[str]] = None, - ensure_finalized_snapshots: bool = False, - ) -> t.Dict[str, int]: - env = self._get_environment(environment) - if not env: - return {} - - snapshots = ( - env.snapshots if not ensure_finalized_snapshots else env.finalized_or_current_snapshots - ) - if models is not None: - snapshots = [s for s in snapshots if s.name in models] - - if not snapshots: - return {} - - table_alias = "intervals" - name_col = exp.column("name", table=table_alias) - version_col = exp.column("version", table=table_alias) - - result: t.Dict[str, int] = {} - - for where in self._snapshot_name_version_filter(snapshots, alias=table_alias): - query = ( - exp.select( - name_col, - exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), - ) - .from_(exp.to_table(self.intervals_table).as_(table_alias)) - .where(where, copy=False) - .where( - exp.and_( - exp.to_column("is_dev").not_(), - exp.to_column("is_removed").not_(), - exp.to_column("is_pending_restatement").not_(), - ), - copy=False, - ) - .group_by(name_col, version_col, copy=False) - ) - - for name, max_end in self._fetchall(query): - result[name] = max_end - - return result - - def recycle(self) -> None: - self.engine_adapter.recycle() - - def close(self) -> None: - self.engine_adapter.close() - - def _cleanup_intervals( - self, - cleanup_targets: t.List[SnapshotTableCleanupTask], - expired_snapshot_ids: t.Collection[SnapshotIdLike], - ) -> None: - # Cleanup can only happen for compacted intervals - self.compact_intervals() - # Delete intervals for non-dev tables that are no longer used - self._delete_intervals_by_version(cleanup_targets) - # Delete dev intervals for dev tables that are no longer used - self._delete_intervals_by_dev_version(cleanup_targets) - # Nullify the snapshot identifiers of interval records for snapshots that have been deleted - self._update_intervals_for_deleted_snapshots(expired_snapshot_ids) - - def _update_intervals_for_deleted_snapshots( - self, snapshot_ids: t.Collection[SnapshotIdLike] - ) -> None: - """Nullifies the snapshot identifiers of dev interval records and snapshot identifiers and dev versions of - non-dev interval records for snapshots that have been deleted so that they can be compacted efficiently. - """ - if not snapshot_ids: - return - - for where in self._snapshot_id_filter(snapshot_ids, alias=None): - # Nullify the identifier for dev intervals - # Set is_compacted to False so that it's compacted during the next compaction - self.engine_adapter.update_table( - self.intervals_table, - {"identifier": None, "is_compacted": False}, - where=where.and_(exp.column("is_dev")), - ) - # Nullify both identifier and dev version for non-dev intervals - # Set is_compacted to False so that it's compacted during the next compaction - self.engine_adapter.update_table( - self.intervals_table, - {"identifier": None, "dev_version": None, "is_compacted": False}, - where=where.and_(exp.column("is_dev").not_()), - ) - - def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: - """Deletes dev intervals for snapshot dev versions that are no longer used.""" - dev_keys_to_delete = [ - SnapshotNameVersion(name=t.snapshot.name, version=t.snapshot.dev_version) - for t in targets - if t.dev_table_only - ] - if not dev_keys_to_delete: - return - - for where in self._snapshot_name_version_filter( - dev_keys_to_delete, version_column_name="dev_version", alias=None - ): - self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) - - def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: - """Deletes intervals for snapshot versions that are no longer used.""" - non_dev_keys_to_delete = [t.snapshot for t in targets if not t.dev_table_only] - if not non_dev_keys_to_delete: - return - - for where in self._snapshot_name_version_filter(non_dev_keys_to_delete, alias=None): - self.engine_adapter.delete_from(self.intervals_table, where) - - def _get_snapshot_intervals( - self, - snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, - uncompacted_only: bool = False, - ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: - if not snapshots and snapshots is not None: - return (set(), []) - - query = self._get_snapshot_intervals_query(uncompacted_only) - - interval_ids: t.Set[str] = set() - intervals: t.Dict[ - t.Tuple[str, str, t.Optional[str], t.Optional[str]], SnapshotIntervals - ] = {} - - for where in ( - self._snapshot_name_version_filter(snapshots, alias="intervals") - if snapshots - else [None] - ): - rows = self._fetchall(query.where(where)) - for ( - interval_id, - name, - identifier, - version, - dev_version, - start, - end, - is_dev, - is_removed, - is_pending_restatement, - ) in rows: - interval_ids.add(interval_id) - merge_key = (name, version, dev_version, identifier) - # Pending restatement intervals are merged by name and version - pending_restatement_interval_merge_key = (name, version, None, None) - - if merge_key not in intervals: - intervals[merge_key] = SnapshotIntervals( - name=name, - identifier=identifier, - version=version, - dev_version=dev_version, - ) - - if pending_restatement_interval_merge_key not in intervals: - intervals[pending_restatement_interval_merge_key] = SnapshotIntervals( - name=name, - identifier=None, - version=version, - dev_version=None, - ) - - if is_removed: - if is_dev: - intervals[merge_key].remove_dev_interval(start, end) - else: - intervals[merge_key].remove_interval(start, end) - elif is_pending_restatement: - intervals[ - pending_restatement_interval_merge_key - ].add_pending_restatement_interval(start, end) - else: - if is_dev: - intervals[merge_key].add_dev_interval(start, end) - else: - intervals[merge_key].add_interval(start, end) - # Remove all pending restatement intervals recorded before the current interval has been added - intervals[ - pending_restatement_interval_merge_key - ].remove_pending_restatement_interval(start, end) - - return interval_ids, [i for i in intervals.values() if not i.is_empty()] - - def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: - query = ( - exp.select( - "id", - exp.column("name", table="intervals"), - exp.column("identifier", table="intervals"), - exp.column("version", table="intervals"), - exp.column("dev_version", table="intervals"), - "start_ts", - "end_ts", - "is_dev", - "is_removed", - "is_pending_restatement", - ) - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .order_by( - exp.column("name", table="intervals"), - exp.column("version", table="intervals"), - "created_ts", - "is_removed", - "is_pending_restatement", - ) - ) - if uncompacted_only: - query.join( - exp.select("name", "version") - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .where(exp.column("is_compacted").not_()) - .distinct() - .subquery(alias="uncompacted"), - on=exp.and_( - exp.column("name", table="intervals").eq( - exp.column("name", table="uncompacted") - ), - exp.column("version", table="intervals").eq( - exp.column("version", table="uncompacted") - ), - ), - copy=False, - ) - return query - - def _push_snapshot_intervals( - self, - snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]], - is_compacted: bool = False, - ) -> None: - new_intervals = [] - for snapshot in snapshots: - logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) - for start_ts, end_ts in snapshot.intervals: - new_intervals.append( - _interval_to_df( - snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted - ) - ) - for start_ts, end_ts in snapshot.dev_intervals: - new_intervals.append( - _interval_to_df( - snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted - ) - ) - - # Make sure that all pending restatement intervals are recorded last - for snapshot in snapshots: - for start_ts, end_ts in snapshot.pending_restatement_intervals: - new_intervals.append( - _interval_to_df( - snapshot, - start_ts, - end_ts, - is_dev=False, - is_compacted=is_compacted, - is_pending_restatement=True, - ) - ) - - if new_intervals: - self.engine_adapter.insert_append( - self.intervals_table, - pd.DataFrame(new_intervals), - columns_to_types=self._interval_columns_to_types, - ) - - def _restore_table( - self, - table_name: TableName, - backup_table_name: TableName, - ) -> None: - self.engine_adapter.drop_table(table_name) - self.engine_adapter.rename_table( - old_table_name=backup_table_name, - new_table_name=table_name, - ) - - @transactional() - def migrate( - self, - default_catalog: t.Optional[str], - skip_backup: bool = False, - promoted_snapshots_only: bool = True, - ) -> None: - """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" - versions = self.get_versions(validate=False) - - migration_start_ts = time.perf_counter() - - try: - migrate_rows = self._apply_migrations(default_catalog, skip_backup) - - if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: - return - - if migrate_rows: - self._migrate_rows(promoted_snapshots_only) - # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. - self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") - self._update_versions() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - ) - except Exception as e: - if skip_backup: - logger.error("Backup was skipped so no rollback was attempted.") - else: - self.rollback() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - error=e, - ) - - self.console.log_migration_status(success=False) - raise SQLMeshError("SQLMesh migration failed.") from e - - self.console.log_migration_status() - - @transactional() - def rollback(self) -> None: - """Rollback to the previous migration.""" - logger.info("Starting migration rollback.") - tables = (self.snapshots_table, self.environments_table, self.versions_table) - optional_tables = (self.intervals_table, self.plan_dags_table, self.auto_restatements_table) - versions = self.get_versions(validate=False) - if versions.schema_version == 0: - # Clean up state tables - for table in tables + optional_tables: - self.engine_adapter.drop_table(table) - else: - if not all( - self.engine_adapter.table_exists(_backup_table_name(table)) for table in tables - ): - raise SQLMeshError("There are no prior migrations to roll back to.") - for table in tables: - self._restore_table(table, _backup_table_name(table)) - - for optional_table in optional_tables: - if self.engine_adapter.table_exists(_backup_table_name(optional_table)): - self._restore_table(optional_table, _backup_table_name(optional_table)) - - logger.info("Migration rollback successful.") - - def state_type(self) -> str: - return self.engine_adapter.dialect - - @transactional() - def _backup_state(self) -> None: - for table in ( - self.snapshots_table, - self.environments_table, - self.versions_table, - self.intervals_table, - self.plan_dags_table, - self.auto_restatements_table, - ): - if self.engine_adapter.table_exists(table): - backup_name = _backup_table_name(table) - self.engine_adapter.drop_table(backup_name) - self.engine_adapter.create_table_like(backup_name, table) - self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) - - def _snapshot_count(self) -> int: - result = self._fetchone(exp.select("COUNT(*)").from_(self.snapshots_table)) - return result[0] if result else 0 - - def _apply_migrations( - self, - default_catalog: t.Optional[str], - skip_backup: bool, - ) -> bool: - versions = self.get_versions(validate=False) - migrations = MIGRATIONS[versions.schema_version :] - should_backup = any( - [ - migrations, - major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version, - major_minor(SQLMESH_VERSION) != versions.minor_sqlmesh_version, - ] - ) - if not skip_backup and should_backup: - self._backup_state() - - snapshot_count_before = self._snapshot_count() if versions.schema_version else None - - for migration in migrations: - logger.info(f"Applying migration {migration}") - migration.migrate(self, default_catalog=default_catalog) - - snapshot_count_after = self._snapshot_count() - - if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: - scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" - raise SQLMeshError( - f"Number of snapshots before ({snapshot_count_before}) and after " - f"({snapshot_count_after}) applying migration scripts {scripts} does not match. " - "Please file an issue issue at https://github.com/TobikoData/sqlmesh/issues/new." - ) - - migrate_snapshots_and_environments = ( - bool(migrations) or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version - ) - return migrate_snapshots_and_environments - - def _migrate_rows(self, promoted_snapshots_only: bool) -> None: - logger.info("Fetching environments") - environments = self.get_environments() - # Only migrate snapshots that are part of at least one environment. - snapshots_to_migrate = ( - {s.snapshot_id for e in environments for s in e.snapshots} - if promoted_snapshots_only - else None - ) - snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) - if not snapshot_mapping: - logger.info("No changes to snapshots detected") - return - self._migrate_environment_rows(environments, snapshot_mapping) - - def _migrate_snapshot_rows( - self, snapshots: t.Optional[t.Set[SnapshotId]] - ) -> t.Dict[SnapshotId, SnapshotTableInfo]: - logger.info("Migrating snapshot rows...") - raw_snapshots = { - SnapshotId(name=name, identifier=identifier): { - **json.loads(raw_snapshot), - "updated_ts": updated_ts, - "unpaused_ts": unpaused_ts, - "unrestorable": unrestorable, - } - for where in (self._snapshot_id_filter(snapshots) if snapshots is not None else [None]) - for name, identifier, raw_snapshot, updated_ts, unpaused_ts, unrestorable in self._fetchall( - exp.select( - "name", "identifier", "snapshot", "updated_ts", "unpaused_ts", "unrestorable" - ) - .from_(self.snapshots_table) - .where(where) - .lock() - ) - } - if not raw_snapshots: - return {} - - dag: DAG[SnapshotId] = DAG() - for snapshot_id, raw_snapshot in raw_snapshots.items(): - parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] - dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) - - reversed_dag_raw = dag.reversed.graph - - self.console.start_snapshot_migration_progress(len(raw_snapshots)) - - parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) - all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} - snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} - new_snapshots: t.Dict[SnapshotId, Snapshot] = {} - visited: t.Set[SnapshotId] = set() - - def _push_new_snapshots() -> None: - all_snapshot_mapping.update( - { - from_id: new_snapshots[to_id].table_info - for from_id, to_id in snapshot_id_mapping.items() - } - ) - - existing_new_snapshots = self.snapshots_exist(new_snapshots) - new_snapshots_to_push = [ - s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots - ] - if new_snapshots_to_push: - logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) - self._push_snapshots(new_snapshots_to_push) - new_snapshots.clear() - snapshot_id_mapping.clear() - - def _visit( - snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] - ) -> None: - if snapshot_id in visited or snapshot_id not in raw_snapshots: - return - visited.add(snapshot_id) - - snapshot = parsed_snapshots[snapshot_id] - node = snapshot.node - - node_seen = set() - node_queue = {snapshot_id} - nodes: t.Dict[str, Node] = {} - while node_queue: - next_snapshot_id = node_queue.pop() - next_snapshot = parsed_snapshots.get(next_snapshot_id) - - if next_snapshot_id in node_seen or not next_snapshot: - continue - - node_seen.add(next_snapshot_id) - node_queue.update(next_snapshot.parents) - nodes[next_snapshot.name] = next_snapshot.node - - new_snapshot = deepcopy(snapshot) - try: - new_snapshot.fingerprint = fingerprint_from_node( - node, - nodes=nodes, - cache=fingerprint_cache, - ) - new_snapshot.parents = tuple( - SnapshotId( - name=parent_node.fqn, - identifier=fingerprint_from_node( - parent_node, - nodes=nodes, - cache=fingerprint_cache, - ).to_identifier(), - ) - for parent_node in _parents_from_node(node, nodes).values() - ) - except Exception: - logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) - return - - # Reset the effective_from date for the new snapshot to avoid unexpected backfills. - new_snapshot.effective_from = None - new_snapshot.previous_versions = snapshot.all_versions - new_snapshot.migrated = True - if not new_snapshot.dev_version_: - new_snapshot.dev_version_ = snapshot.dev_version - - self.console.update_snapshot_migration_progress(1) - - # Visit children and evict them from the parsed_snapshots cache after. - for child in reversed_dag_raw.get(snapshot_id, []): - # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. - _visit(child, fingerprint_cache.copy()) - parsed_snapshots.evict(child) - - if new_snapshot.fingerprint == snapshot.fingerprint: - logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") - return - - new_snapshot_id = new_snapshot.snapshot_id - - if new_snapshot_id in raw_snapshots: - # Mapped to an existing snapshot. - new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] - logger.debug("Migrated snapshot %s already exists", new_snapshot_id) - elif ( - new_snapshot_id not in new_snapshots - or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts - ): - new_snapshots[new_snapshot_id] = new_snapshot - - snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id - logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) - - if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: - _push_new_snapshots() - - for root_snapshot_id in dag.roots: - _visit(root_snapshot_id, {}) - - if new_snapshots: - _push_new_snapshots() - - self.console.stop_snapshot_migration_progress() - return all_snapshot_mapping - - def _migrate_environment_rows( - self, - environments: t.List[Environment], - snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], - ) -> None: - logger.info("Migrating environment rows...") - - updated_prod_environment: t.Optional[Environment] = None - updated_environments = [] - for environment in environments: - snapshots = [ - ( - snapshot_mapping[info.snapshot_id] - if info.snapshot_id in snapshot_mapping - else info - ) - for info in environment.snapshots - ] - - if snapshots != environment.snapshots: - environment.snapshots_ = snapshots - updated_environments.append(environment) - if environment.name == c.PROD: - updated_prod_environment = environment - self.console.start_env_migration_progress(len(updated_environments)) - - for environment in updated_environments: - self._update_environment(environment) - self.console.update_env_migration_progress(1) - - if updated_prod_environment: - try: - self.unpause_snapshots(updated_prod_environment.snapshots, now_timestamp()) - except Exception: - logger.warning("Failed to unpause migrated snapshots", exc_info=True) - - self.console.stop_env_migration_progress() - - def _snapshot_ids_exist( - self, snapshot_ids: t.Iterable[SnapshotIdLike], table_name: exp.Table - ) -> t.Set[SnapshotId]: - return { - SnapshotId(name=name, identifier=identifier) - for where in self._snapshot_id_filter(snapshot_ids) - for name, identifier in self._fetchall( - exp.select("name", "identifier").from_(table_name).where(where) - ) - } - - def _snapshot_id_filter( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - alias: t.Optional[str] = None, - batch_size: t.Optional[int] = None, - ) -> t.Iterator[exp.Condition]: - name_identifiers = sorted( - {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} - ) - batches = self._batches(name_identifiers, batch_size=batch_size) - - if not name_identifiers: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for identifiers in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - exp.column("name", table=alias), - exp.column("identifier", table=alias), - ) - ), - ).isin(*identifiers) - else: - for identifiers in batches: - yield exp.or_( - *[ - exp.and_( - exp.column("name", table=alias).eq(name), - exp.column("identifier", table=alias).eq(identifier), - ) - for name, identifier in identifiers - ] - ) - - def _snapshot_name_version_filter( - self, - snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], - version_column_name: str = "version", - alias: t.Optional[str] = "snapshots", - column_prefix: t.Optional[str] = None, - ) -> t.Iterator[exp.Condition]: - name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) - batches = self._batches(name_versions) - - name_column_name = "name" - if column_prefix: - name_column_name = f"{column_prefix}_{name_column_name}" - version_column_name = f"{column_prefix}_{version_column_name}" - - name_column = exp.column(name_column_name, table=alias) - version_column = exp.column(version_column_name, table=alias) - - if not name_versions: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for versions in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - name_column, - version_column, - ) - ), - ).isin(*versions) - else: - for versions in batches: - yield exp.or_( - *[ - exp.and_( - name_column.eq(name), - version_column.eq(version), - ) - for name, version in versions - ] - ) - - def _batches(self, l: t.List[T], batch_size: t.Optional[int] = None) -> t.List[t.List[T]]: - batch_size = batch_size or self.SNAPSHOT_BATCH_SIZE - return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] - - @contextlib.contextmanager - def _transaction(self) -> t.Iterator[None]: - with self.engine_adapter.transaction(): - yield - - -def _intervals_to_df( - snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], - is_dev: bool, - is_removed: bool, -) -> pd.DataFrame: - return pd.DataFrame( - [ - _interval_to_df( - s, - *interval, - is_dev=is_dev, - is_removed=is_removed, - ) - for s, interval in snapshot_intervals - ] - ) - - -def _interval_to_df( - snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], - start_ts: int, - end_ts: int, - is_dev: bool = False, - is_removed: bool = False, - is_compacted: bool = False, - is_pending_restatement: bool = False, -) -> t.Dict[str, t.Any]: - return { - "id": random_id(), - "created_ts": now_timestamp(), - "name": snapshot.name, - "identifier": snapshot.identifier if not is_pending_restatement else None, - "version": snapshot.version, - "dev_version": snapshot.dev_version if not is_pending_restatement else None, - "start_ts": start_ts, - "end_ts": end_ts, - "is_dev": is_dev, - "is_removed": is_removed, - "is_compacted": is_compacted, - "is_pending_restatement": is_pending_restatement, - } - - -def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": snapshot.name, - "identifier": snapshot.identifier, - "version": snapshot.version, - "snapshot": _snapshot_to_json(snapshot), - "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, - "updated_ts": snapshot.updated_ts, - "unpaused_ts": snapshot.unpaused_ts, - "ttl_ms": snapshot.ttl_ms, - "unrestorable": snapshot.unrestorable, - } - for snapshot in snapshots - ] - ) - - -def _environment_to_df(environment: Environment) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": environment.name, - "snapshots": json.dumps(environment.snapshot_dicts()), - "start_at": time_like_to_str(environment.start_at), - "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, - "plan_id": environment.plan_id, - "previous_plan_id": environment.previous_plan_id, - "expiration_ts": environment.expiration_ts, - "finalized_ts": environment.finalized_ts, - "promoted_snapshot_ids": ( - json.dumps(environment.promoted_snapshot_id_dicts()) - if environment.promoted_snapshot_ids is not None - else None - ), - "suffix_target": environment.suffix_target.value, - "catalog_name_override": environment.catalog_name_override, - "previous_finalized_snapshots": ( - json.dumps(environment.previous_finalized_snapshot_dicts()) - if environment.previous_finalized_snapshots is not None - else None - ), - "normalize_name": environment.normalize_name, - "requirements": json.dumps(environment.requirements), - } - ] - ) - - -def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "snapshot_name": name_version.name, - "snapshot_version": name_version.version, - "next_auto_restatement_ts": ts, - } - for name_version, ts in auto_restatements.items() - ] - ) - - -def _backup_table_name(table_name: TableName) -> exp.Table: - table = exp.to_table(table_name).copy() - table.set("this", exp.to_identifier(table.name + "_backup")) - return table - - -def _snapshot_to_json(snapshot: Snapshot) -> str: - return snapshot.json( - exclude={ - "intervals", - "dev_intervals", - "pending_restatement_intervals", - "updated_ts", - "unpaused_ts", - "unrestorable", - "next_auto_restatement_ts", - } - ) - - -def parse_snapshot( - serialized_snapshot: str, - updated_ts: int, - unpaused_ts: t.Optional[int], - unrestorable: bool, - next_auto_restatement_ts: t.Optional[int], -) -> Snapshot: - return Snapshot( - **{ - **json.loads(serialized_snapshot), - "updated_ts": updated_ts, - "unpaused_ts": unpaused_ts, - "unrestorable": unrestorable, - "next_auto_restatement_ts": next_auto_restatement_ts, - } - ) - - -class LazilyParsedSnapshots: - def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): - self._raw_snapshots = raw_snapshots - self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} - - def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: - if snapshot_id not in self._parsed_snapshots: - raw_snapshot = self._raw_snapshots.get(snapshot_id) - if raw_snapshot: - self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) - else: - self._parsed_snapshots[snapshot_id] = None - return self._parsed_snapshots[snapshot_id] - - def evict(self, snapshot_id: SnapshotId) -> None: - self._parsed_snapshots.pop(snapshot_id, None) - - def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: - snapshot = self.get(snapshot_id) - if snapshot is None: - raise KeyError(snapshot_id) - return snapshot - - -class SharedVersionSnapshot(PydanticModel): - """A stripped down version of a snapshot that is used for fetching snapshots that share the same version - with a significantly reduced parsing overhead. - """ - - name: str - version: str - dev_version_: t.Optional[str] = Field(alias="dev_version") - identifier: str - fingerprint: SnapshotFingerprint - interval_unit: IntervalUnit - change_category: SnapshotChangeCategory - updated_ts: int - unpaused_ts: t.Optional[int] - unrestorable: bool - disable_restatement: bool - effective_from: t.Optional[TimeLike] - raw_snapshot: t.Dict[str, t.Any] - - @property - def snapshot_id(self) -> SnapshotId: - return SnapshotId(name=self.name, identifier=self.identifier) - - @property - def is_forward_only(self) -> bool: - return self.change_category == SnapshotChangeCategory.FORWARD_ONLY - - @property - def normalized_effective_from_ts(self) -> t.Optional[int]: - return ( - to_timestamp(self.interval_unit.cron_floor(self.effective_from)) - if self.effective_from - else None - ) - - @property - def dev_version(self) -> str: - return self.dev_version_ or self.fingerprint.to_version() - - @property - def full_snapshot(self) -> Snapshot: - return Snapshot( - **{ - **self.raw_snapshot, - "updated_ts": self.updated_ts, - "unpaused_ts": self.unpaused_ts, - "unrestorable": self.unrestorable, - } - ) - - def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: - """Sets the timestamp for when this snapshot was unpaused. - - Args: - unpaused_dt: The datetime object of when this snapshot was unpaused. - """ - self.unpaused_ts = ( - to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None - ) - - @classmethod - def from_snapshot_record( - cls, - *, - name: str, - identifier: str, - version: str, - updated_ts: int, - unpaused_ts: t.Optional[int], - unrestorable: bool, - snapshot: str, - ) -> SharedVersionSnapshot: - raw_snapshot = json.loads(snapshot) - raw_node = raw_snapshot["node"] - return SharedVersionSnapshot( - name=name, - version=version, - dev_version=raw_snapshot.get("dev_version"), - identifier=identifier, - fingerprint=raw_snapshot["fingerprint"], - interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])), - change_category=raw_snapshot["change_category"], - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), - effective_from=raw_snapshot.get("effective_from"), - raw_snapshot=raw_snapshot, - ) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 871ab3bc07..5a449b4137 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -34,7 +34,7 @@ from sqlmesh.core.model.kind import ModelKindName from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder from sqlmesh.core.state_sync.cache import CachingStateSync -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync from sqlmesh.utils.connection_pool import SingletonConnectionPool, ThreadLocalConnectionPool from sqlmesh.utils.date import ( make_inclusive_end, diff --git a/tests/core/test_environment.py b/tests/core/test_environment.py index 163bba47d8..307f220c25 100644 --- a/tests/core/test_environment.py +++ b/tests/core/test_environment.py @@ -2,7 +2,7 @@ from sqlmesh.core.environment import Environment, EnvironmentNamingInfo from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo -from sqlmesh.core.state_sync.engine_adapter import _environment_to_df +from sqlmesh.core.state_sync.db.environment import _environment_to_df def test_sanitize_name(): diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index a3b5f3ee29..f1008a10dc 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -106,7 +106,7 @@ def promote_snapshots( def delete_versions(state_sync: EngineAdapterStateSync) -> None: - state_sync.engine_adapter.drop_table(state_sync.versions_table) + state_sync.engine_adapter.drop_table(state_sync.version_state.versions_table) def test_push_snapshots( @@ -143,7 +143,7 @@ def test_push_snapshots( snapshot_b.snapshot_id: snapshot_b, } - logger = logging.getLogger("sqlmesh.core.state_sync.engine_adapter") + logger = logging.getLogger("sqlmesh.core.state_sync.db.facade") with patch.object(logger, "error") as mock_logger: state_sync.push_snapshots([snapshot_a]) assert str({snapshot_a.snapshot_id}) == mock_logger.call_args[0][1] @@ -195,10 +195,10 @@ def test_duplicates(state_sync: EngineAdapterStateSync, make_snapshot: t.Callabl snapshot_b.updated_ts = snapshot_a.updated_ts + 1 snapshot_c.updated_ts = 0 state_sync.push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_b]) - state_sync._push_snapshots([snapshot_c]) - state_sync._snapshot_cache.clear() + state_sync.snapshot_state.push_snapshots([snapshot_a]) + state_sync.snapshot_state.push_snapshots([snapshot_b]) + state_sync.snapshot_state.push_snapshots([snapshot_c]) + state_sync.snapshot_state.clear_cache() assert ( state_sync.get_snapshots([snapshot_a])[snapshot_a.snapshot_id].updated_ts == snapshot_b.updated_ts @@ -214,7 +214,7 @@ def test_snapshots_exists(state_sync: EngineAdapterStateSync, snapshots: t.List[ @pytest.fixture def get_snapshot_intervals(state_sync) -> t.Callable[[Snapshot], t.Optional[SnapshotIntervals]]: def _get_snapshot_intervals(snapshot: Snapshot) -> t.Optional[SnapshotIntervals]: - intervals = state_sync._get_snapshot_intervals([snapshot])[-1] + intervals = state_sync.interval_state.get_snapshot_intervals([snapshot]) return intervals[0] if intervals else None return _get_snapshot_intervals @@ -403,7 +403,7 @@ def test_refresh_snapshot_intervals( def test_get_snapshot_intervals( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, get_snapshot_intervals ) -> None: - state_sync.SNAPSHOT_BATCH_SIZE = 1 + state_sync.interval_state.SNAPSHOT_BATCH_SIZE = 1 snapshot_a = make_snapshot( SqlModel( @@ -500,7 +500,7 @@ def test_compact_intervals_delete_batches( ) delete_from_mock = mocker.patch.object(state_sync.engine_adapter, "delete_from") - state_sync.INTERVAL_BATCH_SIZE = 2 + state_sync.interval_state.INTERVAL_BATCH_SIZE = 2 state_sync.push_snapshots([snapshot]) @@ -512,7 +512,9 @@ def test_compact_intervals_delete_batches( state_sync.compact_intervals() - delete_from_mock.assert_has_calls([call(state_sync.intervals_table, mocker.ANY)] * 3) + delete_from_mock.assert_has_calls( + [call(state_sync.interval_state.intervals_table, mocker.ANY)] * 3 + ) def test_promote_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): @@ -1162,7 +1164,7 @@ def test_delete_expired_snapshots_seed( def test_delete_expired_snapshots_batching( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): - state_sync.SNAPSHOT_BATCH_SIZE = 1 + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 1 now_ts = now_timestamp() snapshot_a = make_snapshot( @@ -1234,7 +1236,7 @@ def test_delete_expired_snapshots_promoted( env.snapshots_ = [] state_sync.promote(env) - now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.engine_adapter.now_timestamp") + now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.db.snapshot.now_timestamp") now_timestamp_mock.return_value = now_timestamp() + 11000 assert state_sync.delete_expired_snapshots() == [ @@ -1494,7 +1496,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1529,7 +1531,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ # The intervals of the old snapshot is preserved with the null identifier @@ -1607,7 +1609,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1642,7 +1644,7 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( # Check all intervals assert sorted( - state_sync._get_snapshot_intervals([snapshot, new_snapshot])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), key=lambda x: x.identifier or "", ) == [ SnapshotIntervals( @@ -1754,7 +1756,7 @@ def test_compact_intervals_after_cleanup( assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), key=lambda x: (x.identifier or "", x.dev_version or ""), ) == expected_intervals @@ -1765,7 +1767,7 @@ def test_compact_intervals_after_cleanup( assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 4 # type: ignore assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), key=lambda x: (x.identifier or "", x.dev_version or ""), ) == expected_intervals @@ -2049,7 +2051,7 @@ def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: state_sync.migrate(default_catalog=None) # migration version is behind, always raise - state_sync._update_versions(schema_version=SCHEMA_VERSION + 1) + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION + 1) error = ( rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is behind '{SCHEMA_VERSION + 1}' \(remote\). " rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={SQLMESH_VERSION}"' command\).""" @@ -2062,7 +2064,7 @@ def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: state_sync.get_versions(validate=False) # migration version is ahead, only raise when validate is true - state_sync._update_versions(schema_version=SCHEMA_VERSION - 1) + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION - 1) with pytest.raises( SQLMeshError, match=rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is ahead of '{SCHEMA_VERSION - 1}'", @@ -2083,7 +2085,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: else f"{int(patch) + 1}" ) sqlmesh_version_patch_bump = f"{major}.{minor}.{new_patch}" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_patch_bump) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_patch_bump) state_sync.get_versions(validate=False) # sqlmesh version is behind @@ -2092,7 +2094,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is behind '{sqlmesh_version_minor_bump}' \(remote\). " rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={sqlmesh_version_minor_bump}"' command\).""" ) - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_bump) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_bump) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2100,7 +2102,7 @@ def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: # sqlmesh version is ahead sqlmesh_version_minor_decrease = f"{major}.{int(minor) - 1}.{patch}" error = rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is ahead of '{sqlmesh_version_minor_decrease}'" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2110,7 +2112,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: # patch version sqlglot doesn't matter major, minor, patch, *_ = SQLGLOT_VERSION.split(".") sqlglot_version = f"{major}.{minor}.{int(patch) + 1}" - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) state_sync.get_versions(validate=False) # sqlglot version is behind @@ -2119,7 +2121,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is behind '{sqlglot_version}' \(remote\). " rf"""Please upgrade SQLGlot \('pip install --upgrade "sqlglot=={sqlglot_version}"' command\).""" ) - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2127,7 +2129,7 @@ def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: # sqlglot version is ahead sqlglot_version = f"{major}.{int(minor) - 1}.{patch}" error = rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is ahead of '{sqlglot_version}'" - state_sync._update_versions(sqlglot_version=sqlglot_version) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) with pytest.raises(SQLMeshError, match=error): state_sync.get_versions() state_sync.get_versions(validate=False) @@ -2146,8 +2148,12 @@ def test_empty_versions() -> None: def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture, tmp_path) -> None: from sqlmesh import __version__ as SQLMESH_VERSION - migrate_rows_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._migrate_rows") - backup_state_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._backup_state") + migrate_rows_mock = mocker.patch( + "sqlmesh.core.state_sync.db.migrator.StateMigrator._migrate_rows" + ) + backup_state_mock = mocker.patch( + "sqlmesh.core.state_sync.db.migrator.StateMigrator._backup_state" + ) state_sync.migrate(default_catalog=None) migrate_rows_mock.assert_not_called() backup_state_mock.assert_not_called() @@ -2181,8 +2187,8 @@ def test_rollback(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> ): state_sync.rollback() - restore_table_spy = mocker.spy(state_sync, "_restore_table") - state_sync._backup_state() + restore_table_spy = mocker.spy(state_sync.migrator, "_restore_table") + state_sync.migrator._backup_state() state_sync.rollback() calls = {(a.sql(), b.sql()) for (a, b), _ in restore_table_spy.call_args_list} @@ -2207,16 +2213,18 @@ def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> state_sync = EngineAdapterStateSync( create_engine_adapter(lambda: duck_conn, "duckdb"), schema=c.SQLMESH, context_path=tmp_path ) - mocker.patch.object(state_sync, "_migrate_rows", side_effect=Exception("mocked error")) + mocker.patch.object(state_sync.migrator, "_migrate_rows", side_effect=Exception("mocked error")) with pytest.raises( SQLMeshError, match="SQLMesh migration failed.", ): state_sync.migrate(default_catalog=None) - assert not state_sync.engine_adapter.table_exists(state_sync.snapshots_table) - assert not state_sync.engine_adapter.table_exists(state_sync.environments_table) - assert not state_sync.engine_adapter.table_exists(state_sync.versions_table) - assert not state_sync.engine_adapter.table_exists(state_sync.intervals_table) + assert not state_sync.engine_adapter.table_exists(state_sync.snapshot_state.snapshots_table) + assert not state_sync.engine_adapter.table_exists( + state_sync.environment_state.environments_table + ) + assert not state_sync.engine_adapter.table_exists(state_sync.version_state.versions_table) + assert not state_sync.engine_adapter.table_exists(state_sync.interval_state.intervals_table) def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: @@ -2313,7 +2321,7 @@ def test_backup_state(state_sync: EngineAdapterStateSync, mocker: MockerFixture) }, ) - state_sync._backup_state() + state_sync.migrator._backup_state() pd.testing.assert_frame_equal( state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots"), state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots_backup"), @@ -2338,12 +2346,12 @@ def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: "select count(*) from sqlmesh._snapshots" ) assert old_snapshots_count == (12,) - state_sync._backup_state() + state_sync.migrator._backup_state() state_sync.engine_adapter.delete_from("sqlmesh._snapshots", "TRUE") snapshots_count = state_sync.engine_adapter.fetchone("select count(*) from sqlmesh._snapshots") assert snapshots_count == (0,) - state_sync._restore_table( + state_sync.migrator._restore_table( table_name="sqlmesh._snapshots", backup_table_name="sqlmesh._snapshots_backup", ) @@ -2375,7 +2383,7 @@ def test_seed_hydration( assert snapshot.model.is_hydrated assert snapshot.model.seed.content == "header\n1\n2" - state_sync._snapshot_cache.clear() + state_sync.snapshot_state.clear_cache() stored_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] assert isinstance(stored_snapshot.model, SeedModel) assert not stored_snapshot.model.is_hydrated @@ -2766,8 +2774,8 @@ def test_get_snapshots(mocker): def test_snapshot_batching(state_sync, mocker, make_snapshot): mock = mocker.Mock() - state_sync.SNAPSHOT_BATCH_SIZE = 2 - state_sync.engine_adapter = mock + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 2 + state_sync.snapshot_state.engine_adapter = mock snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1")), "1") snapshot_b = make_snapshot(SqlModel(name="a", query=parse_one("select 2")), "2") @@ -2838,13 +2846,12 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): ], ] - snapshots = state_sync._get_snapshots( + snapshots = state_sync.snapshot_state.get_snapshots( ( SnapshotId(name="a", identifier="1"), SnapshotId(name="a", identifier="2"), SnapshotId(name="a", identifier="3"), ), - hydrate_intervals=False, ) assert len(snapshots) == 3 calls = mock.fetchall.call_args_list @@ -2883,13 +2890,12 @@ def test_snapshot_cache( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture ): cache_mock = mocker.Mock() - state_sync._snapshot_cache = cache_mock + state_sync.snapshot_state._snapshot_cache = cache_mock snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1"))) cache_mock.get_or_load.return_value = ({snapshot.snapshot_id: snapshot}, {snapshot.snapshot_id}) - # Use _push_snapshots to bypass cache. - state_sync._push_snapshots([snapshot]) + state_sync.snapshot_state.push_snapshots([snapshot]) assert state_sync.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} cache_mock.get_or_load.assert_called_once_with({snapshot.snapshot_id}, mocker.ANY) @@ -2897,7 +2903,9 @@ def test_snapshot_cache( # Update the snapshot in the state and make sure this update is reflected on the cached instance. assert snapshot.unpaused_ts is None assert not snapshot.unrestorable - state_sync._update_snapshots([snapshot.snapshot_id], unpaused_ts=1, unrestorable=True) + state_sync.snapshot_state._update_snapshots( + [snapshot.snapshot_id], unpaused_ts=1, unrestorable=True + ) new_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] assert new_snapshot.unpaused_ts == 1 assert new_snapshot.unrestorable @@ -2912,7 +2920,7 @@ def test_update_auto_restatements(state_sync: EngineAdapterStateSync, make_snaps snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select 2")), version="2") snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("select 3")), version="3") - state_sync._push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + state_sync.snapshot_state.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] = { snapshot_a.name_version: 1, @@ -3125,7 +3133,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_b, "2020-01-03", "2020-01-03") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals @@ -3210,7 +3218,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_a, "2020-01-04", "2020-01-04") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals @@ -3269,7 +3277,7 @@ def test_compact_intervals_pending_restatement_shared_version( state_sync.add_interval(snapshot_b, "2020-01-05", "2020-01-05") assert ( sorted( - state_sync._get_snapshot_intervals([snapshot_a, snapshot_b])[1], + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), key=lambda x: (x.name, x.identifier or ""), ) == expected_intervals diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 8c0b0631c8..7de86e72b3 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -27,7 +27,7 @@ ViewKind, ) from sqlmesh.core.model.kind import SCDType2ByColumnKind, SCDType2ByTimeKind -from sqlmesh.core.state_sync.engine_adapter import _snapshot_to_json +from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig,