Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"requests",
"rich[jupyter]",
"ruamel.yaml",
"sqlglot[rs]~=27.17.0",
"sqlglot[rs]~=27.19.0",
"tenacity",
"time-machine",
"json-stream"
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/base_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _get_data_objects(
for row in df.itertuples()
]

def get_current_schema(self) -> str:
def _get_current_schema(self) -> str:
"""Returns the current default schema for the connection."""
result = self.fetchone(exp.select(self.CURRENT_SCHEMA_EXPRESSION))
if result and result[0]:
Expand Down
111 changes: 16 additions & 95 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlglot import exp

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin
from sqlmesh.core.engine_adapter.shared import (
CatalogSupport,
DataObject,
Expand All @@ -24,18 +25,19 @@
import pandas as pd

from sqlmesh.core._typing import SchemaName, TableName, SessionProperties
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query, GrantsConfig, DCL
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query

logger = logging.getLogger(__name__)


class DatabricksEngineAdapter(SparkEngineAdapter):
class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
DIALECT = "databricks"
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
SUPPORTS_CLONING = True
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
SUPPORTS_GRANTS = True
USE_CATALOG_IN_GRANTS = True
SCHEMA_DIFFER_KWARGS = {
"support_positional_add": True,
"nested_support": NestedSupport.ALL,
Expand Down Expand Up @@ -159,100 +161,19 @@ def _grant_object_kind(table_type: DataObjectType) -> str:
return "MATERIALIZED VIEW"
return "TABLE"

def _dcl_grants_config_expr(
self,
dcl_cmd: t.Type[DCL],
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
expressions: t.List[exp.Expression] = []
if not grant_config:
return expressions

object_kind = self._grant_object_kind(table_type)
for privilege, principals in grant_config.items():
for principal in principals:
args: t.Dict[str, t.Any] = {
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
"securable": table.copy(),
"principals": [exp.to_identifier(principal.lower())],
}

if object_kind:
args["kind"] = exp.Var(this=object_kind)

expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]

return expressions

def _apply_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)

def _revoke_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)

def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
if schema_identifier := table.args.get("db"):
schema_name = schema_identifier.this
else:
schema_name = self.get_current_database()
if catalog_identifier := table.args.get("catalog"):
catalog_name = catalog_identifier.this
else:
catalog_name = self.get_current_catalog()
table_name = table.args.get("this").this # type: ignore

