Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial

from sqlglot import exp

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.shared import (
CatalogSupport,
Expand All @@ -23,7 +24,7 @@
import pandas as pd

from sqlmesh.core._typing import SchemaName, TableName, SessionProperties
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query, GrantsConfig, DCL

logger = logging.getLogger(__name__)

Expand All @@ -34,6 +35,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
SUPPORTS_CLONING = True
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
SUPPORTS_GRANTS = True
SCHEMA_DIFFER_KWARGS = {
"support_positional_add": True,
"nested_support": NestedSupport.ALL,
Expand Down Expand Up @@ -149,6 +151,109 @@ def spark(self) -> PySparkSession:
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.FULL_SUPPORT

@staticmethod
def _grant_object_kind(table_type: DataObjectType) -> str:
if table_type == DataObjectType.VIEW:
return "VIEW"
if table_type == DataObjectType.MATERIALIZED_VIEW:
return "MATERIALIZED VIEW"
return "TABLE"

def _dcl_grants_config_expr(
self,
dcl_cmd: t.Type[DCL],
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
expressions: t.List[exp.Expression] = []
if not grant_config:
return expressions

object_kind = self._grant_object_kind(table_type)
for privilege, principals in grant_config.items():
for principal in principals:
args: t.Dict[str, t.Any] = {
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
"securable": table.copy(),
"principals": [exp.to_identifier(principal.lower())],
}

if object_kind:
args["kind"] = exp.Var(this=object_kind)

expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]

return expressions

def _apply_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)

def _revoke_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)

def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
if schema_identifier := table.args.get("db"):
schema_name = schema_identifier.this
else:
schema_name = self.get_current_database()
if catalog_identifier := table.args.get("catalog"):
catalog_name = catalog_identifier.this
else:
catalog_name = self.get_current_catalog()
table_name = table.args.get("this").this # type: ignore

grant_expr = (
exp.select("privilege_type", "grantee")
.from_(
exp.table_(
"table_privileges",
db="information_schema",
catalog=catalog_name,
)
)
.where(
exp.and_(
exp.column("table_catalog").eq(exp.Literal.string(catalog_name.lower())),
exp.column("table_schema").eq(exp.Literal.string(schema_name.lower())),
exp.column("table_name").eq(exp.Literal.string(table_name.lower())),
Comment on lines +227 to +229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the reason we lower case everything is because we are assuming Databricks users will be using Unity catalog? Is that always the case? I know some deployments still use hive metastore. If I remember correctly, it table names etc. can be case sensitive there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The engine adapter itself assumes Unity Catalog. So we don't need to worry about hive metastore compatibility.

exp.column("grantor").eq(exp.func("current_user")),
exp.column("grantee").neq(exp.func("current_user")),
# We only care about explicitly granted privileges and not inherited ones
# if this is removed you would see grants inherited from the catalog get returned
exp.column("inherited_from").eq(exp.Literal.string("NONE")),
)
)
)

results = self.fetchall(grant_expr)

grants_dict: GrantsConfig = {}
for privilege_raw, grantee_raw in results:
if privilege_raw is None or grantee_raw is None:
continue

privilege = str(privilege_raw)
grantee = str(grantee_raw)
if not privilege or not grantee:
continue

grantees = grants_dict.setdefault(privilege, [])
if grantee not in grantees:
grantees.append(grantee)

return grants_dict

def _begin_session(self, properties: SessionProperties) -> t.Any:
"""Begin a new session."""
# Align the different possible connectors to a single catalog
Expand Down
36 changes: 29 additions & 7 deletions tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,25 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
self._context.upsert_model(model)
return self._context, model

def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str:
def _get_create_user_or_role(
self, username: str, password: t.Optional[str] = None
) -> t.Tuple[str, t.Optional[str]]:
password = password or random_id()
if self.dialect == "postgres":
return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'"
return username, f"CREATE USER \"{username}\" WITH PASSWORD '{password}'"
if self.dialect == "snowflake":
return f"CREATE ROLE {username}"
return username, f"CREATE ROLE {username}"
if self.dialect == "databricks":
# Creating an account-level group in Databricks requires making REST API calls so we are going to
# use a pre-created group instead. We assume the suffix on the name is the unique id
return "_".join(username.split("_")[:-1]), None
raise ValueError(f"User creation not supported for dialect: {self.dialect}")

def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None:
create_user_sql = self._get_create_user_or_role(username, password)
self.engine_adapter.execute(create_user_sql)
def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str:
username, create_user_sql = self._get_create_user_or_role(username, password)
if create_user_sql:
self.engine_adapter.execute(create_user_sql)
return username

@contextmanager
def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]:
Expand All @@ -769,7 +777,7 @@ def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]
self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect
).sql(dialect=self.dialect)
password = random_id()
self._create_user_or_role(user_name, password)
user_name = self._create_user_or_role(user_name, password)
created_users.append(user_name)
roles[role_name] = user_name

Expand All @@ -779,6 +787,18 @@ def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]
for user_name in created_users:
self._cleanup_user_or_role(user_name)

def get_insert_privilege(self) -> str:
if self.dialect == "databricks":
# This would really be "MODIFY" but for the purposes of having this be unique from UPDATE
# we return "MANAGE" instead
return "MANAGE"
return "INSERT"

def get_update_privilege(self) -> str:
if self.dialect == "databricks":
return "MODIFY"
return "UPDATE"

