From 40415499c005bc87a1fbd1f8584b6f75c091e7d3 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:10:09 -0700 Subject: [PATCH] feat: add snowflake grant support --- sqlmesh/core/engine_adapter/_typing.py | 1 + sqlmesh/core/engine_adapter/postgres.py | 6 +- sqlmesh/core/engine_adapter/risingwave.py | 1 + sqlmesh/core/engine_adapter/snowflake.py | 122 ++++++++++- .../engine_adapter/integration/__init__.py | 51 +++++ .../integration/test_integration.py | 192 +++++++++++++++++ .../integration/test_integration_postgres.py | 166 --------------- tests/core/engine_adapter/test_snowflake.py | 196 ++++++++++++++++++ 8 files changed, 564 insertions(+), 171 deletions(-) diff --git a/sqlmesh/core/engine_adapter/_typing.py b/sqlmesh/core/engine_adapter/_typing.py index a8c52eef47..77bcf2c015 100644 --- a/sqlmesh/core/engine_adapter/_typing.py +++ b/sqlmesh/core/engine_adapter/_typing.py @@ -31,3 +31,4 @@ QueryOrDF = t.Union[Query, DF] GrantsConfig = t.Dict[str, t.List[str]] + DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke) diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index 873191992f..b052e6d829 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -18,9 +18,7 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName - from sqlmesh.core.engine_adapter._typing import DF, GrantsConfig, QueryOrDF - - DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke) + from sqlmesh.core.engine_adapter._typing import DCL, DF, GrantsConfig, QueryOrDF logger = logging.getLogger(__name__) @@ -38,7 +36,7 @@ class PostgresEngineAdapter( HAS_VIEW_BINDING = True CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False - MAX_IDENTIFIER_LENGTH = 63 + MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { diff --git a/sqlmesh/core/engine_adapter/risingwave.py b/sqlmesh/core/engine_adapter/risingwave.py index fdcee90f0f..61b44f5bbb 100644 --- a/sqlmesh/core/engine_adapter/risingwave.py +++ b/sqlmesh/core/engine_adapter/risingwave.py @@ -32,6 +32,7 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter): SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_TRANSACTIONS = False MAX_IDENTIFIER_LENGTH = None + SUPPORTS_GRANTS = False def columns( self, table_name: TableName, include_pseudo_columns: bool = False diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 9c27b45115..557467ad9e 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -34,7 +34,14 @@ import pandas as pd from sqlmesh.core._typing import SchemaName, SessionProperties, TableName - from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF, SnowparkSession + from sqlmesh.core.engine_adapter._typing import ( + DCL, + DF, + GrantsConfig, + Query, + QueryOrDF, + SnowparkSession, + ) from sqlmesh.core.node import IntervalUnit @@ -73,6 +80,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi MANAGED_TABLE_KIND = "DYNAMIC TABLE" SNOWPARK = "snowpark" SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTS_GRANTS = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -127,6 +135,118 @@ def snowpark(self) -> t.Optional[SnowparkSession]: def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + if table_type == DataObjectType.MANAGED_TABLE: + return "DYNAMIC TABLE" + return "TABLE" + + def _get_current_schema(self) -> str: + """Returns the current default schema for the connection.""" + result = self.fetchone("SELECT CURRENT_SCHEMA()") + if not result or not result[0]: + raise SQLMeshError("Unable to determine current schema") + return str(result[0]) + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grant_config: + return expressions + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grant_config.items(): + for principal in principals: + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], + "securable": table.copy(), + "principals": [principal], + } + + if object_kind: + args["kind"] = exp.Var(this=object_kind) + + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + + def _apply_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + schema_identifier = table.args.get("db") or normalize_identifiers( + exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect + ) + catalog_identifier = table.args.get("catalog") + if not catalog_identifier: + current_catalog = self.get_current_catalog() + if not current_catalog: + raise SQLMeshError("Unable to determine current catalog for fetching grants") + catalog_identifier = normalize_identifiers( + exp.to_identifier(current_catalog, quoted=True), dialect=self.dialect + ) + catalog_identifier.set("quoted", True) + table_identifier = table.args.get("this") + + grant_expr = ( + exp.select("privilege_type", "grantee") + .from_( + exp.table_( + "TABLE_PRIVILEGES", + db="INFORMATION_SCHEMA", + catalog=catalog_identifier, + ) + ) + .where( + exp.and_( + exp.column("table_schema").eq(exp.Literal.string(schema_identifier.this)), + exp.column("table_name").eq(exp.Literal.string(table_identifier.this)), # type: ignore + exp.column("grantor").eq(exp.func("CURRENT_ROLE")), + exp.column("grantee").neq(exp.func("CURRENT_ROLE")), + ) + ) + ) + + results = self.fetchall(grant_expr) + + grants_dict: GrantsConfig = {} + for privilege_raw, grantee_raw in results: + if privilege_raw is None or grantee_raw is None: + continue + + privilege = str(privilege_raw) + grantee = str(grantee_raw) + if not privilege or not grantee: + continue + + grantees = grants_dict.setdefault(privilege, []) + if grantee not in grantees: + grantees.append(grantee) + + return grants_dict + def _create_catalog(self, catalog_name: exp.Identifier) -> None: props = exp.Properties( expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))] diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index c5377e309a..7e6dae2f1b 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -5,10 +5,12 @@ import sys import typing as t import time +from contextlib import contextmanager import pandas as pd # noqa: TID253 import pytest from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh import Config, Context, EngineAdapter from sqlmesh.core.config import load_config_from_paths @@ -744,6 +746,55 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]: self._context.upsert_model(model) return self._context, model + def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str: + password = password or random_id() + if self.dialect == "postgres": + return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'" + if self.dialect == "snowflake": + return f"CREATE ROLE {username}" + raise ValueError(f"User creation not supported for dialect: {self.dialect}") + + def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None: + create_user_sql = self._get_create_user_or_role(username, password) + self.engine_adapter.execute(create_user_sql) + + @contextmanager + def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]: + created_users = [] + roles = {} + + try: + for role_name in role_names: + user_name = normalize_identifiers( + self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect + ).sql(dialect=self.dialect) + password = random_id() + self._create_user_or_role(user_name, password) + created_users.append(user_name) + roles[role_name] = user_name + + yield roles + + finally: + for user_name in created_users: + self._cleanup_user_or_role(user_name) + + def _cleanup_user_or_role(self, user_name: str) -> None: + """Helper function to clean up a PostgreSQL user and all their dependencies.""" + try: + if self.dialect == "postgres": + self.engine_adapter.execute(f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE usename = '{user_name}' AND pid <> pg_backend_pid() + """) + self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"') + self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') + elif self.dialect == "snowflake": + self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}") + except Exception: + pass + def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None: current_attempt = 0 diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 5a708e1e4c..f48a95d39b 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -3834,3 +3834,195 @@ def _assert_mview_value(value: int): assert any("Replacing view" in call[0][0] for call in mock_logger.call_args_list) _assert_mview_value(value=2) + + +def test_sync_grants_config(ctx: TestContext) -> None: + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("sync_grants_integration") + + with ctx.create_users_or_roles("reader", "writer", "admin") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + "SELECT": [roles["reader"]], + "INSERT": [roles["writer"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(current_grants.get("SELECT", [])) == {roles["reader"]} + assert set(current_grants.get("INSERT", [])) == {roles["writer"]} + + target_grants = { + "SELECT": [roles["writer"], roles["admin"]], + "UPDATE": [roles["admin"]], + } + ctx.engine_adapter.sync_grants_config(table, target_grants) + + synced_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(synced_grants.get("SELECT", [])) == { + roles["writer"], + roles["admin"], + } + assert set(synced_grants.get("UPDATE", [])) == {roles["admin"]} + assert synced_grants.get("INSERT", []) == [] + + +def test_grants_sync_empty_config(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grants_empty_test") + + with ctx.create_users_or_roles("user") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + "SELECT": [roles["user"]], + "INSERT": [roles["user"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + initial_current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert roles["user"] in initial_current_grants.get("SELECT", []) + assert roles["user"] in initial_current_grants.get("INSERT", []) + + ctx.engine_adapter.sync_grants_config(table, {}) + + final_grants = ctx.engine_adapter._get_current_grants_config(table) + assert final_grants == {} + + +def test_grants_case_insensitive_grantees(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + with ctx.create_users_or_roles("test_reader", "test_writer") as roles: + table = ctx.table("grants_quoted_test") + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + test_schema = table.db + for role_credentials in roles.values(): + ctx.engine_adapter.execute( + f'GRANT USAGE ON SCHEMA "{test_schema}" TO "{role_credentials}"' + ) + + reader = roles["test_reader"] + writer = roles["test_writer"] + + grants_config = {"SELECT": [reader, writer.upper()]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + # Grantees are still in lowercase + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get("SELECT", []) + assert writer in current_grants.get("SELECT", []) + + # Revoke writer + grants_config = {"SELECT": [reader.upper()]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get("SELECT", []) + assert writer not in current_grants.get("SELECT", []) + + +def test_grants_plan(ctx: TestContext, tmp_path: Path): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grant_model").sql(dialect=ctx.dialect) + with ctx.create_users_or_roles("analyst", "etl_user") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name {table}, + kind FULL, + grants ( + 'select' = ['{roles["analyst"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date + """ + + (tmp_path / "models" / "grant_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + + # Physical layer w/ grants + table_name = snapshot.table_name() + view_name = snapshot.qualified_view_name.for_environment( + plan_result.environment_naming_info, dialect=ctx.dialect + ) + current_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=ctx.dialect) + ) + assert current_grants == {"SELECT": [roles["analyst"]]} + + # Virtual layer (view) w/ grants + virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert virtual_grants == {"SELECT": [roles["analyst"]]} + + # Update model with query change and new grants + updated_model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name {table}, + kind FULL, + grants ( + 'select' = ['{roles["analyst"]}', '{roles["etl_user"]}'], + 'insert' = ['{roles["etl_user"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date, 'v2' as version + """, + default_dialect=context.default_dialect, + ), + dialect=context.default_dialect, + ) + context.upsert_model(updated_model) + + plan = context.plan(auto_apply=True, no_prompts=True) + plan_result = PlanResults.create(plan, ctx, ctx.add_test_suffix(TEST_SCHEMA)) + assert len(plan_result.plan.directly_modified) == 1 + + new_snapshot = plan_result.snapshot_for(updated_model) + assert new_snapshot is not None + + new_table_name = new_snapshot.table_name() + final_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(new_table_name, dialect=ctx.dialect) + ) + expected_final_grants = { + "SELECT": [roles["analyst"], roles["etl_user"]], + "INSERT": [roles["etl_user"]], + } + assert set(final_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) + assert final_grants.get("INSERT", []) == expected_final_grants["INSERT"] + + # Virtual layer should also have the updated grants + updated_virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) + assert updated_virtual_grants.get("INSERT", []) == expected_final_grants["INSERT"] diff --git a/tests/core/engine_adapter/integration/test_integration_postgres.py b/tests/core/engine_adapter/integration/test_integration_postgres.py index 635aeb474a..68686fbceb 100644 --- a/tests/core/engine_adapter/integration/test_integration_postgres.py +++ b/tests/core/engine_adapter/integration/test_integration_postgres.py @@ -375,172 +375,6 @@ def _mutate_config(gateway: str, config: Config): # Grants Integration Tests -def test_grants_sync(engine_adapter: PostgresEngineAdapter, ctx: TestContext, config: Config): - with create_users(engine_adapter, "user1", "user2", "user3") as roles: - table = ctx.table("grants_sync_test") - engine_adapter.create_table( - table, {"id": exp.DataType.build("INT"), "data": exp.DataType.build("TEXT")} - ) - - initial_grants = { - "SELECT": [roles["user1"]["username"], roles["user2"]["username"]], - "INSERT": [roles["user1"]["username"]], - } - engine_adapter.sync_grants_config(table, initial_grants) - - initial_current_grants = engine_adapter._get_current_grants_config(table) - assert roles["user1"]["username"] in initial_current_grants.get("SELECT", []) - assert roles["user2"]["username"] in initial_current_grants.get("SELECT", []) - assert roles["user1"]["username"] in initial_current_grants.get("INSERT", []) - - target_grants = { - "SELECT": [roles["user2"]["username"], roles["user3"]["username"]], - "UPDATE": [roles["user3"]["username"]], - } - engine_adapter.sync_grants_config(table, target_grants) - - final_grants = engine_adapter._get_current_grants_config(table) - - assert set(final_grants.get("SELECT", [])) == { - roles["user2"]["username"], - roles["user3"]["username"], - } - assert set(final_grants.get("UPDATE", [])) == {roles["user3"]["username"]} - assert final_grants.get("INSERT", []) == [] - - -def test_grants_sync_empty_config( - engine_adapter: PostgresEngineAdapter, ctx: TestContext, config: Config -): - with create_users(engine_adapter, "user") as roles: - table = ctx.table("grants_empty_test") - engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) - - initial_grants = { - "SELECT": [roles["user"]["username"]], - "INSERT": [roles["user"]["username"]], - } - engine_adapter.sync_grants_config(table, initial_grants) - - initial_current_grants = engine_adapter._get_current_grants_config(table) - assert roles["user"]["username"] in initial_current_grants.get("SELECT", []) - assert roles["user"]["username"] in initial_current_grants.get("INSERT", []) - - engine_adapter.sync_grants_config(table, {}) - - final_grants = engine_adapter._get_current_grants_config(table) - assert final_grants == {} - - -def test_grants_case_insensitive_grantees( - engine_adapter: PostgresEngineAdapter, ctx: TestContext, config: Config -): - with create_users(engine_adapter, "test_reader", "test_writer") as roles: - table = ctx.table("grants_quoted_test") - engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) - - test_schema = table.db - for role_credentials in roles.values(): - engine_adapter.execute( - f'GRANT USAGE ON SCHEMA "{test_schema}" TO "{role_credentials["username"]}"' - ) - - reader = roles["test_reader"]["username"] - writer = roles["test_writer"]["username"] - - grants_config = {"SELECT": [reader, writer.upper()]} - engine_adapter.sync_grants_config(table, grants_config) - - # Grantees are still in lowercase - current_grants = engine_adapter._get_current_grants_config(table) - assert reader in current_grants.get("SELECT", []) - assert writer in current_grants.get("SELECT", []) - - # Revoke writer - grants_config = {"SELECT": [reader.upper()]} - engine_adapter.sync_grants_config(table, grants_config) - - current_grants = engine_adapter._get_current_grants_config(table) - assert reader in current_grants.get("SELECT", []) - assert writer not in current_grants.get("SELECT", []) - - -def test_grants_plan(engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path): - with create_users(engine_adapter, "analyst", "etl_user") as roles: - (tmp_path / "models").mkdir(exist_ok=True) - - model_def = """ - MODEL ( - name test_schema.grant_model, - kind FULL, - grants ( - 'select' = ['test_analyst'] - ), - grants_target_layer 'all' - ); - SELECT 1 as id, CURRENT_DATE as created_date - """ - - (tmp_path / "models" / "grant_model.sql").write_text(model_def) - - context = ctx.create_context(path=tmp_path) - plan_result = context.plan(auto_apply=True, no_prompts=True) - - assert len(plan_result.new_snapshots) == 1 - snapshot = plan_result.new_snapshots[0] - - # Physical layer w/ grants - table_name = snapshot.table_name() - current_grants = engine_adapter._get_current_grants_config( - exp.to_table(table_name, dialect=engine_adapter.dialect) - ) - assert current_grants == {"SELECT": [roles["analyst"]["username"]]} - - # Virtual layer (view) w/ grants - virtual_view_name = f"test_schema.grant_model" - virtual_grants = engine_adapter._get_current_grants_config( - exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) - ) - assert virtual_grants == {"SELECT": [roles["analyst"]["username"]]} - - # Update model with query change and new grants - existing_model = context.get_model("test_schema.grant_model") - from sqlglot import parse_one - - updated_query = parse_one("SELECT 1 as id, CURRENT_DATE as created_date, 'v2' as version") - context.upsert_model( - existing_model, - query=updated_query, - grants={ - "select": [roles["analyst"]["username"], roles["etl_user"]["username"]], - "insert": [roles["etl_user"]["username"]], - }, - ) - plan_result = context.plan(auto_apply=True, no_prompts=True) - - assert len(plan_result.new_snapshots) == 1 - new_snapshot = plan_result.new_snapshots[0] - - assert new_snapshot is not None - new_table_name = new_snapshot.table_name() - final_grants = engine_adapter._get_current_grants_config( - exp.to_table(new_table_name, dialect=engine_adapter.dialect) - ) - expected_final_grants = { - "SELECT": [roles["analyst"]["username"], roles["etl_user"]["username"]], - "INSERT": [roles["etl_user"]["username"]], - } - assert set(final_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) - assert final_grants.get("INSERT", []) == expected_final_grants["INSERT"] - - # Virtual layer should also have the updated grants - updated_virtual_grants = engine_adapter._get_current_grants_config( - exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) - ) - assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) - assert updated_virtual_grants.get("INSERT", []) == expected_final_grants["INSERT"] - - def test_grants_plan_target_layer_physical_only( engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path ): diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index 62c4a4f3eb..a84a6666ad 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -4,6 +4,7 @@ import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers import sqlmesh.core.dialect as d from sqlmesh.core.dialect import normalize_model_name @@ -245,6 +246,201 @@ def test_multiple_column_comments(make_mocked_engine_adapter: t.Callable, mocker ] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + 'SELECT privilege_type, grantee FROM "TEST_DB".INFORMATION_SCHEMA.TABLE_PRIVILEGES ' + "WHERE table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role1' in sql_calls + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role2' in sql_calls + assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role3' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE old_role' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE legacy_role' + in sql_calls + ) + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = { + "SELECT": ["ROLE shared", "ROLE new_role"], + "INSERT": ["ROLE shared", "ROLE writer"], + } + + current_grants = [ + ("SELECT", "ROLE shared"), + ("SELECT", "ROLE legacy"), + ("INSERT", "ROLE shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "TEST_DB".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert ( + 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE new_role' in sql_calls + ) + assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE writer' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE legacy' in sql_calls + ) + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "DYNAMIC TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_object", dialect="snowflake"), dialect="snowflake" + ) + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["ROLE test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f'GRANT SELECT ON {expected_keyword} "TEST_DB"."TEST_SCHEMA"."TEST_OBJECT" TO ROLE test' + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"test_db"."test_schema"."test_table"', dialect="snowflake"), + dialect="snowflake", + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "test_db".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role1' in sql_calls + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role2' in sql_calls + assert 'GRANT INSERT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role3' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "test_db"."test_schema"."test_table" FROM ROLE old_role' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "test_db"."test_schema"."test_table" FROM ROLE legacy_role' + in sql_calls + ) + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"TesT_Table"', dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "get_current_catalog", return_value="caTalog") + mocker.patch.object(adapter, "_get_current_schema", return_value="sChema") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "caTalog".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_schema = 'sChema' AND table_name = 'TesT_Table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE role1' in sql_calls + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE role2' in sql_calls + assert 'GRANT INSERT ON TABLE "TesT_Table" TO ROLE role3' in sql_calls + assert 'REVOKE SELECT ON TABLE "TesT_Table" FROM ROLE old_role' in sql_calls + assert 'REVOKE UPDATE ON TABLE "TesT_Table" FROM ROLE legacy_role' in sql_calls + + def test_df_to_source_queries_use_schema( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ):