From 3b60eba7f69bbab614c802fa22f02f9bf8537303 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:11:46 -0700 Subject: [PATCH] feat: databricks grant support --- sqlmesh/core/engine_adapter/databricks.py | 107 ++++++++++- .../engine_adapter/integration/__init__.py | 36 +++- .../integration/test_integration.py | 45 +++-- tests/core/engine_adapter/test_databricks.py | 177 ++++++++++++++++++ 4 files changed, 334 insertions(+), 31 deletions(-) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 946a7bdf74..3836f326c1 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -5,6 +5,7 @@ from functools import partial from sqlglot import exp + from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -23,7 +24,7 @@ import pandas as pd from sqlmesh.core._typing import SchemaName, TableName, SessionProperties - from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query + from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query, GrantsConfig, DCL logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter): SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True + SUPPORTS_GRANTS = True SCHEMA_DIFFER_KWARGS = { "support_positional_add": True, "nested_support": NestedSupport.ALL, @@ -149,6 +151,109 @@ def spark(self) -> PySparkSession: def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + return "TABLE" + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grant_config: + return expressions + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grant_config.items(): + for principal in principals: + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], + "securable": table.copy(), + "principals": [exp.to_identifier(principal.lower())], + } + + if object_kind: + args["kind"] = exp.Var(this=object_kind) + + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + + def _apply_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + if schema_identifier := table.args.get("db"): + schema_name = schema_identifier.this + else: + schema_name = self.get_current_database() + if catalog_identifier := table.args.get("catalog"): + catalog_name = catalog_identifier.this + else: + catalog_name = self.get_current_catalog() + table_name = table.args.get("this").this # type: ignore + + grant_expr = ( + exp.select("privilege_type", "grantee") + .from_( + exp.table_( + "table_privileges", + db="information_schema", + catalog=catalog_name, + ) + ) + .where( + exp.and_( + exp.column("table_catalog").eq(exp.Literal.string(catalog_name.lower())), + exp.column("table_schema").eq(exp.Literal.string(schema_name.lower())), + exp.column("table_name").eq(exp.Literal.string(table_name.lower())), + exp.column("grantor").eq(exp.func("current_user")), + exp.column("grantee").neq(exp.func("current_user")), + # We only care about explicitly granted privileges and not inherited ones + # if this is removed you would see grants inherited from the catalog get returned + exp.column("inherited_from").eq(exp.Literal.string("NONE")), + ) + ) + ) + + results = self.fetchall(grant_expr) + + grants_dict: GrantsConfig = {} + for privilege_raw, grantee_raw in results: + if privilege_raw is None or grantee_raw is None: + continue + + privilege = str(privilege_raw) + grantee = str(grantee_raw) + if not privilege or not grantee: + continue + + grantees = grants_dict.setdefault(privilege, []) + if grantee not in grantees: + grantees.append(grantee) + + return grants_dict + def _begin_session(self, properties: SessionProperties) -> t.Any: """Begin a new session.""" # Align the different possible connectors to a single catalog diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index 7e6dae2f1b..0c53edd405 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -746,17 +746,25 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]: self._context.upsert_model(model) return self._context, model - def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str: + def _get_create_user_or_role( + self, username: str, password: t.Optional[str] = None + ) -> t.Tuple[str, t.Optional[str]]: password = password or random_id() if self.dialect == "postgres": - return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'" + return username, f"CREATE USER \"{username}\" WITH PASSWORD '{password}'" if self.dialect == "snowflake": - return f"CREATE ROLE {username}" + return username, f"CREATE ROLE {username}" + if self.dialect == "databricks": + # Creating an account-level group in Databricks requires making REST API calls so we are going to + # use a pre-created group instead. We assume the suffix on the name is the unique id + return "_".join(username.split("_")[:-1]), None raise ValueError(f"User creation not supported for dialect: {self.dialect}") - def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None: - create_user_sql = self._get_create_user_or_role(username, password) - self.engine_adapter.execute(create_user_sql) + def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str: + username, create_user_sql = self._get_create_user_or_role(username, password) + if create_user_sql: + self.engine_adapter.execute(create_user_sql) + return username @contextmanager def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]: @@ -769,7 +777,7 @@ def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str] self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect ).sql(dialect=self.dialect) password = random_id() - self._create_user_or_role(user_name, password) + user_name = self._create_user_or_role(user_name, password) created_users.append(user_name) roles[role_name] = user_name @@ -779,6 +787,18 @@ def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str] for user_name in created_users: self._cleanup_user_or_role(user_name) + def get_insert_privilege(self) -> str: + if self.dialect == "databricks": + # This would really be "MODIFY" but for the purposes of having this be unique from UPDATE + # we return "MANAGE" instead + return "MANAGE" + return "INSERT" + + def get_update_privilege(self) -> str: + if self.dialect == "databricks": + return "MODIFY" + return "UPDATE" + def _cleanup_user_or_role(self, user_name: str) -> None: """Helper function to clean up a PostgreSQL user and all their dependencies.""" try: @@ -792,6 +812,8 @@ def _cleanup_user_or_role(self, user_name: str) -> None: self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') elif self.dialect == "snowflake": self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}") + elif self.dialect == "databricks": + pass except Exception: pass diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index f48a95d39b..3b3ac12a26 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -3843,23 +3843,24 @@ def test_sync_grants_config(ctx: TestContext) -> None: ) table = ctx.table("sync_grants_integration") - + insert_privilege = ctx.get_insert_privilege() + update_privilege = ctx.get_update_privilege() with ctx.create_users_or_roles("reader", "writer", "admin") as roles: ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) initial_grants = { "SELECT": [roles["reader"]], - "INSERT": [roles["writer"]], + insert_privilege: [roles["writer"]], } ctx.engine_adapter.sync_grants_config(table, initial_grants) current_grants = ctx.engine_adapter._get_current_grants_config(table) assert set(current_grants.get("SELECT", [])) == {roles["reader"]} - assert set(current_grants.get("INSERT", [])) == {roles["writer"]} + assert set(current_grants.get(insert_privilege, [])) == {roles["writer"]} target_grants = { "SELECT": [roles["writer"], roles["admin"]], - "UPDATE": [roles["admin"]], + update_privilege: [roles["admin"]], } ctx.engine_adapter.sync_grants_config(table, target_grants) @@ -3868,8 +3869,8 @@ def test_sync_grants_config(ctx: TestContext) -> None: roles["writer"], roles["admin"], } - assert set(synced_grants.get("UPDATE", [])) == {roles["admin"]} - assert synced_grants.get("INSERT", []) == [] + assert set(synced_grants.get(update_privilege, [])) == {roles["admin"]} + assert synced_grants.get(insert_privilege, []) == [] def test_grants_sync_empty_config(ctx: TestContext): @@ -3879,19 +3880,19 @@ def test_grants_sync_empty_config(ctx: TestContext): ) table = ctx.table("grants_empty_test") - + insert_privilege = ctx.get_insert_privilege() with ctx.create_users_or_roles("user") as roles: ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) initial_grants = { "SELECT": [roles["user"]], - "INSERT": [roles["user"]], + insert_privilege: [roles["user"]], } ctx.engine_adapter.sync_grants_config(table, initial_grants) initial_current_grants = ctx.engine_adapter._get_current_grants_config(table) assert roles["user"] in initial_current_grants.get("SELECT", []) - assert roles["user"] in initial_current_grants.get("INSERT", []) + assert roles["user"] in initial_current_grants.get(insert_privilege, []) ctx.engine_adapter.sync_grants_config(table, {}) @@ -3905,18 +3906,12 @@ def test_grants_case_insensitive_grantees(ctx: TestContext): f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" ) - with ctx.create_users_or_roles("test_reader", "test_writer") as roles: + with ctx.create_users_or_roles("reader", "writer") as roles: table = ctx.table("grants_quoted_test") ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) - test_schema = table.db - for role_credentials in roles.values(): - ctx.engine_adapter.execute( - f'GRANT USAGE ON SCHEMA "{test_schema}" TO "{role_credentials}"' - ) - - reader = roles["test_reader"] - writer = roles["test_writer"] + reader = roles["reader"] + writer = roles["writer"] grants_config = {"SELECT": [reader, writer.upper()]} ctx.engine_adapter.sync_grants_config(table, grants_config) @@ -3941,7 +3936,8 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path): f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" ) - table = ctx.table("grant_model").sql(dialect=ctx.dialect) + table = ctx.table("grant_model").sql(dialect="duckdb") + insert_privilege = ctx.get_insert_privilege() with ctx.create_users_or_roles("analyst", "etl_user") as roles: (tmp_path / "models").mkdir(exist_ok=True) @@ -3990,7 +3986,7 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path): kind FULL, grants ( 'select' = ['{roles["analyst"]}', '{roles["etl_user"]}'], - 'insert' = ['{roles["etl_user"]}'] + '{insert_privilege}' = ['{roles["etl_user"]}'] ), grants_target_layer 'all' ); @@ -4015,14 +4011,17 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path): ) expected_final_grants = { "SELECT": [roles["analyst"], roles["etl_user"]], - "INSERT": [roles["etl_user"]], + insert_privilege: [roles["etl_user"]], } assert set(final_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) - assert final_grants.get("INSERT", []) == expected_final_grants["INSERT"] + assert final_grants.get(insert_privilege, []) == expected_final_grants[insert_privilege] # Virtual layer should also have the updated grants updated_virtual_grants = ctx.engine_adapter._get_current_grants_config( exp.to_table(view_name, dialect=ctx.dialect) ) assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"]) - assert updated_virtual_grants.get("INSERT", []) == expected_final_grants["INSERT"] + assert ( + updated_virtual_grants.get(insert_privilege, []) + == expected_final_grants[insert_privilege] + ) diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index f482361c3c..eb320e03ca 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -139,6 +139,183 @@ def test_get_current_database(mocker: MockFixture, make_mocked_engine_adapter: t assert to_sql_calls(adapter) == ["SELECT CURRENT_DATABASE()"] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `main`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["shared", "new_role"], + "MODIFY": ["shared", "writer"], + } + + current_grants = [ + ("SELECT", "shared"), + ("SELECT", "legacy"), + ("MODIFY", "shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `new_role`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writer`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_object", dialect="databricks") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f"GRANT SELECT ON {expected_keyword} `main`.`test_schema`.`test_object` TO `test`" + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="`test_db`") + relation = exp.to_table("`test_db`.`test_schema`.`test_table`", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM test_db.information_schema.table_privileges " + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_db`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_db`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_db`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main_catalog") + relation = exp.to_table("test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "get_current_database", return_value="schema") + mocker.patch.object(adapter, "get_current_catalog", return_value="main_catalog") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main_catalog.information_schema.table_privileges " + "WHERE table_catalog = 'main_catalog' AND table_schema = 'schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_table` FROM `stale`" in sql_calls + + def test_insert_overwrite_by_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockFixture, make_temp_table_name: t.Callable ):