grant_expr = (
exp.select("privilege_type", "grantee")
.from_(
exp.table_(
"table_privileges",
db="information_schema",
catalog=catalog_name,
)
)
.where(
exp.and_(
exp.column("table_catalog").eq(exp.Literal.string(catalog_name.lower())),
exp.column("table_schema").eq(exp.Literal.string(schema_name.lower())),
exp.column("table_name").eq(exp.Literal.string(table_name.lower())),
exp.column("grantor").eq(exp.func("current_user")),
exp.column("grantee").neq(exp.func("current_user")),
# We only care about explicitly granted privileges and not inherited ones
# if this is removed you would see grants inherited from the catalog get returned
exp.column("inherited_from").eq(exp.Literal.string("NONE")),
)
)
def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
# We only care about explicitly granted privileges and not inherited ones
# if this is removed you would see grants inherited from the catalog get returned
expression = super()._get_grant_expression(table)
expression.args["where"].set(
"this",
exp.and_(
expression.args["where"].this,
exp.column("inherited_from").eq(exp.Literal.string("NONE")),
wrap=False,
),
)

results = self.fetchall(grant_expr)

grants_dict: GrantsConfig = {}
for privilege_raw, grantee_raw in results:
if privilege_raw is None or grantee_raw is None:
continue

privilege = str(privilege_raw)
grantee = str(grantee_raw)
if not privilege or not grantee:
continue

grantees = grants_dict.setdefault(privilege, [])
if grantee not in grantees:
grantees.append(grantee)

return grants_dict
return expression

def _begin_session(self, properties: SessionProperties) -> t.Any:
"""Begin a new session."""
Expand Down
143 changes: 142 additions & 1 deletion sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@

from sqlglot import exp, parse_one
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlmesh.core.engine_adapter.base import EngineAdapter
from sqlmesh.core.engine_adapter.shared import DataObjectType
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.dialect import schema_
from sqlmesh.core.schema_diff import TableAlterOperation
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName
from sqlmesh.core.engine_adapter._typing import DF
from sqlmesh.core.engine_adapter._typing import (
DCL,
DF,
GrantsConfig,
QueryOrDF,
)
from sqlmesh.core.engine_adapter.base import QueryOrDF

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -548,3 +555,137 @@ def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.

def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression:
return exp.cast(expr, "INT")


class GrantsFromInfoSchemaMixin(EngineAdapter):
CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False
USE_CATALOG_IN_GRANTS = False
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges"

@staticmethod
@abc.abstractmethod
def _grant_object_kind(table_type: DataObjectType) -> t.Optional[str]:
pass

@abc.abstractmethod
def _get_current_schema(self) -> str:
pass

def _dcl_grants_config_expr(
self,
dcl_cmd: t.Type[DCL],
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
expressions: t.List[exp.Expression] = []
if not grant_config:
return expressions

object_kind = self._grant_object_kind(table_type)
for privilege, principals in grant_config.items():
args: t.Dict[str, t.Any] = {
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
"securable": table.copy(),
}
if object_kind:
args["kind"] = exp.Var(this=object_kind)
if self.SUPPORTS_MULTIPLE_GRANT_PRINCIPALS:
args["principals"] = [
normalize_identifiers(
parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect),
dialect=self.dialect,
)
for principal in principals
]
expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]
else:
for principal in principals:
args["principals"] = [
normalize_identifiers(
parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect),
dialect=self.dialect,
)
]
expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]

return expressions

def _apply_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)

def _revoke_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)

def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
schema_identifier = table.args.get("db") or normalize_identifiers(
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
)
schema_name = schema_identifier.this
table_name = table.args.get("this").this # type: ignore

grant_conditions = [
exp.column("table_schema").eq(exp.Literal.string(schema_name)),
exp.column("table_name").eq(exp.Literal.string(table_name)),
exp.column("grantor").eq(self.CURRENT_USER_OR_ROLE_EXPRESSION),
exp.column("grantee").neq(self.CURRENT_USER_OR_ROLE_EXPRESSION),
]

info_schema_table = normalize_identifiers(
exp.table_(self.GRANT_INFORMATION_SCHEMA_TABLE_NAME, db="information_schema"),
dialect=self.dialect,
)
if self.USE_CATALOG_IN_GRANTS:
catalog_identifier = table.args.get("catalog")
if not catalog_identifier:
catalog_name = self.get_current_catalog()
if not catalog_name:
raise SQLMeshError(
"Current catalog could not be determined for fetching grants. This is unexpected."
)
catalog_identifier = normalize_identifiers(
exp.to_identifier(catalog_name, quoted=True), dialect=self.dialect
)
catalog_name = catalog_identifier.this
info_schema_table.set("catalog", catalog_identifier.copy())
grant_conditions.insert(
0, exp.column("table_catalog").eq(exp.Literal.string(catalog_name))
)

return (
exp.select("privilege_type", "grantee")
.from_(info_schema_table)
.where(exp.and_(*grant_conditions))
)

def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
grant_expr = self._get_grant_expression(table)

results = self.fetchall(grant_expr)

grants_dict: GrantsConfig = {}
for privilege_raw, grantee_raw in results:
if privilege_raw is None or grantee_raw is None:
continue

privilege = str(privilege_raw)
grantee = str(grantee_raw)
if not privilege or not grantee:
continue

grantees = grants_dict.setdefault(privilege, [])
if grantee not in grantees:
grantees.append(grantee)

return grants_dict
Loading