Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/engine_adapter/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@

QueryOrDF = t.Union[Query, DF]
GrantsConfig = t.Dict[str, t.List[str]]
DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)
6 changes: 2 additions & 4 deletions sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName
from sqlmesh.core.engine_adapter._typing import DF, GrantsConfig, QueryOrDF

DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)
from sqlmesh.core.engine_adapter._typing import DCL, DF, GrantsConfig, QueryOrDF

logger = logging.getLogger(__name__)

Expand All @@ -38,7 +36,7 @@ class PostgresEngineAdapter(
HAS_VIEW_BINDING = True
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
SUPPORTS_REPLACE_TABLE = False
MAX_IDENTIFIER_LENGTH = 63
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/engine_adapter/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter):
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_TRANSACTIONS = False
MAX_IDENTIFIER_LENGTH = None
SUPPORTS_GRANTS = False

def columns(
self, table_name: TableName, include_pseudo_columns: bool = False
Expand Down
122 changes: 121 additions & 1 deletion sqlmesh/core/engine_adapter/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@
import pandas as pd

from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF, SnowparkSession
from sqlmesh.core.engine_adapter._typing import (
DCL,
DF,
GrantsConfig,
Query,
QueryOrDF,
SnowparkSession,
)
from sqlmesh.core.node import IntervalUnit


Expand Down Expand Up @@ -73,6 +80,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
SNOWPARK = "snowpark"
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTS_GRANTS = True

@contextlib.contextmanager
def session(self, properties: SessionProperties) -> t.Iterator[None]:
Expand Down Expand Up @@ -127,6 +135,118 @@ def snowpark(self) -> t.Optional[SnowparkSession]:
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.FULL_SUPPORT

@staticmethod
def _grant_object_kind(table_type: DataObjectType) -> str:
if table_type == DataObjectType.VIEW:
return "VIEW"
if table_type == DataObjectType.MATERIALIZED_VIEW:
return "MATERIALIZED VIEW"
if table_type == DataObjectType.MANAGED_TABLE:
return "DYNAMIC TABLE"
return "TABLE"

def _get_current_schema(self) -> str:
"""Returns the current default schema for the connection."""
result = self.fetchone("SELECT CURRENT_SCHEMA()")
if not result or not result[0]:
raise SQLMeshError("Unable to determine current schema")
return str(result[0])

def _dcl_grants_config_expr(
self,
dcl_cmd: t.Type[DCL],
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
expressions: t.List[exp.Expression] = []
if not grant_config:
return expressions

object_kind = self._grant_object_kind(table_type)
for privilege, principals in grant_config.items():
for principal in principals:
args: t.Dict[str, t.Any] = {
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
"securable": table.copy(),
"principals": [principal],
}

if object_kind:
args["kind"] = exp.Var(this=object_kind)

expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]

return expressions

def _apply_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)

def _revoke_grants_config_expr(
self,
table: exp.Table,
grant_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
) -> t.List[exp.Expression]:
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)

def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
schema_identifier = table.args.get("db") or normalize_identifiers(
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
)
catalog_identifier = table.args.get("catalog")
if not catalog_identifier:
current_catalog = self.get_current_catalog()
if not current_catalog:
raise SQLMeshError("Unable to determine current catalog for fetching grants")
catalog_identifier = normalize_identifiers(
exp.to_identifier(current_catalog, quoted=True), dialect=self.dialect
)
catalog_identifier.set("quoted", True)
table_identifier = table.args.get("this")

grant_expr = (
exp.select("privilege_type", "grantee")
.from_(
exp.table_(
"TABLE_PRIVILEGES",
db="INFORMATION_SCHEMA",
catalog=catalog_identifier,
)
)
.where(
exp.and_(
exp.column("table_schema").eq(exp.Literal.string(schema_identifier.this)),
exp.column("table_name").eq(exp.Literal.string(table_identifier.this)), # type: ignore
exp.column("grantor").eq(exp.func("CURRENT_ROLE")),
exp.column("grantee").neq(exp.func("CURRENT_ROLE")),
)
)
)

results = self.fetchall(grant_expr)

grants_dict: GrantsConfig = {}
for privilege_raw, grantee_raw in results:
if privilege_raw is None or grantee_raw is None:
continue

privilege = str(privilege_raw)
grantee = str(grantee_raw)
if not privilege or not grantee:
continue

grantees = grants_dict.setdefault(privilege, [])
if grantee not in grantees:
grantees.append(grantee)

return grants_dict

def _create_catalog(self, catalog_name: exp.Identifier) -> None:
props = exp.Properties(
expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))]
Expand Down
51 changes: 51 additions & 0 deletions tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import sys
import typing as t
import time
from contextlib import contextmanager

import pandas as pd # noqa: TID253
import pytest
from sqlglot import exp, parse_one
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlmesh import Config, Context, EngineAdapter
from sqlmesh.core.config import load_config_from_paths
Expand Down Expand Up @@ -744,6 +746,55 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
self._context.upsert_model(model)
return self._context, model

def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str:
password = password or random_id()
if self.dialect == "postgres":
return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'"
if self.dialect == "snowflake":
return f"CREATE ROLE {username}"
raise ValueError(f"User creation not supported for dialect: {self.dialect}")

def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None:
create_user_sql = self._get_create_user_or_role(username, password)
self.engine_adapter.execute(create_user_sql)

@contextmanager
def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]:
created_users = []
roles = {}

try:
for role_name in role_names:
user_name = normalize_identifiers(
self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect
).sql(dialect=self.dialect)
password = random_id()
self._create_user_or_role(user_name, password)
created_users.append(user_name)
roles[role_name] = user_name

yield roles

finally:
for user_name in created_users:
self._cleanup_user_or_role(user_name)

def _cleanup_user_or_role(self, user_name: str) -> None:
"""Helper function to clean up a PostgreSQL user and all their dependencies."""
try:
if self.dialect == "postgres":
self.engine_adapter.execute(f"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE usename = '{user_name}' AND pid <> pg_backend_pid()
""")
self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"')
self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"')
elif self.dialect == "snowflake":
self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}")
except Exception:
pass


def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None:
current_attempt = 0
Expand Down
Loading