def _cleanup_user_or_role(self, user_name: str) -> None:
"""Helper function to clean up a PostgreSQL user and all their dependencies."""
try:
Expand All @@ -792,6 +812,8 @@ def _cleanup_user_or_role(self, user_name: str) -> None:
self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"')
elif self.dialect == "snowflake":
self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}")
elif self.dialect == "databricks":
pass
except Exception:
pass

Expand Down
45 changes: 22 additions & 23 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3843,23 +3843,24 @@ def test_sync_grants_config(ctx: TestContext) -> None:
)

table = ctx.table("sync_grants_integration")

insert_privilege = ctx.get_insert_privilege()
update_privilege = ctx.get_update_privilege()
with ctx.create_users_or_roles("reader", "writer", "admin") as roles:
ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")})

initial_grants = {
"SELECT": [roles["reader"]],
"INSERT": [roles["writer"]],
insert_privilege: [roles["writer"]],
}
ctx.engine_adapter.sync_grants_config(table, initial_grants)

current_grants = ctx.engine_adapter._get_current_grants_config(table)
assert set(current_grants.get("SELECT", [])) == {roles["reader"]}
assert set(current_grants.get("INSERT", [])) == {roles["writer"]}
assert set(current_grants.get(insert_privilege, [])) == {roles["writer"]}

target_grants = {
"SELECT": [roles["writer"], roles["admin"]],
"UPDATE": [roles["admin"]],
update_privilege: [roles["admin"]],
}
ctx.engine_adapter.sync_grants_config(table, target_grants)

Expand All @@ -3868,8 +3869,8 @@ def test_sync_grants_config(ctx: TestContext) -> None:
roles["writer"],
roles["admin"],
}
assert set(synced_grants.get("UPDATE", [])) == {roles["admin"]}
assert synced_grants.get("INSERT", []) == []
assert set(synced_grants.get(update_privilege, [])) == {roles["admin"]}
assert synced_grants.get(insert_privilege, []) == []


def test_grants_sync_empty_config(ctx: TestContext):
Expand All @@ -3879,19 +3880,19 @@ def test_grants_sync_empty_config(ctx: TestContext):
)

table = ctx.table("grants_empty_test")

insert_privilege = ctx.get_insert_privilege()
with ctx.create_users_or_roles("user") as roles:
ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")})

initial_grants = {
"SELECT": [roles["user"]],
"INSERT": [roles["user"]],
insert_privilege: [roles["user"]],
}
ctx.engine_adapter.sync_grants_config(table, initial_grants)

initial_current_grants = ctx.engine_adapter._get_current_grants_config(table)
assert roles["user"] in initial_current_grants.get("SELECT", [])
assert roles["user"] in initial_current_grants.get("INSERT", [])
assert roles["user"] in initial_current_grants.get(insert_privilege, [])

ctx.engine_adapter.sync_grants_config(table, {})

Expand All @@ -3905,18 +3906,12 @@ def test_grants_case_insensitive_grantees(ctx: TestContext):
f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants"
)

with ctx.create_users_or_roles("test_reader", "test_writer") as roles:
with ctx.create_users_or_roles("reader", "writer") as roles:
table = ctx.table("grants_quoted_test")
ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")})

test_schema = table.db
for role_credentials in roles.values():
ctx.engine_adapter.execute(
f'GRANT USAGE ON SCHEMA "{test_schema}" TO "{role_credentials}"'
)

reader = roles["test_reader"]
writer = roles["test_writer"]
reader = roles["reader"]
writer = roles["writer"]

grants_config = {"SELECT": [reader, writer.upper()]}
ctx.engine_adapter.sync_grants_config(table, grants_config)
Expand All @@ -3941,7 +3936,8 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path):
f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants"
)

table = ctx.table("grant_model").sql(dialect=ctx.dialect)
table = ctx.table("grant_model").sql(dialect="duckdb")
insert_privilege = ctx.get_insert_privilege()
with ctx.create_users_or_roles("analyst", "etl_user") as roles:
(tmp_path / "models").mkdir(exist_ok=True)

Expand Down Expand Up @@ -3990,7 +3986,7 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path):
kind FULL,
grants (
'select' = ['{roles["analyst"]}', '{roles["etl_user"]}'],
'insert' = ['{roles["etl_user"]}']
'{insert_privilege}' = ['{roles["etl_user"]}']
),
grants_target_layer 'all'
);
Expand All @@ -4015,14 +4011,17 @@ def test_grants_plan(ctx: TestContext, tmp_path: Path):
)
expected_final_grants = {
"SELECT": [roles["analyst"], roles["etl_user"]],
"INSERT": [roles["etl_user"]],
insert_privilege: [roles["etl_user"]],
}
assert set(final_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"])
assert final_grants.get("INSERT", []) == expected_final_grants["INSERT"]
assert final_grants.get(insert_privilege, []) == expected_final_grants[insert_privilege]

# Virtual layer should also have the updated grants
updated_virtual_grants = ctx.engine_adapter._get_current_grants_config(
exp.to_table(view_name, dialect=ctx.dialect)
)
assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_final_grants["SELECT"])
assert updated_virtual_grants.get("INSERT", []) == expected_final_grants["INSERT"]
assert (
updated_virtual_grants.get(insert_privilege, [])
== expected_final_grants[insert_privilege]
)
Loading