From 7e4fd48da6c3163bd103af9ae412f2fa3eaee5f3 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 31 Mar 2025 02:47:32 +0000 Subject: [PATCH 1/2] Feat: state import/export --- docs/concepts/state.md | 272 +++++++++ docs/reference/cli.md | 51 ++ mkdocs.yml | 1 + pyproject.toml | 2 + sqlmesh/cli/main.py | 82 +++ sqlmesh/core/console.py | 370 ++++++++++++ sqlmesh/core/context.py | 61 ++ sqlmesh/core/state_sync/base.py | 19 + sqlmesh/core/state_sync/common.py | 35 ++ sqlmesh/core/state_sync/db/facade.py | 101 +++- sqlmesh/core/state_sync/export_import.py | 229 +++++++ sqlmesh/schedulers/airflow/state_sync.py | 7 + tests/cli/test_cli.py | 374 ++++++++++++ tests/core/state_sync/test_export_import.py | 568 ++++++++++++++++++ .../core/{ => state_sync}/test_state_sync.py | 0 15 files changed, 2171 insertions(+), 1 deletion(-) create mode 100644 docs/concepts/state.md create mode 100644 sqlmesh/core/state_sync/export_import.py create mode 100644 tests/core/state_sync/test_export_import.py rename tests/core/{ => state_sync}/test_state_sync.py (100%) diff --git a/docs/concepts/state.md b/docs/concepts/state.md new file mode 100644 index 0000000000..29e1feed59 --- /dev/null +++ b/docs/concepts/state.md @@ -0,0 +1,272 @@ +# State + +SQLMesh stores information about your project in a state database that is usually separate from your main warehouse. + +The SQLMesh state database contains: + +- Information about every [Model Version](./models/overview.md) in your project (query, loaded intervals, dependencies) +- A list of every [Virtual Data Environment](./environments.md) in the project +- Which model versions are [promoted](./plans.md#plan-application) into each [Virtual Data Environment](./environments.md) +- Information about any [auto restatements](./models/overview.md#auto_restatement_cron) present in your project +- Other metadata about your project such as current SQLMesh / SQLGlot version + +The state database is how SQLMesh "remembers" what it's done before so it can compute a minimum set of operations to apply changes instead of rebuilding everything every time. It's also how SQLMesh tracks what historical data has already been backfilled for [incremental models](./models/model_kinds.md#incremental_by_time_range) so you dont need to add branching logic into the model query to handle this. + +!!! info "State database performance" + + The workload against the state database is an OLTP workload that requires transaction support in order to work correctly. + + For the best experience, we recommend [Tobiko Cloud](../cloud/cloud_index.md) or databases designed for OLTP workloads such as [PostgreSQL](../integrations/engines/postgres.md). + + Using your warehouse OLAP database to store state is supported for proof-of-concept projects but is not suitable for production and **will** lead to poor performance and consistency. + + For more information on engines suitable for the SQLMesh state database, see the [configuration guide](../guides/configuration.md#state-connection). + +## Exporting / Importing State + +SQLMesh supports exporting the state database to a `.json` file. From there, you can inspect the file with any tool that can read text files. You can also pass the file around and import it back in to a SQLMesh project running elsewhere. + +### Exporting state + +SQLMesh can export the state database to a file like so: + +```bash +$ sqlmesh state export -o state.json +Exporting state to 'state.json' from the following connection: + +Gateway: dev +State Connection: +├── Type: postgres +├── Catalog: sushi_dev +└── Dialect: postgres + +Continue? [y/n]: y + + Exporting versions ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + Exporting snapshots ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 17/17 • 0:00:00 +Exporting environments ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + +State exported successfully to 'state.json' +``` + +This will produce a file `state.json` in the current directory containing the SQLMesh state. + +The state file is a simple `json` file that looks like: + +```json +{ + /* State export metadata */ + "metadata": { + "timestamp": "2025-03-16 23:09:00+00:00", /* UTC timestamp of when the file was produced */ + "file_version": 1, /* state export file format version */ + "importable": true /* whether or not this file can be imported with `sqlmesh state import` */ + }, + /* Library versions used to produce this state export file */ + "versions": { + "schema_version": 76 /* sqlmesh state database schema version */, + "sqlglot_version": "26.10.1" /* version of SQLGlot used to produce the state file */, + "sqlmesh_version": "0.165.1" /* version of SQLMesh used to produce the state file */, + }, + /* array of objects containing every Snapshot (physical table) tracked by the SQLMesh project */ + "snapshots": [ + { "name": "..." } + ], + /* object for every Virtual Data Environment in the project. key = environment name, value = environment details */ + "environments": { + "prod": { + "..." + } + } +} +``` + +#### Specific environments + +You can export a specific environment like so: + +```sh +$ sqlmesh state export --environment my_dev -o my_dev_state.json +``` + +Note that every snapshot that is part of the environment will be exported, not just the differences from `prod`. The reason for this is so that the environment can be fully imported elsewhere without any assumptions about which snapshots are already present in state. + +#### Local state + +You can export local state like so: + +```bash +$ sqlmesh state export --local -o local_state.json +``` + +This essentially just exports the state of the local context which includes local changes that have not been applied to any virtual data environments. + +Therefore, a local state export will only have `snapshots` populated. `environments` will be empty because virtual data environments are only present in the warehouse / remote state. In addition, the file is marked as **not importable** so it cannot be used with a subsequent `sqlmesh state import` command. + +### Importing state + +!!! warning "Back up your state database first!" + + Please ensure you have created an independent backup of your state database in case something goes wrong during the state import. + + SQLMesh tries to wrap the state import in a transaction but some database engines do not support transactions against DDL which means + a import error has the potential to leave the state database in an inconsistent state. + +SQLMesh can import a state file into the state database like so: + +```bash +$ sqlmesh state import -i state.json --replace +Loading state from 'state.json' into the following connection: + +Gateway: dev +State Connection: +├── Type: postgres +├── Catalog: sushi_dev +└── Dialect: postgres + +[WARNING] This destructive operation will delete all existing state against the 'dev' gateway +and replace it with what\'s in the 'state.json' file. + +Are you sure? [y/n]: y + +State File Information: +├── Creation Timestamp: 2025-03-31 02:15:00+00:00 +├── File Version: 1 +├── SQLMesh version: 0.170.1.dev0 +├── SQLMesh migration version: 76 +└── SQLGlot version: 26.12.0 + + Importing versions ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + Importing snapshots ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 17/17 • 0:00:00 +Importing environments ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + +State imported successfully from 'state.json' +``` + +Note that the state database structure needs to be present and up to date, so run `sqlmesh migrate` before running `sqlmesh state import` if you get a version mismatch error. + +If you have a partial state export, perhaps for a single environment - you can merge it in by omitting the `--replace` parameter: + +```bash +$ sqlmesh state import -i state.json +... + +[WARNING] This operation will merge the contents of the state file to the state located at the 'dev' gateway. +Matching snapshots or environments will be replaced. +Non-matching snapshots or environments will be ignored. + +Are you sure? [y/n]: y + +... +State imported successfully from 'state.json' +``` + + +### Specific gateways + +If your project has [multiple gateways](../guides/configuration.md#gateways) with different state connections per gateway, you can target the [state_connection](../guides/configuration.md#state-connection) of a specific gateway like so: + +```bash +# state export +$ sqlmesh --gateway state export -o state.json + +# state import +$ sqlmesh --gateway state import -i state.json +``` + +## Version Compatibility + +When importing state, the state file must have been produced with the same major and minor version of SQLMesh that is being used to import it. + +If you attempt to import state with an incompatible version, you will get the following error: + +```bash +$ sqlmesh state import -i state.json +...SNIP... + +State import failed! +Error: SQLMesh version mismatch. You are running '0.165.1' but the state file was created with '0.164.1'. +Please upgrade/downgrade your SQLMesh version to match the state file before performing the import. +``` + +### Upgrading a state file + +You can upgrade a state file produced by an old SQLMesh version to be compatible with a newer SQLMesh version by: + +- Loading it into a local database using the older SQLMesh version +- Installing the newer SQLMesh version +- Running `sqlmesh migrate` to upgrade the state within the local database +- Running `sqlmesh state export` to export it back out again. The new export is now compatible with the newer version of SQLMesh. + +Below is an example of how to upgrade a state file created with SQLMesh `0.164.1` to be compatible with SQLMesh `0.165.1`. + +First, create and activate a virtual environment to isolate the SQLMesh versions from your main environment: + +```bash +$ python -m venv migration-env + +$ . ./migration-env/bin/activate + +(migration-env)$ +``` + +Install the SQLMesh version compatible with your state file. The correct version to use is printed in the error message, eg `the state file was created with '0.164.1'` means you need to install SQLMesh `0.164.1`: + +```bash +(migration-env)$ pip install "sqlmesh==0.164.1" +``` + +Add a gateway to your `config.yaml` like so: + +```yaml +gateways: + migration: + connection: + type: duckdb + database: ./state-migration.duckdb +``` + +The goal here is to define just enough config for SQLMesh to be able to use a local database to run the state export/import commands. SQLMesh still needs to inherit things like the `model_defaults` from your project in order to migrate state correctly which is why we have not used an isolated directory. + +!!! warning + + From here on, be sure to specify `--gateway migration` to all SQLMesh commands or you run the risk of accidentally clobbering any state on your main gateway + +You can now import your state export using the same version of SQLMesh it was created with: + +```bash +(migration-env)$ sqlmesh --gateway migration migrate + +(migration-env)$ sqlmesh --gateway migration state import -i state.json +... +State imported successfully from 'state.json' +``` + +Now we have the state imported, we can upgrade SQLMesh and export the state from the new version. +The new version was printed in the original error message, eg `You are running '0.165.1'` + +To upgrade SQLMesh, simply install the new version: + +```bash +(migration-env)$ pip install --upgrade "sqlmesh==0.165.1" +``` + +Migrate the state to the new version: + +```bash +(migration-env)$ sqlmesh --gateway migration migrate +``` + +And finally, create a new state file which is now compatible with the new SQLMesh version: + +```bash + (migration-env)$ sqlmesh --gateway migration state export -o state-migrated.json +``` + +The `state-migrated.json` file is now compatible with the newer version of SQLMesh. +You can then transfer it to the place you originally needed it and import it in: + +```bash +$ sqlmesh state import -i state-migrated.json +... +State imported successfully from 'state-migrated.json' +``` \ No newline at end of file diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 951d2db65b..cd110e1b42 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -40,6 +40,7 @@ Commands: rewrite Rewrite a SQL expression with semantic... rollback Rollback SQLMesh to the previous migration. run Evaluate missing intervals for the target... + state Commands for interacting with state table_diff Show the diff between two tables. table_name Prints the name of the physical table for the... test Run model unit tests. @@ -455,6 +456,56 @@ Options: --help Show this message and exit. ``` +## state + +``` +Usage: sqlmesh state [OPTIONS] COMMAND [ARGS]... + + Commands for interacting with state + +Options: + --help Show this message and exit. + +Commands: + export Export the state database to a file + import Import a state export file back into the state database +``` + +### export + +``` +Usage: sqlmesh state export [OPTIONS] + + Export the state database to a file + +Options: + -o, --output-file FILE Path to write the state export to [required] + --environment TEXT Name of environment to export. Specify multiple + --environment arguments to export multiple + environments + --local Export local state only. Note that the resulting + file will not be importable + --no-confirm Do not prompt for confirmation before exporting + existing state + --help Show this message and exit. +``` + +### import + +``` +Usage: sqlmesh state import [OPTIONS] + + Import a state export file back into the state database + +Options: + -i, --input-file FILE Path to the state file [required] + --replace Clear the remote state before loading the file. If + omitted, a merge is performed instead + --no-confirm Do not prompt for confirmation before updating + existing state + --help Show this message and exit. +``` + ## table_diff ``` diff --git a/mkdocs.yml b/mkdocs.yml index 3730443730..a21d1f1fa5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -41,6 +41,7 @@ nav: - concepts/environments.md - concepts/tests.md - concepts/audits.md + - concepts/state.md - Models: - concepts/models/overview.md - concepts/models/model_kinds.md diff --git a/pyproject.toml b/pyproject.toml index 08d85fbeaf..74bc72956e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "sqlglot[rs]~=26.12.1", "tenacity", "time-machine", + "json-stream" ] classifiers = [ "Intended Audience :: Developers", @@ -203,5 +204,6 @@ module = [ "pydantic_core.*", "dlt.*", "bigframes.*", + "json_stream.*" ] ignore_missing_imports = true diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 3aec539020..fd3e592fb5 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -18,6 +18,7 @@ from sqlmesh.core.context import Context from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import MissingDependencyError +from pathlib import Path logger = logging.getLogger(__name__) @@ -1030,3 +1031,84 @@ def lint( ) -> None: """Run the linter for the target model(s).""" obj.lint_models(models) + + +@cli.group(no_args_is_help=True) +def state() -> None: + """Commands for interacting with state""" + pass + + +@state.command("export") +@click.option( + "-o", + "--output-file", + required=True, + help="Path to write the state export to", + type=click.Path(dir_okay=False, writable=True, path_type=Path), +) +@click.option( + "--environment", + multiple=True, + help="Name of environment to export. Specify multiple --environment arguments to export multiple environments", +) +@click.option( + "--local", + is_flag=True, + help="Export local state only. Note that the resulting file will not be importable", +) +@click.option( + "--no-confirm", + is_flag=True, + help="Do not prompt for confirmation before exporting existing state", +) +@click.pass_obj +@error_handler +@cli_analytics +def state_export( + obj: Context, + output_file: Path, + environment: t.Optional[t.Tuple[str]], + local: bool, + no_confirm: bool, +) -> None: + """Export the state database to a file""" + confirm = not no_confirm + + if environment and local: + raise click.ClickException("Cannot specify both --environment and --local") + + environment_names = list(environment) if environment else None + obj.export_state( + output_file=output_file, + environment_names=environment_names, + local_only=local, + confirm=confirm, + ) + + +@state.command("import") +@click.option( + "-i", + "--input-file", + help="Path to the state file", + required=True, + type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), +) +@click.option( + "--replace", + is_flag=True, + help="Clear the remote state before loading the file. If omitted, a merge is performed instead", +) +@click.option( + "--no-confirm", + is_flag=True, + help="Do not prompt for confirmation before updating existing state", +) +@click.pass_obj +@error_handler +@cli_analytics +def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool) -> None: + """Import a state export file back into the state database""" + confirm = not no_confirm + obj.import_state(input_file=input_file, clear=replace, confirm=confirm) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index ea5a4bb825..0c78d982aa 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -7,6 +7,7 @@ import uuid import logging import textwrap +from pathlib import Path from hyperscript import h from rich.console import Console as RichConsole @@ -55,6 +56,8 @@ from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.plan import Plan, EvaluatablePlan, PlanBuilder, SnapshotIntervals from sqlmesh.core.table_diff import TableDiff, RowDiff, SchemaDiff + from sqlmesh.core.config.connection import ConnectionConfig + from sqlmesh.core.state_sync import Versions LayoutWidget = t.TypeVar("LayoutWidget", bound=t.Union[widgets.VBox, widgets.HBox]) @@ -207,6 +210,62 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: """Stop the environment migration progress.""" + @abc.abstractmethod + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + """State a state export""" + + @abc.abstractmethod + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + """Update the state export progress""" + + @abc.abstractmethod + def stop_state_export(self, success: bool, output_file: Path) -> None: + """Finish a state export""" + + @abc.abstractmethod + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + """Start a state import""" + + @abc.abstractmethod + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + """Update the state import process""" + + @abc.abstractmethod + def stop_state_import(self, success: bool, input_file: Path) -> None: + """Finish a state import""" + @abc.abstractmethod def show_model_difference_summary( self, @@ -322,6 +381,10 @@ def show_row_diff( def print_environments(self, environments_summary: t.Dict[str, int]) -> None: """Prints all environment names along with expiry datetime.""" + @abc.abstractmethod + def print_connection_config(self, config: ConnectionConfig, title: str = "Connection") -> None: + """Print connection config information""" + def _limit_model_names(self, tree: Tree, verbosity: Verbosity = Verbosity.DEFAULT) -> Tree: """Trim long indirectly modified model lists below threshold.""" modified_length = len(tree.children) @@ -433,6 +496,56 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: pass + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + return confirm + + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + pass + + def stop_state_export(self, success: bool, output_file: Path) -> None: + pass + + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + return confirm + + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + pass + + def stop_state_import(self, success: bool, input_file: Path) -> None: + pass + def show_model_difference_summary( self, context_diff: ContextDiff, @@ -515,6 +628,11 @@ def show_linter_violations( ) -> None: pass + def print_connection_config( + self, config: ConnectionConfig, title: t.Optional[str] = "Connection" + ) -> None: + pass + def make_progress_bar( message: str, @@ -579,6 +697,16 @@ def __init__( self.loading_status: t.Dict[uuid.UUID, Status] = {} + self.state_export_progress: t.Optional[Progress] = None + self.state_export_version_task: t.Optional[TaskID] = None + self.state_export_snapshot_task: t.Optional[TaskID] = None + self.state_export_environment_task: t.Optional[TaskID] = None + + self.state_import_progress: t.Optional[Progress] = None + self.state_import_version_task: t.Optional[TaskID] = None + self.state_import_snapshot_task: t.Optional[TaskID] = None + self.state_import_environment_task: t.Optional[TaskID] = None + self.verbosity = verbosity self.dialect = dialect self.ignore_warnings = ignore_warnings @@ -897,6 +1025,238 @@ def stop_env_migration_progress(self, success: bool = True) -> None: if success: self.log_success("Environments migrated successfully") + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + self.state_export_progress = None + + if local_only: + self.log_status_update(f"Exporting [b]local[/b] state to '{output_file.as_posix()}'\n") + self.log_warning( + "Local state exports just contain the model versions in your local context. Therefore, the resulting file cannot be imported." + ) + else: + self.log_status_update( + f"Exporting state to '{output_file.as_posix()}' from the following connection:\n" + ) + if gateway: + self.log_status_update(f"[b]Gateway[/b]: [green]{gateway}[/green]") + if state_connection_config: + self.print_connection_config(state_connection_config, title="State Connection") + if environment_names: + heading = "Environments" if len(environment_names) > 1 else "Environment" + self.log_status_update( + f"[b]{heading}[/b]: [yellow]{', '.join(environment_names)}[/yellow]" + ) + + should_continue = True + if confirm: + should_continue = self._confirm("\nContinue?") + self.log_status_update("") + + if should_continue: + self.state_export_progress = make_progress_bar("{task.description}", self.console) + assert isinstance(self.state_export_progress, Progress) + + self.state_export_version_task = self.state_export_progress.add_task( + "Exporting versions", start=False + ) + self.state_export_snapshot_task = self.state_export_progress.add_task( + "Exporting snapshots", start=False + ) + self.state_export_environment_task = self.state_export_progress.add_task( + "Exporting environments", start=False + ) + + self.state_export_progress.start() + + return should_continue + + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + if self.state_export_progress: + if self.state_export_version_task is not None: + if version_count is not None: + self.state_export_progress.start_task(self.state_export_version_task) + self.state_export_progress.update( + self.state_export_version_task, + total=version_count, + completed=version_count, + refresh=True, + ) + if versions_complete: + self.state_export_progress.stop_task(self.state_export_version_task) + + if self.state_export_snapshot_task is not None: + if snapshot_count is not None: + self.state_export_progress.start_task(self.state_export_snapshot_task) + self.state_export_progress.update( + self.state_export_snapshot_task, + total=snapshot_count, + completed=snapshot_count, + refresh=True, + ) + if snapshots_complete: + self.state_export_progress.stop_task(self.state_export_snapshot_task) + + if self.state_export_environment_task is not None: + if environment_count is not None: + self.state_export_progress.start_task(self.state_export_environment_task) + self.state_export_progress.update( + self.state_export_environment_task, + total=environment_count, + completed=environment_count, + refresh=True, + ) + if environments_complete: + self.state_export_progress.stop_task(self.state_export_environment_task) + + def stop_state_export(self, success: bool, output_file: Path) -> None: + if self.state_export_progress: + self.state_export_progress.stop() + self.state_export_progress = None + + if success: + self.log_success(f"State exported successfully to '{output_file.as_posix()}'") + else: + self.log_error("State export failed!") + + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + self.log_status_update( + f"Loading state from '{input_file.as_posix()}' into the following connection:\n" + ) + self.log_status_update(f"[b]Gateway[/b]: [green]{gateway}[/green]") + self.print_connection_config(state_connection_config, title="State Connection") + self.log_status_update("") + + if clear: + self.log_warning( + f"This [b]destructive[/b] operation will delete all existing state against the '{gateway}' gateway \n" + f"and replace it with what's in the '{input_file.as_posix()}' file.\n" + ) + else: + self.log_warning( + f"This operation will [b]merge[/b] the contents of the state file to the state located at the '{gateway}' gateway.\n" + "Matching snapshots or environments will be replaced.\n" + "Non-matching snapshots or environments will be ignored.\n" + ) + + should_continue = True + if confirm: + should_continue = self._confirm("[red]Are you sure?[/red]") + self.log_status_update("") + + if should_continue: + self.state_import_progress = make_progress_bar("{task.description}", self.console) + + self.state_import_info = Tree("[bold]State File Information:") + + self.state_import_version_task = self.state_import_progress.add_task( + "Importing versions", start=False + ) + self.state_import_snapshot_task = self.state_import_progress.add_task( + "Importing snapshots", start=False + ) + self.state_import_environment_task = self.state_import_progress.add_task( + "Importing environments", start=False + ) + + self.state_import_progress.start() + + return should_continue + + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + if self.state_import_progress: + if self.state_import_info: + if timestamp: + self.state_import_info.add(f"Creation Timestamp: {timestamp}") + if state_file_version: + self.state_import_info.add(f"File Version: {state_file_version}") + if versions: + self.state_import_info.add(f"SQLMesh version: {versions.sqlmesh_version}") + self.state_import_info.add( + f"SQLMesh migration version: {versions.schema_version}" + ) + self.state_import_info.add(f"SQLGlot version: {versions.sqlglot_version}\n") + + self._print(self.state_import_info) + + version_count = len(versions.model_dump()) + + if self.state_import_version_task is not None: + self.state_import_progress.start_task(self.state_import_version_task) + self.state_import_progress.update( + self.state_import_version_task, + total=version_count, + completed=version_count, + ) + self.state_import_progress.stop_task(self.state_import_version_task) + + if self.state_import_snapshot_task is not None: + if snapshot_count is not None: + self.state_import_progress.start_task(self.state_import_snapshot_task) + self.state_import_progress.update( + self.state_import_snapshot_task, + completed=snapshot_count, + total=snapshot_count, + refresh=True, + ) + + if snapshots_complete: + self.state_import_progress.stop_task(self.state_import_snapshot_task) + + if self.state_import_environment_task is not None: + if environment_count is not None: + self.state_import_progress.start_task(self.state_import_environment_task) + self.state_import_progress.update( + self.state_import_environment_task, + completed=environment_count, + total=environment_count, + refresh=True, + ) + + if environments_complete: + self.state_import_progress.stop_task(self.state_import_environment_task) + + def stop_state_import(self, success: bool, input_file: Path) -> None: + if self.state_import_progress: + self.state_import_progress.stop() + self.state_import_progress = None + + if success: + self.log_success(f"State imported successfully from '{input_file.as_posix()}'") + else: + self.log_error("State import failed!") + def show_model_difference_summary( self, context_diff: ContextDiff, @@ -1585,6 +1945,16 @@ def print_environments(self, environments_summary: t.Dict[str, int]) -> None: output_str = "\n".join([str(len(output)), *output]) self.log_status_update(f"Number of SQLMesh environments are: {output_str}") + def print_connection_config(self, config: ConnectionConfig, title: str = "Connection") -> None: + engine_adapter_type = config._engine_adapter + + tree = Tree(f"[b]{title}:[/b]") + tree.add(f"Type: [bold cyan]{config.type_}[/bold cyan]") + tree.add(f"Catalog: [bold cyan]{config.get_catalog()}[/bold cyan]") + tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]") + + self._print(tree) + def _get_snapshot_change_category( self, snapshot: Snapshot, diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 004e5c1e28..87217fe64e 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -392,6 +392,9 @@ def __init__( ] self._connection_config = self.config.get_connection(self.gateway) + self._state_connection_config = ( + self.config.get_state_connection(self.gateway) or self._connection_config + ) self.concurrent_tasks = concurrent_tasks or self._connection_config.concurrent_tasks self._engine_adapters: t.Dict[str, EngineAdapter] = { @@ -2089,6 +2092,64 @@ def clear_caches(self) -> None: for path in self.configs: rmtree(path / c.CACHE) + def export_state( + self, + output_file: Path, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> None: + from sqlmesh.core.state_sync.export_import import export_state + + # trigger a connection to the StateSync so we can fail early if there is a problem + # note we still need to do this even if we are doing a local export so we know what 'versions' to write + self.state_sync.get_versions(validate=True) + + local_snapshots = self.snapshots if local_only else None + + if self.console.start_state_export( + output_file=output_file, + gateway=self.selected_gateway, + state_connection_config=self._state_connection_config, + environment_names=environment_names, + local_only=local_only, + confirm=confirm, + ): + try: + export_state( + state_sync=self.state_sync, + output_file=output_file, + local_snapshots=local_snapshots, + environment_names=environment_names, + console=self.console, + ) + self.console.stop_state_export(success=True, output_file=output_file) + except: + self.console.stop_state_export(success=False, output_file=output_file) + raise + + def import_state(self, input_file: Path, clear: bool = False, confirm: bool = True) -> None: + from sqlmesh.core.state_sync.export_import import import_state + + if self.console.start_state_import( + input_file=input_file, + gateway=self.selected_gateway, + state_connection_config=self._state_connection_config, + clear=clear, + confirm=confirm, + ): + try: + import_state( + state_sync=self.state_sync, + input_file=input_file, + clear=clear, + console=self.console, + ) + self.console.stop_state_import(success=True, input_file=input_file) + except: + self.console.stop_state_import(success=False, input_file=input_file) + raise + def _run_tests( self, verbosity: Verbosity = Verbosity.DEFAULT ) -> t.Tuple[unittest.result.TestResult, str]: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 967e5f9571..771dd94172 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -24,6 +24,7 @@ from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator +from sqlmesh.core.state_sync.common import StateStream logger = logging.getLogger(__name__) @@ -267,6 +268,14 @@ def _get_versions(self) -> Versions: The versions object. """ + @abc.abstractmethod + def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream: + """Export the contents of this StateSync as a StateStream + + Args: + environment_names: An optional list of environment names to export. If not specified, all environments will be exported. + """ + class StateSync(StateReader, abc.ABC): """Abstract base class for snapshot and environment state management.""" @@ -459,6 +468,16 @@ def add_interval( ) self.add_snapshots_intervals([snapshot_intervals]) + @abc.abstractmethod + def import_(self, stream: StateStream, clear: bool = True) -> None: + """ + Replace the existing state with the state contained in the StateStream + + Args: + stream: The stream of new state + clear: Whether or not to clear existing state before inserting state from the stream + """ + class DelegatingStateSync(StateSync): def __init__(self, state_sync: StateSync) -> None: diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 8c3bd51b7d..1d0778a4d8 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -3,14 +3,18 @@ import logging import typing as t from functools import wraps +import itertools +import abc from sqlmesh.core.console import Console from sqlmesh.core.dialect import schema_ from sqlmesh.core.environment import Environment from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.snapshot import Snapshot if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter.base import EngineAdapter + from sqlmesh.core.state_sync.base import Versions logger = logging.getLogger(__name__) @@ -81,3 +85,34 @@ def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: return wrapper return decorator + + +T = t.TypeVar("T") + + +def chunk_iterable(iterable: t.Iterable[T], size: int = 10) -> t.Iterable[t.Iterable[T]]: + iterator = iter(iterable) + for first in iterator: + yield itertools.chain([first], itertools.islice(iterator, size - 1)) + + +class StateStream(abc.ABC): + """ + Represents a stream of state either going into the StateSync (perhaps loaded from a file) + or out of the StateSync (perhaps being dumped to a file) + """ + + @property + @abc.abstractmethod + def versions(self) -> Versions: + """The versions of the objects contained in this StateStream""" + + @property + @abc.abstractmethod + def snapshots(self) -> t.Iterable[Snapshot]: + """A stream of Snapshot objects. Note that they should be fully populated with any relevant Intervals""" + + @property + @abc.abstractmethod + def environments(self) -> t.Iterable[Environment]: + """A stream of Environment objects""" diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index cb3dd10685..8b6e373d3c 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -19,6 +19,7 @@ import contextlib import logging import typing as t +import itertools from pathlib import Path from datetime import datetime @@ -46,7 +47,11 @@ StateSync, Versions, ) -from sqlmesh.core.state_sync.common import transactional +from sqlmesh.core.state_sync.common import ( + transactional, + StateStream, + chunk_iterable, +) from sqlmesh.core.state_sync.db.interval import IntervalState from sqlmesh.core.state_sync.db.environment import EnvironmentState from sqlmesh.core.state_sync.db.snapshot import SnapshotState @@ -439,6 +444,100 @@ def rollback(self) -> None: """Rollback to the previous migration.""" self.migrator.rollback() + @transactional() + def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream: + state_sync = self + + snapshot_ids_to_export: t.Set[SnapshotId] = set() + selected_environments: t.List[Environment] = [] + if environment_names: + for env_name in environment_names: + environment = self.get_environment(env_name) + if not environment: + raise SQLMeshError(f"No such environment: {env_name}") + selected_environments.append(environment) + + for env in selected_environments: + snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []]) + + def _include_snapshot(s_id: SnapshotId) -> bool: + if environment_names: + return s_id in snapshot_ids_to_export + return True + + class _DumpStateStream(StateStream): + @property + def versions(self) -> Versions: + return state_sync.get_versions() + + @property + def snapshots(self) -> t.Iterable[Snapshot]: + all_snapshot_ids = { + s.snapshot_id + for e in state_sync.get_environments() + for s in e.snapshots + if _include_snapshot(s.snapshot_id) + } + for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE): + yield from state_sync.get_snapshots(chunk).values() + + @property + def environments(self) -> t.Iterable[Environment]: + if environment_names: + yield from selected_environments + else: + yield from state_sync.get_environments() + + return _DumpStateStream() + + @transactional() + def import_(self, stream: StateStream, clear: bool = True) -> None: + existing_versions = self.get_versions() + + # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file + # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption + # is that they dont contain any breaking changes + incoming_versions = stream.versions + if incoming_versions.minor_sqlmesh_version != existing_versions.minor_sqlmesh_version: + raise SQLMeshError( + f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n" + "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import." + ) + + if clear: + self.reset(default_catalog=None) + + auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {} + + for snapshot_chunk in chunk_iterable(stream.snapshots, SnapshotState.SNAPSHOT_BATCH_SIZE): + snapshot_iterator, intervals_iterator, auto_restatments_iterator = itertools.tee( + snapshot_chunk, 3 + ) + overwrite_existing_snapshots = ( + not clear + ) # if clear=True, all existing snapshjots were dropped anyway + self.snapshot_state.push_snapshots( + snapshot_iterator, overwrite=overwrite_existing_snapshots + ) + self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator)) + + auto_restatements.update( + { + s.name_version: s.next_auto_restatement_ts + for s in auto_restatments_iterator + if s.next_auto_restatement_ts + } + ) + + existing_environments = set(self.get_environments_summary().keys()) if not clear else set() + for environment in stream.environments: + if not clear and environment.name in existing_environments: + self.environment_state.update_environment(environment) + else: + self.promote(environment) + + self.update_auto_restatements(auto_restatements) + def state_type(self) -> str: return self.engine_adapter.dialect diff --git a/sqlmesh/core/state_sync/export_import.py b/sqlmesh/core/state_sync/export_import.py new file mode 100644 index 0000000000..0c390831a4 --- /dev/null +++ b/sqlmesh/core/state_sync/export_import.py @@ -0,0 +1,229 @@ +import json +import typing as t +from sqlmesh.core.state_sync import StateSync +from sqlmesh.core.snapshot import Snapshot +from sqlmesh.utils.date import now, to_tstz +from sqlmesh.core.environment import Environment +from sqlmesh.utils.pydantic import _expression_encoder +from sqlmesh.core.state_sync import Versions +from sqlmesh.core.state_sync.common import StateStream +from sqlmesh.core.console import Console +from pathlib import Path +from sqlmesh.core.console import NoopConsole + +import json_stream +from json_stream import streamable_dict, to_standard_types, streamable_list +from json_stream.writer import StreamableDict +from json_stream.base import StreamingJSONObject +from json_stream.dump import JSONStreamEncoder +from sqlmesh.utils.errors import SQLMeshError +from sqlglot import exp +from sqlmesh.utils.pydantic import DEFAULT_ARGS as PYDANTIC_DEFAULT_ARGS, PydanticModel + + +class SQLMeshJSONStreamEncoder(JSONStreamEncoder): + def default(self, obj: t.Any) -> t.Any: + if isinstance(obj, exp.Expression): + return _expression_encoder(obj) + + return super().default(obj) + + +def _dump_pydantic_model(model: PydanticModel) -> t.Dict[str, t.Any]: + dump_args: t.Dict[str, t.Any] = PYDANTIC_DEFAULT_ARGS + return model.model_dump(mode="json", **dump_args) + + +def _create_local_state_stream(versions: Versions, snapshots: t.Dict[str, Snapshot]) -> StateStream: + class _LocalStateStream(StateStream): + @property + def versions(self) -> Versions: + return versions + + @property + def snapshots(self) -> t.Iterable[Snapshot]: + return iter(snapshots.values()) + + @property + def environments(self) -> t.Iterable[Environment]: + return [] + + return _LocalStateStream() + + +def _export(state_stream: StateStream, importable: bool, console: Console) -> StreamableDict: + """ + Return the state in a format 'json_stream' can stream to a file + + Args: + state_stream: A stream of state to export + console: A Console instance to print progress to + """ + + @streamable_list + def _dump_snapshots( + snapshot_stream: t.Iterable[Snapshot], + ) -> t.Iterator[t.Dict[str, t.Any]]: + console.update_state_export_progress(snapshot_count=0) + for idx, snapshot in enumerate(snapshot_stream): + yield _dump_pydantic_model(snapshot) + console.update_state_export_progress(snapshot_count=idx + 1) + + @streamable_dict + def _dump_environments( + environment_stream: t.Iterable[Environment], + ) -> t.Iterator[t.Tuple[str, t.Any]]: + console.update_state_export_progress(environment_count=0) + for idx, env in enumerate(environment_stream): + yield env.name, _dump_pydantic_model(env) + console.update_state_export_progress(environment_count=idx + 1) + + @streamable_dict + def _do_export() -> t.Iterator[t.Tuple[str, t.Any]]: + yield "metadata", {"timestamp": to_tstz(now()), "file_version": 1, "importable": importable} + + versions = _dump_pydantic_model(state_stream.versions) + yield "versions", versions + console.update_state_export_progress(version_count=len(versions), versions_complete=True) + + yield "snapshots", _dump_snapshots(state_stream.snapshots) + console.update_state_export_progress(snapshots_complete=True) + + yield "environments", _dump_environments(state_stream.environments) + console.update_state_export_progress(environments_complete=True) + + return _do_export() + + +def _import( + state_sync: StateSync, data: t.Callable[[], StreamingJSONObject], clear: bool, console: Console +) -> None: + """ + Load the state defined by the :data into the supplied :state_sync. The data is in the same format as written by dump() + + Args: + state_sync: The StateSync that the user has requested to dump state from + data: A factory function that produces new streaming JSON reader attached to the file we are loading state from. + This is so each section of the file can have its own reader which allows it to be read in isolation / out-of-order + This puts less reliance on downstream consumers performing operations in a certain order + clear: Whether or not to clear the existing state before writing the new state + console: A Console instance to print progress to + """ + + class _FileStateStream(StateStream): + @property + def versions(self) -> Versions: + versions_raw = to_standard_types(data()["versions"]) + return Versions.model_validate(versions_raw) + + @property + def snapshots(self) -> t.Iterable[Snapshot]: + stream = data()["snapshots"] + + console.update_state_import_progress(snapshot_count=0) + for idx, raw_snapshot in enumerate(stream): + snapshot = Snapshot.model_validate(to_standard_types(raw_snapshot)) + yield snapshot + console.update_state_import_progress(snapshot_count=idx + 1) + + console.update_state_import_progress(snapshots_complete=True) + + @property + def environments(self) -> t.Iterable[Environment]: + stream = data()["environments"] + + console.update_state_import_progress(environment_count=0) + for idx, (_, raw_environment) in enumerate(stream.items()): + environment = Environment.model_validate(to_standard_types(raw_environment)) + yield environment + console.update_state_import_progress(environment_count=idx + 1) + + console.update_state_import_progress(environments_complete=True) + + metadata = to_standard_types(data()["metadata"]) + + timestamp = metadata["timestamp"] + if not isinstance(timestamp, str): + raise ValueError(f"'timestamp' contains an invalid value. Expecting str, got: {timestamp}") + console.update_state_import_progress( + timestamp=timestamp, state_file_version=metadata["file_version"] + ) + + stream = _FileStateStream() + + console.update_state_import_progress(versions=stream.versions) + + state_sync.import_(stream, clear=clear) + + +def export_state( + state_sync: StateSync, + output_file: Path, + local_snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + environment_names: t.Optional[t.List[str]] = None, + console: t.Optional[Console] = None, +) -> None: + console = console or NoopConsole() + + state_stream = ( + _create_local_state_stream(state_sync.get_versions(), local_snapshots) + if local_snapshots + else state_sync.export(environment_names=environment_names) + ) + + importable = False if local_snapshots else True + + json_stream = _export(state_stream=state_stream, importable=importable, console=console) + with output_file.open(mode="w", encoding="utf8") as fh: + json.dump(json_stream, fh, indent=2, cls=SQLMeshJSONStreamEncoder) + + +def import_state( + state_sync: StateSync, + input_file: Path, + clear: bool = False, + console: t.Optional[Console] = None, +) -> None: + console = console or NoopConsole() + + # we need to peek into the file to figure out what state version we are dealing with + with input_file.open("r", encoding="utf8") as fh: + stream = json_stream.load(fh) + if not isinstance(stream, StreamingJSONObject): + raise SQLMeshError(f"Expected JSON object, got: {type(stream)}") + + try: + metadata = stream["metadata"].persistent() + except KeyError: + raise SQLMeshError("Expecting a 'metadata' key to be present") + + if not isinstance(metadata, StreamingJSONObject): + raise SQLMeshError("Expecting the 'metadata' key to contain an object") + + file_version = metadata.get("file_version") + if file_version is None: + raise SQLMeshError("Unable to determine state file format version from the input file") + + try: + int(file_version) + except ValueError: + raise SQLMeshError(f"Unable to parse state file format version: {file_version}") + + if not metadata.get("importable", False): + # this can happen if the state file was created from local unversioned snapshots that were not sourced from the project state database + raise SQLMeshError("State file is marked as not importable. Aborting") + + handles: t.List[t.TextIO] = [] + + def _new_handle() -> StreamingJSONObject: + handle = input_file.open("r", encoding="utf8") + handles.append(handle) + stream = json_stream.load(handle) + assert isinstance(stream, StreamingJSONObject) + return stream + + try: + _import(state_sync=state_sync, data=_new_handle, clear=clear, console=console) + finally: + for handle in handles: + handle.close() diff --git a/sqlmesh/schedulers/airflow/state_sync.py b/sqlmesh/schedulers/airflow/state_sync.py index 345584f6b3..fecf15199a 100644 --- a/sqlmesh/schedulers/airflow/state_sync.py +++ b/sqlmesh/schedulers/airflow/state_sync.py @@ -17,6 +17,7 @@ from sqlmesh.core.state_sync import StateSync, Versions from sqlmesh.core.state_sync.base import PromotionResult from sqlmesh.schedulers.airflow.client import AirflowClient +from sqlmesh.core.state_sync.common import StateStream if t.TYPE_CHECKING: from sqlmesh.utils.date import TimeLike @@ -361,3 +362,9 @@ def close(self) -> None: def state_type(self) -> str: """Returns the type of state sync.""" return "airflow_http" + + def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream: + raise NotImplementedError("State export is not supported by the Airflow state sync") + + def import_(self, stream: StateStream, clear: bool = True) -> None: + raise NotImplementedError("State import is not supported by the Airflow state sync") diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index fdcdf4486b..6ab850c91c 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -7,6 +7,7 @@ import pytest from click.testing import CliRunner import time_machine +import json from sqlmesh.cli.example_project import ProjectTemplate, init_example_project from sqlmesh.cli.main import cli @@ -1240,3 +1241,376 @@ def test_lint(runner, tmp_path): ) assert result.output.count("Linter errors for") == 2 assert result.exit_code == 1 + + +def test_state_export(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create some state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export it + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "export", "-o", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # verify output + assert "Gateway: local" in result.output + assert "Type: duckdb" in result.output + assert "Exporting versions" in result.output + assert "Exporting snapshots" in result.output + assert "Exporting environments" in result.output + assert "State exported successfully" in result.output + + assert state_export_file.exists() + assert len(state_export_file.read_text()) > 0 + + +def test_state_export_specific_environments(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + (tmp_path / "models" / "new_model.sql").write_text( + """ + MODEL ( + name sqlmesh_example.new_model, + kind FULL + ); + + SELECT 1; + """ + ) + + # create dev env with new model + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export non existent env - should fail + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "nonexist", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 1 + assert "No such environment: nonexist" in result.output + + # export dev, should contain original snapshots + new one + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "dev", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Environment: dev" in result.output + assert "State exported successfully" in result.output + + state = json.loads(state_export_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 4 + assert any("new_model" in s["name"] for s in state["snapshots"]) + assert len(state["environments"]) == 1 + assert "dev" in state["environments"] + assert "prod" not in state["environments"] + + +def test_state_export_local(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # note: we have not plan+applied at all, we are just exporting local state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Exporting local state" in result.output + assert "the resulting file cannot be imported" in result.output + assert "State exported successfully" in result.output + + state = json.loads(state_export_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 3 + assert not state["metadata"]["importable"] + assert len(state["environments"]) == 0 + + # test mutually exclusive with --environment + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "foo", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 1 + + assert "Cannot specify both --environment and --local" in result.output + + +def test_state_import(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create some state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export it + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "export", "-o", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import it back + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "import", "-i", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + assert "Gateway: local" in result.output + assert "Type: duckdb" in result.output + assert "Importing versions" in result.output + assert "Importing snapshots" in result.output + assert "Importing environments" in result.output + assert "State imported successfully" in result.output + + # plan should have no changes + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + ], + ) + assert result.exit_code == 0 + assert "No changes to plan" in result.output + + +def test_state_import_replace(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + (tmp_path / "models" / "new_model.sql").write_text( + """ + MODEL ( + name sqlmesh_example.new_model, + kind FULL + ); + + SELECT 1; + """ + ) + + # create dev with new model + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # prove both dev and prod exist + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "environments", + ], + ) + assert result.exit_code == 0 + assert "dev -" in result.output + assert "prod -" in result.output + + # export just prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "prod", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import it back with --replace + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "import", + "-i", + str(state_export_file), + "--replace", + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "State imported successfully" in result.output + + # prove only prod exists now + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "environments", + ], + ) + assert result.exit_code == 0 + assert "dev -" not in result.output + assert "prod -" in result.output + + +def test_state_import_local(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # local state export + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import should fail - local state is not importable + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "import", "-i", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 1 + assert "State file is marked as not importable" in result.output + assert "Aborting" in result.output diff --git a/tests/core/state_sync/test_export_import.py b/tests/core/state_sync/test_export_import.py new file mode 100644 index 0000000000..de4fb2cdca --- /dev/null +++ b/tests/core/state_sync/test_export_import.py @@ -0,0 +1,568 @@ +import pytest +from pathlib import Path +from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync +from sqlmesh.core.state_sync.export_import import export_state, import_state +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core import constants as c +from sqlmesh.cli.example_project import init_example_project +from sqlmesh.core.context import Context +from sqlmesh.core.environment import Environment +from sqlmesh.core.config import Config, GatewayConfig, DuckDBConnectionConfig, ModelDefaultsConfig + +import json + + +@pytest.fixture +def example_project_config(tmp_path: Path) -> Config: + return Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + default_gateway="main", + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + ), + ) + + +@pytest.fixture +def state_sync(tmp_path: Path, example_project_config: Config) -> StateSync: + return EngineAdapterStateSync( + engine_adapter=example_project_config.get_state_connection("main").create_engine_adapter(), # type: ignore + schema=c.SQLMESH, + context_path=tmp_path, + ) + + +def test_export_empty_state(tmp_path: Path, state_sync: StateSync) -> None: + output_file = tmp_path / "state_dump.json" + + # Cannot dump an un-migrated state database + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + export_state(state_sync, output_file) + + state_sync.migrate(default_catalog=None) + + export_state(state_sync, output_file) + + state = json.loads(output_file.read_text(encoding="utf8")) + + assert "metadata" in state + metadata = state["metadata"] + assert "timestamp" in metadata + assert "file_version" in metadata + assert "importable" in metadata + + assert "versions" in state + versions = state["versions"] + assert "schema_version" in versions + assert "sqlglot_version" in versions + assert "sqlmesh_version" in versions + + assert "snapshots" in state + assert isinstance(state["snapshots"], list) + assert len(state["snapshots"]) == 0 + + assert "environments" in state + assert isinstance(state["environments"], dict) + assert len(state["environments"]) == 0 + + +def test_export_entire_project( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # prod + plan = context.plan(auto_apply=True) + assert len(plan.modified_snapshots) > 0 + + # modify full_model + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + # dev + context.load() + context.plan(environment="dev", auto_apply=True, skip_tests=True) + + output_file = tmp_path / "state_dump.json" + export_state(state_sync, output_file) + + state = json.loads(output_file.read_text(encoding="utf8")) + assert "metadata" in state + # full project dumps can always be imported back + assert state["metadata"]["importable"] + + assert "versions" in state + + assert len(state["snapshots"]) > 0 + snapshot_names = [s["name"] for s in state["snapshots"]] + assert len(snapshot_names) == 5 + assert '"warehouse"."sqlmesh_example"."full_model"' in snapshot_names # will be in here twice + assert '"warehouse"."sqlmesh_example"."incremental_model"' in snapshot_names + assert '"warehouse"."sqlmesh_example"."seed_model"' in snapshot_names + assert '"warehouse"."sqlmesh_example"."new_model"' in snapshot_names + + assert "prod" in state["environments"] + assert "dev" in state["environments"] + + prod = state["environments"]["prod"] + assert len(prod["snapshots"]) == 3 + prod_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(prod).snapshots] + + dev = state["environments"]["dev"] + assert len(dev["snapshots"]) == 4 + dev_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(dev).snapshots] + + full_model_id = next(s for s in dev_snapshot_ids if "full_model" in s.name) + incremental_model_id = next(s for s in dev_snapshot_ids if "incremental_model" in s.name) + seed_model_id = next(s for s in dev_snapshot_ids if "seed_model" in s.name) + new_model_id = next(s for s in dev_snapshot_ids if "new_model" in s.name) + + assert incremental_model_id in prod_snapshot_ids + assert seed_model_id in prod_snapshot_ids + assert full_model_id not in prod_snapshot_ids + assert new_model_id not in prod_snapshot_ids + + +def test_export_specific_environment( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + with pytest.raises(SQLMeshError, match=r"No such environment"): + export_state(state_sync, output_file, environment_names=["FOO"]) + + # modify full_model + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # plan dev + context.load() + plan = context.plan(environment="dev", auto_apply=True, skip_tests=True) + assert len(plan.modified_snapshots) == 1 + + # export dev - all models should be included + export_state(state_sync, output_file, environment_names=["dev"]) + + dev_state = json.loads(output_file.read_text(encoding="utf8")) + + assert "metadata" in dev_state + assert "versions" in dev_state + + assert len(dev_state["snapshots"]) == 3 + snapshot_names = [s["name"] for s in dev_state["snapshots"]] + assert any("full_model" in name for name in snapshot_names) + assert any("incremental_model" in name for name in snapshot_names) + assert any("seed_model" in name for name in snapshot_names) + dev_full_model = next(s for s in dev_state["snapshots"] if "full_model" in s["name"]) + + assert len(dev_state["environments"]) == 1 + assert "dev" in dev_state["environments"] + + # this state dump is still importable even though its just a subset + assert dev_state["metadata"]["importable"] + + # export prod - prod full_model should be a different version to dev + export_state(state_sync, output_file, environment_names=["prod"]) + + prod_state = json.loads(output_file.read_text(encoding="utf8")) + snapshot_names = [s["name"] for s in prod_state["snapshots"]] + assert any("full_model" in name for name in snapshot_names) + assert any("incremental_model" in name for name in snapshot_names) + assert any("seed_model" in name for name in snapshot_names) + prod_full_model = next(s for s in prod_state["snapshots"] if "full_model" in s["name"]) + + assert len(prod_state["environments"]) == 1 + assert "prod" in prod_state["environments"] + assert prod_state["metadata"]["importable"] + + assert dev_full_model["fingerprint"]["data_hash"] != prod_full_model["fingerprint"]["data_hash"] + + +def test_export_local_state( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + # modify full_model - create local change + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + assert len(context.snapshots) == 3 + + context.load() + + assert len(context.snapshots) == 4 + + export_state(state_sync, output_file, context.snapshots) + state = json.loads(output_file.read_text(encoding="utf8")) + assert "metadata" in state + assert "versions" in state + + # this state dump cannot be imported because its just local state + assert not state["metadata"]["importable"] + + # no environments because local state is just snapshots + assert len(state["environments"]) == 0 + + snapshots = state["snapshots"] + assert len(snapshots) == 4 + full_model = next(s for s in snapshots if "full_model" in s["name"]) + new_model = next(s for s in snapshots if "new_model" in s["name"]) + + assert "'1' as modified" in full_model["node"]["query"] + assert "SELECT 1 as id" in new_model["node"]["query"] + + +def test_import_invalid_file(tmp_path: Path, state_sync: StateSync) -> None: + state_file = tmp_path / "state.json" + state_file.write_text("invalid json file") + + with pytest.raises(Exception, match=r"Invalid JSON character"): + import_state(state_sync, state_file) + + state_file.write_text("[]") + with pytest.raises(SQLMeshError, match=r"Expected JSON object"): + import_state(state_sync, state_file) + + state_file.write_text("{}") + with pytest.raises(SQLMeshError, match=r"Expecting a 'metadata' key"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": [] }') + with pytest.raises(SQLMeshError, match=r"Expecting the 'metadata' key to contain an object"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": {} }') + with pytest.raises(SQLMeshError, match=r"Unable to determine state file format version"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": { "file_version": "blah" } }') + with pytest.raises(SQLMeshError, match=r"Unable to parse state file format version"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": { "file_version": 1, "importable": false } }') + with pytest.raises(SQLMeshError, match=r"not importable"): + import_state(state_sync, state_file) + + +def test_import_from_older_version_export_fails(tmp_path: Path, state_sync: StateSync) -> None: + state_sync.migrate(default_catalog=None) + current_version = state_sync.get_versions() + + major, minor = current_version.minor_sqlmesh_version + older_version = current_version.copy(update=dict(sqlmesh_version=f"{major}.{minor - 1}.0")) + + assert older_version.minor_sqlmesh_version < current_version.minor_sqlmesh_version + + state_file = tmp_path / "state.json" + state_versions = older_version.model_dump(mode="json") + state_file.write_text( + json.dumps( + { + "metadata": { + "timestamp": "2024-01-01 00:00:00", + "file_version": 1, + "importable": True, + }, + "versions": state_versions, + } + ) + ) + + with pytest.raises(SQLMeshError, match=r"SQLMesh version mismatch"): + import_state(state_sync, state_file) + + +def test_import_from_newer_version_export_fails(tmp_path: Path, state_sync: StateSync) -> None: + state_sync.migrate(default_catalog=None) + current_version = state_sync.get_versions() + + major, minor = current_version.minor_sqlmesh_version + newer_version = current_version.copy(update=dict(sqlmesh_version=f"{major}.{minor + 1}.0")) + + assert newer_version.minor_sqlmesh_version > current_version.minor_sqlmesh_version + + state_file = tmp_path / "state.json" + state_versions = newer_version.model_dump(mode="json") + state_file.write_text( + json.dumps( + { + "versions": state_versions, + "metadata": { + "timestamp": "2024-01-01 00:00:00", + "file_version": 1, + "importable": True, + }, + } + ) + ) + + with pytest.raises(SQLMeshError, match=r"SQLMesh version mismatch"): + import_state(state_sync, state_file) + + +def test_import_local_state_fails( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + export_state(state_sync, output_file, context.snapshots) + state = json.loads(output_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 3 + + with pytest.raises(SQLMeshError, match=r"not importable"): + import_state(state_sync, output_file) + + +def test_import_partial( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind FULL, + cron '@daily' + ); + + SELECT 1 as id; + """) + + # create dev + context.load() + context.plan(environment="dev", auto_apply=True, skip_tests=True) + + # export just dev + export_state(state_sync, output_file, environment_names=["dev"]) + + state = json.loads(output_file.read_text(encoding="utf8")) + # mess with the file to rename "dev" to "dev2" + dev = state["environments"].pop("dev") + dev["name"] = "dev2" + state["environments"]["dev2"] = dev + + assert list(state["environments"].keys()) == ["dev2"] + output_file.write_text(json.dumps(state), encoding="utf8") + + # import "dev2" + import_state(state_sync, output_file, clear=False) + + # StateSync should have "prod", "dev" and "dev2". + assert sorted(list(state_sync.get_environments_summary().keys())) == ["dev", "dev2", "prod"] + + assert not context.plan(environment="dev", skip_tests=True).has_changes + assert not context.plan(environment="dev2", skip_tests=True).has_changes + assert context.plan( + environment="prod", skip_tests=True + ).has_changes # prod has changes the 'new_model' model hasnt been applied + + +def test_roundtrip(tmp_path: Path, example_project_config: Config, state_sync: StateSync) -> None: + state_file = tmp_path / "state_dump.json" + + init_example_project(path=tmp_path, dialect="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # populate initial state + plan = context.plan(auto_apply=True) + assert plan.has_changes + + # plan again to prove no changes + plan = context.plan(auto_apply=True) + assert not plan.has_changes + + export_state(state_sync, state_file) + assert len(state_file.read_text()) > 0 + + # destroy state + assert isinstance(state_sync, EngineAdapterStateSync) + state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) + + # state was destroyed, plan should have changes + state_sync.migrate(default_catalog=None) + plan = context.plan() + assert plan.has_changes + + # load in state dump + import_state(state_sync, state_file) + + # plan should have no changes now our state is back + plan = context.plan() + assert not plan.has_changes + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind FULL, + cron '@daily' + ); + + SELECT 1 as id; + """) + + context.load() + plan = context.plan(environment="dev", auto_apply=True) + assert plan.has_changes + + plan = context.plan(environment="dev") + assert not plan.has_changes + + # dump new state that contains the 'dev' environment + export_state(state_sync, state_file) + + # show state destroyed + state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + state_sync.get_versions(validate=True) + + state_sync.migrate(default_catalog=None) + import_state(state_sync, state_file) + + # should be no changes in dev + assert not context.plan(environment="dev").has_changes + + # prod should show a change for adding 'new_model' + prod_plan = context.plan(environment="prod") + assert prod_plan.new_snapshots == [] + assert len(prod_plan.modified_snapshots) == 1 + assert "new_model" in list(prod_plan.modified_snapshots.values())[0].name + + +def test_roundtrip_includes_auto_restatements( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + init_example_project(path=tmp_path, dialect="duckdb") + + # add a model with auto restatements + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + context.plan(auto_apply=True) + + # dump state + output_file = tmp_path / "state_dump.json" + export_state(state_sync, output_file) + state = json.loads(output_file.read_text(encoding="utf8")) + + snapshots = state["snapshots"] + assert len(snapshots) == 4 + + # auto restatements only work after a cadence run + new_model_snapshot = next(s for s in snapshots if "new_model" in s["name"]) + assert "next_auto_restatement_ts" not in new_model_snapshot + + # trigger cadence run and re-dump show auto restatement dumped + context.run() + + export_state(state_sync, output_file) + state = json.loads(output_file.read_text()) + + new_model_snapshot = next(s for s in state["snapshots"] if "new_model" in s["name"]) + assert new_model_snapshot["next_auto_restatement_ts"] > 0 + + # import the state again and run a plan to show there is no changes / the auto restatement was imported + import_state(state_sync, output_file) + + plan = context.plan(skip_tests=True) + assert not plan.has_changes diff --git a/tests/core/test_state_sync.py b/tests/core/state_sync/test_state_sync.py similarity index 100% rename from tests/core/test_state_sync.py rename to tests/core/state_sync/test_state_sync.py From 54146dc470dae4ecf1fb2ba93df94e15b6d8a373 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Wed, 2 Apr 2025 05:03:35 +0000 Subject: [PATCH 2/2] PR feedback --- docs/concepts/state.md | 9 +++- sqlmesh/core/console.py | 4 ++ sqlmesh/core/state_sync/common.py | 12 +++-- sqlmesh/core/state_sync/db/facade.py | 27 +++++----- sqlmesh/core/state_sync/export_import.py | 15 +++--- tests/core/state_sync/test_export_import.py | 57 +++++++++++++++++++-- 6 files changed, 97 insertions(+), 27 deletions(-) diff --git a/docs/concepts/state.md b/docs/concepts/state.md index 29e1feed59..ea5391ec20 100644 --- a/docs/concepts/state.md +++ b/docs/concepts/state.md @@ -74,7 +74,14 @@ The state file is a simple `json` file that looks like: /* object for every Virtual Data Environment in the project. key = environment name, value = environment details */ "environments": { "prod": { - "..." + /* information about the environment itself */ + "environment": { + "..." + }, + /* information about any before_all / after_all statements for this environment */ + "statements": [ + "..." + ] } } } diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 0c78d982aa..8ba41ad3f6 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -1129,6 +1129,8 @@ def stop_state_export(self, success: bool, output_file: Path) -> None: self.state_export_progress.stop() self.state_export_progress = None + self.log_status_update("") + if success: self.log_success(f"State exported successfully to '{output_file.as_posix()}'") else: @@ -1252,6 +1254,8 @@ def stop_state_import(self, success: bool, input_file: Path) -> None: self.state_import_progress.stop() self.state_import_progress = None + self.log_status_update("") + if success: self.log_success(f"State imported successfully from '{input_file.as_posix()}'") else: diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 1d0778a4d8..90ab67989c 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -8,7 +8,8 @@ from sqlmesh.core.console import Console from sqlmesh.core.dialect import schema_ -from sqlmesh.core.environment import Environment +from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.core.environment import Environment, EnvironmentStatements from sqlmesh.utils.errors import SQLMeshError from sqlmesh.core.snapshot import Snapshot @@ -96,6 +97,11 @@ def chunk_iterable(iterable: t.Iterable[T], size: int = 10) -> t.Iterable[t.Iter yield itertools.chain([first], itertools.islice(iterator, size - 1)) +class EnvironmentWithStatements(PydanticModel): + environment: Environment + statements: t.List[EnvironmentStatements] = [] + + class StateStream(abc.ABC): """ Represents a stream of state either going into the StateSync (perhaps loaded from a file) @@ -114,5 +120,5 @@ def snapshots(self) -> t.Iterable[Snapshot]: @property @abc.abstractmethod - def environments(self) -> t.Iterable[Environment]: - """A stream of Environment objects""" + def environments(self) -> t.Iterable[EnvironmentWithStatements]: + """A stream of Environments with any EnvironmentStatements attached""" diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 8b6e373d3c..884955d98e 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -51,6 +51,7 @@ transactional, StateStream, chunk_iterable, + EnvironmentWithStatements, ) from sqlmesh.core.state_sync.db.interval import IntervalState from sqlmesh.core.state_sync.db.environment import EnvironmentState @@ -482,11 +483,13 @@ def snapshots(self) -> t.Iterable[Snapshot]: yield from state_sync.get_snapshots(chunk).values() @property - def environments(self) -> t.Iterable[Environment]: - if environment_names: - yield from selected_environments - else: - yield from state_sync.get_environments() + def environments(self) -> t.Iterable[EnvironmentWithStatements]: + envs = selected_environments if environment_names else state_sync.get_environments() + + for env in envs: + yield EnvironmentWithStatements( + environment=env, statements=state_sync.get_environment_statements(env.name) + ) return _DumpStateStream() @@ -515,7 +518,7 @@ def import_(self, stream: StateStream, clear: bool = True) -> None: ) overwrite_existing_snapshots = ( not clear - ) # if clear=True, all existing snapshjots were dropped anyway + ) # if clear=True, all existing snapshots were dropped anyway self.snapshot_state.push_snapshots( snapshot_iterator, overwrite=overwrite_existing_snapshots ) @@ -529,12 +532,12 @@ def import_(self, stream: StateStream, clear: bool = True) -> None: } ) - existing_environments = set(self.get_environments_summary().keys()) if not clear else set() - for environment in stream.environments: - if not clear and environment.name in existing_environments: - self.environment_state.update_environment(environment) - else: - self.promote(environment) + for environment_with_statements in stream.environments: + environment = environment_with_statements.environment + self.environment_state.update_environment(environment) + self.environment_state.update_environment_statements( + environment.name, environment.plan_id, environment_with_statements.statements + ) self.update_auto_restatements(auto_restatements) diff --git a/sqlmesh/core/state_sync/export_import.py b/sqlmesh/core/state_sync/export_import.py index 0c390831a4..c2a43ada01 100644 --- a/sqlmesh/core/state_sync/export_import.py +++ b/sqlmesh/core/state_sync/export_import.py @@ -3,10 +3,9 @@ from sqlmesh.core.state_sync import StateSync from sqlmesh.core.snapshot import Snapshot from sqlmesh.utils.date import now, to_tstz -from sqlmesh.core.environment import Environment from sqlmesh.utils.pydantic import _expression_encoder from sqlmesh.core.state_sync import Versions -from sqlmesh.core.state_sync.common import StateStream +from sqlmesh.core.state_sync.common import StateStream, EnvironmentWithStatements from sqlmesh.core.console import Console from pathlib import Path from sqlmesh.core.console import NoopConsole @@ -45,7 +44,7 @@ def snapshots(self) -> t.Iterable[Snapshot]: return iter(snapshots.values()) @property - def environments(self) -> t.Iterable[Environment]: + def environments(self) -> t.Iterable[EnvironmentWithStatements]: return [] return _LocalStateStream() @@ -71,11 +70,11 @@ def _dump_snapshots( @streamable_dict def _dump_environments( - environment_stream: t.Iterable[Environment], + environment_stream: t.Iterable[EnvironmentWithStatements], ) -> t.Iterator[t.Tuple[str, t.Any]]: console.update_state_export_progress(environment_count=0) for idx, env in enumerate(environment_stream): - yield env.name, _dump_pydantic_model(env) + yield env.environment.name, _dump_pydantic_model(env) console.update_state_export_progress(environment_count=idx + 1) @streamable_dict @@ -129,12 +128,14 @@ def snapshots(self) -> t.Iterable[Snapshot]: console.update_state_import_progress(snapshots_complete=True) @property - def environments(self) -> t.Iterable[Environment]: + def environments(self) -> t.Iterable[EnvironmentWithStatements]: stream = data()["environments"] console.update_state_import_progress(environment_count=0) for idx, (_, raw_environment) in enumerate(stream.items()): - environment = Environment.model_validate(to_standard_types(raw_environment)) + environment = EnvironmentWithStatements.model_validate( + to_standard_types(raw_environment) + ) yield environment console.update_state_import_progress(environment_count=idx + 1) diff --git a/tests/core/state_sync/test_export_import.py b/tests/core/state_sync/test_export_import.py index de4fb2cdca..4e4aee5861 100644 --- a/tests/core/state_sync/test_export_import.py +++ b/tests/core/state_sync/test_export_import.py @@ -1,6 +1,6 @@ import pytest from pathlib import Path -from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync +from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync, CachingStateSync from sqlmesh.core.state_sync.export_import import export_state, import_state from sqlmesh.utils.errors import SQLMeshError from sqlmesh.core import constants as c @@ -136,11 +136,11 @@ def test_export_entire_project( assert "prod" in state["environments"] assert "dev" in state["environments"] - prod = state["environments"]["prod"] + prod = state["environments"]["prod"]["environment"] assert len(prod["snapshots"]) == 3 prod_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(prod).snapshots] - dev = state["environments"]["dev"] + dev = state["environments"]["dev"]["environment"] assert len(dev["snapshots"]) == 4 dev_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(dev).snapshots] @@ -427,7 +427,7 @@ def test_import_partial( state = json.loads(output_file.read_text(encoding="utf8")) # mess with the file to rename "dev" to "dev2" dev = state["environments"].pop("dev") - dev["name"] = "dev2" + dev["environment"]["name"] = "dev2" state["environments"]["dev2"] = dev assert list(state["environments"].keys()) == ["dev2"] @@ -566,3 +566,52 @@ def test_roundtrip_includes_auto_restatements( plan = context.plan(skip_tests=True) assert not plan.has_changes + + +def test_roundtrip_includes_environment_statements(tmp_path: Path) -> None: + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + default_gateway="main", + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + ), + before_all=["select 1 as before_all"], + after_all=["select 2 as after_all"], + ) + + context = Context(paths=tmp_path, config=config) + context.plan(auto_apply=True) + + state_file = tmp_path / "state_dump.json" + context.export_state(state_file) + + environments = json.loads(state_file.read_text(encoding="utf8"))["environments"] + + assert environments["prod"]["statements"][0]["before_all"][0] == "select 1 as before_all" + assert environments["prod"]["statements"][0]["after_all"][0] == "select 2 as after_all" + + assert not context.plan().has_changes + + state_sync = context.state_sync + assert isinstance(state_sync, CachingStateSync) + assert isinstance(state_sync.state_sync, EngineAdapterStateSync) + + # show state destroyed + state_sync.state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) # type: ignore + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + state_sync.get_versions(validate=True) + + state_sync.migrate(default_catalog=None) + import_state(state_sync, state_file) + + assert not context.plan().has_changes + + environment_statements = state_sync.get_environment_statements("prod") + assert len(environment_statements) == 1 + assert environment_statements[0].before_all[0] == "select 1 as before_all" + assert environment_statements[0].after_all[0] == "select 2 as after_all"