From 4564c98d7b8fea473a4f2b700cb0af57efdea5a6 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:08:14 -0700 Subject: [PATCH] chore: create grant mixin and normalize --- pyproject.toml | 2 +- sqlmesh/core/engine_adapter/base_postgres.py | 2 +- sqlmesh/core/engine_adapter/databricks.py | 111 ++------------ sqlmesh/core/engine_adapter/mixins.py | 143 +++++++++++++++++- sqlmesh/core/engine_adapter/postgres.py | 60 +------- sqlmesh/core/engine_adapter/redshift.py | 86 +---------- sqlmesh/core/engine_adapter/snowflake.py | 131 ++++------------ sqlmesh/core/engine_adapter/spark.py | 4 +- .../core/engine_adapter/test_base_postgres.py | 6 +- tests/core/engine_adapter/test_databricks.py | 10 +- tests/core/engine_adapter/test_postgres.py | 18 +-- tests/core/engine_adapter/test_redshift.py | 32 ++-- tests/core/engine_adapter/test_snowflake.py | 53 ++++--- tests/core/engine_adapter/test_spark.py | 4 +- 14 files changed, 265 insertions(+), 397 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 59880e61c5..25e98248a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "requests", "rich[jupyter]", "ruamel.yaml", - "sqlglot[rs]~=27.17.0", + "sqlglot[rs]~=27.19.0", "tenacity", "time-machine", "json-stream" diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 9d4fa68ac3..80c5bef2d0 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -191,7 +191,7 @@ def _get_data_objects( for row in df.itertuples() ] - def get_current_schema(self) -> str: + def _get_current_schema(self) -> str: """Returns the current default schema for the connection.""" result = self.fetchone(exp.select(self.CURRENT_SCHEMA_EXPRESSION)) if result and result[0]: diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 3836f326c1..5d786ba823 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -7,6 +7,7 @@ from sqlglot import exp from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, DataObject, @@ -24,18 +25,19 @@ import pandas as pd from sqlmesh.core._typing import SchemaName, TableName, SessionProperties - from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query, GrantsConfig, DCL + from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query logger = logging.getLogger(__name__) -class DatabricksEngineAdapter(SparkEngineAdapter): +class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): DIALECT = "databricks" INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True SUPPORTS_GRANTS = True + USE_CATALOG_IN_GRANTS = True SCHEMA_DIFFER_KWARGS = { "support_positional_add": True, "nested_support": NestedSupport.ALL, @@ -159,100 +161,19 @@ def _grant_object_kind(table_type: DataObjectType) -> str: return "MATERIALIZED VIEW" return "TABLE" - def _dcl_grants_config_expr( - self, - dcl_cmd: t.Type[DCL], - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - expressions: t.List[exp.Expression] = [] - if not grant_config: - return expressions - - object_kind = self._grant_object_kind(table_type) - for privilege, principals in grant_config.items(): - for principal in principals: - args: t.Dict[str, t.Any] = { - "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], - "securable": table.copy(), - "principals": [exp.to_identifier(principal.lower())], - } - - if object_kind: - args["kind"] = exp.Var(this=object_kind) - - expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] - - return expressions - - def _apply_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) - - def _revoke_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) - - def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: - if schema_identifier := table.args.get("db"): - schema_name = schema_identifier.this - else: - schema_name = self.get_current_database() - if catalog_identifier := table.args.get("catalog"): - catalog_name = catalog_identifier.this - else: - catalog_name = self.get_current_catalog() - table_name = table.args.get("this").this # type: ignore - - grant_expr = ( - exp.select("privilege_type", "grantee") - .from_( - exp.table_( - "table_privileges", - db="information_schema", - catalog=catalog_name, - ) - ) - .where( - exp.and_( - exp.column("table_catalog").eq(exp.Literal.string(catalog_name.lower())), - exp.column("table_schema").eq(exp.Literal.string(schema_name.lower())), - exp.column("table_name").eq(exp.Literal.string(table_name.lower())), - exp.column("grantor").eq(exp.func("current_user")), - exp.column("grantee").neq(exp.func("current_user")), - # We only care about explicitly granted privileges and not inherited ones - # if this is removed you would see grants inherited from the catalog get returned - exp.column("inherited_from").eq(exp.Literal.string("NONE")), - ) - ) + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # We only care about explicitly granted privileges and not inherited ones + # if this is removed you would see grants inherited from the catalog get returned + expression = super()._get_grant_expression(table) + expression.args["where"].set( + "this", + exp.and_( + expression.args["where"].this, + exp.column("inherited_from").eq(exp.Literal.string("NONE")), + wrap=False, + ), ) - - results = self.fetchall(grant_expr) - - grants_dict: GrantsConfig = {} - for privilege_raw, grantee_raw in results: - if privilege_raw is None or grantee_raw is None: - continue - - privilege = str(privilege_raw) - grantee = str(grantee_raw) - if not privilege or not grantee: - continue - - grantees = grants_dict.setdefault(privilege, []) - if grantee not in grantees: - grantees.append(grantee) - - return grants_dict + return expression def _begin_session(self, properties: SessionProperties) -> t.Any: """Begin a new session.""" diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 1d66da0607..1a6fdea8c2 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -7,8 +7,10 @@ from sqlglot import exp, parse_one from sqlglot.helper import seq_get +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.node import IntervalUnit from sqlmesh.core.dialect import schema_ from sqlmesh.core.schema_diff import TableAlterOperation @@ -16,7 +18,12 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName - from sqlmesh.core.engine_adapter._typing import DF + from sqlmesh.core.engine_adapter._typing import ( + DCL, + DF, + GrantsConfig, + QueryOrDF, + ) from sqlmesh.core.engine_adapter.base import QueryOrDF logger = logging.getLogger(__name__) @@ -548,3 +555,137 @@ def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp. def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression: return exp.cast(expr, "INT") + + +class GrantsFromInfoSchemaMixin(EngineAdapter): + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False + USE_CATALOG_IN_GRANTS = False + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges" + + @staticmethod + @abc.abstractmethod + def _grant_object_kind(table_type: DataObjectType) -> t.Optional[str]: + pass + + @abc.abstractmethod + def _get_current_schema(self) -> str: + pass + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grant_config: + return expressions + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grant_config.items(): + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], + "securable": table.copy(), + } + if object_kind: + args["kind"] = exp.Var(this=object_kind) + if self.SUPPORTS_MULTIPLE_GRANT_PRINCIPALS: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + for principal in principals + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + else: + for principal in principals: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + + def _apply_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grant_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + schema_identifier = table.args.get("db") or normalize_identifiers( + exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect + ) + schema_name = schema_identifier.this + table_name = table.args.get("this").this # type: ignore + + grant_conditions = [ + exp.column("table_schema").eq(exp.Literal.string(schema_name)), + exp.column("table_name").eq(exp.Literal.string(table_name)), + exp.column("grantor").eq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + exp.column("grantee").neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + ] + + info_schema_table = normalize_identifiers( + exp.table_(self.GRANT_INFORMATION_SCHEMA_TABLE_NAME, db="information_schema"), + dialect=self.dialect, + ) + if self.USE_CATALOG_IN_GRANTS: + catalog_identifier = table.args.get("catalog") + if not catalog_identifier: + catalog_name = self.get_current_catalog() + if not catalog_name: + raise SQLMeshError( + "Current catalog could not be determined for fetching grants. This is unexpected." + ) + catalog_identifier = normalize_identifiers( + exp.to_identifier(catalog_name, quoted=True), dialect=self.dialect + ) + catalog_name = catalog_identifier.this + info_schema_table.set("catalog", catalog_identifier.copy()) + grant_conditions.insert( + 0, exp.column("table_catalog").eq(exp.Literal.string(catalog_name)) + ) + + return ( + exp.select("privilege_type", "grantee") + .from_(info_schema_table) + .where(exp.and_(*grant_conditions)) + ) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + grant_expr = self._get_grant_expression(table) + + results = self.fetchall(grant_expr) + + grants_dict: GrantsConfig = {} + for privilege_raw, grantee_raw in results: + if privilege_raw is None or grantee_raw is None: + continue + + privilege = str(privilege_raw) + grantee = str(grantee_raw) + if not privilege or not grantee: + continue + + grantees = grants_dict.setdefault(privilege, []) + if grantee not in grantees: + grantees.append(grantee) + + return grants_dict diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index b052e6d829..feb0f47908 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -13,12 +13,13 @@ PandasNativeFetchDFSupportMixin, RowDiffMixin, logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import set_catalog if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName - from sqlmesh.core.engine_adapter._typing import DCL, DF, GrantsConfig, QueryOrDF + from sqlmesh.core.engine_adapter._typing import DF, GrantsConfig, QueryOrDF logger = logging.getLogger(__name__) @@ -29,6 +30,7 @@ class PostgresEngineAdapter( PandasNativeFetchDFSupportMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "postgres" SUPPORTS_GRANTS = True @@ -38,6 +40,9 @@ class PostgresEngineAdapter( SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 SUPPORTS_QUERY_EXECUTION_TRACKING = True + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants" + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" @@ -138,26 +143,6 @@ def server_version(self) -> t.Tuple[int, int]: return int(match.group(1)), int(match.group(2)) return 0, 0 - def _dcl_grants_config_expr( - self, - dcl_cmd: t.Type[DCL], - relation: exp.Expression, - grant_config: GrantsConfig, - ) -> t.List[exp.Expression]: - expressions = [] - for privilege, principals in grant_config.items(): - if not principals: - continue - - grant = dcl_cmd( - privileges=[exp.GrantPrivilege(this=exp.Var(this=privilege))], - securable=relation, - principals=principals, # use original strings; no quoting - ) - expressions.append(grant) - - return t.cast(t.List[exp.Expression], expressions) - def _apply_grants_config_expr( self, table: exp.Table, @@ -175,36 +160,3 @@ def _revoke_grants_config_expr( ) -> t.List[exp.Expression]: # https://www.postgresql.org/docs/current/sql-revoke.html return self._dcl_grants_config_expr(exp.Revoke, table, grant_config) - - def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: - """Returns current grants for a Postgres table as a dictionary.""" - table_schema = table.db or self.get_current_schema() - table_name = table.name - - # https://www.postgresql.org/docs/current/infoschema-role-table-grants.html - grant_expr = ( - exp.select("privilege_type", "grantee") - .from_(exp.table_("role_table_grants", db="information_schema")) - .where( - exp.and_( - exp.column("table_schema").eq(exp.Literal.string(table_schema)), - exp.column("table_name").eq(exp.Literal.string(table_name)), - exp.column("grantor").eq(exp.column("current_role")), - exp.column("grantee").neq(exp.column("current_role")), - ) - ) - ) - results = self.fetchall(grant_expr) - - grants_dict: t.Dict[str, t.List[str]] = {} - for row in results: - privilege = str(row[0]) - grantee = str(row[1]) - - if privilege not in grants_dict: - grants_dict[privilege] = [] - - if grantee not in grants_dict[privilege]: - grants_dict[privilege].append(grantee) - - return grants_dict diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 5c23b4b8e6..34b64503b3 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -14,6 +14,7 @@ VarcharSizeWorkaroundMixin, RowDiffMixin, logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import ( CommentCreationView, @@ -28,7 +29,6 @@ import pandas as pd from sqlmesh.core._typing import SchemaName, TableName - from sqlmesh.core.engine_adapter._typing import DCL, GrantsConfig from sqlmesh.core.engine_adapter.base import QueryOrDF, Query logger = logging.getLogger(__name__) @@ -41,6 +41,7 @@ class RedshiftEngineAdapter( NonTransactionalTruncateMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "redshift" CURRENT_CATALOG_EXPRESSION = exp.func("current_database") @@ -48,6 +49,7 @@ class RedshiftEngineAdapter( COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False SUPPORTS_GRANTS = True + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { @@ -173,88 +175,6 @@ def _grant_object_kind(table_type: DataObjectType) -> str: return "MATERIALIZED VIEW" return "TABLE" - def _dcl_grants_config_expr( - self, - dcl_cmd: t.Type[DCL], - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - expressions: t.List[exp.Expression] = [] - if not grant_config: - return expressions - - object_kind = self._grant_object_kind(table_type) - for privilege, principals in grant_config.items(): - if not principals: - continue - - args: t.Dict[str, t.Any] = { - "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], - "securable": table.copy(), - "principals": principals, - } - - if object_kind: - args["kind"] = exp.Var(this=object_kind) - - expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] - - return expressions - - def _apply_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) - - def _revoke_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) - - def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: - """Returns current grants for a Redshift table as a dictionary.""" - table_schema = table.db or self.get_current_schema() - table_name = table.name - current_user = exp.func("current_user") - - grant_expr = ( - exp.select("privilege_type", "grantee") - .from_(exp.table_("table_privileges", db="information_schema")) - .where( - exp.and_( - exp.column("table_schema").eq(exp.Literal.string(table_schema)), - exp.column("table_name").eq(exp.Literal.string(table_name)), - exp.column("grantor").eq(current_user), - exp.column("grantee").neq(current_user), - ) - ) - ) - - results = self.fetchall(grant_expr) - - grants_dict: GrantsConfig = {} - for privilege_raw, grantee_raw in results: - if privilege_raw is None or grantee_raw is None: - continue - - privilege = str(privilege_raw) - grantee = str(grantee_raw) - if not privilege or not grantee: - continue - - grants_dict.setdefault(privilege, []) - if grantee not in grants_dict[privilege]: - grants_dict[privilege].append(grantee) - - return grants_dict - def _create_table_from_source_queries( self, table_name: TableName, diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 557467ad9e..a86bc63037 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -15,6 +15,7 @@ GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -35,9 +36,7 @@ from sqlmesh.core._typing import SchemaName, SessionProperties, TableName from sqlmesh.core.engine_adapter._typing import ( - DCL, DF, - GrantsConfig, Query, QueryOrDF, SnowparkSession, @@ -53,7 +52,9 @@ "drop_catalog": CatalogSupport.REQUIRES_SET_CATALOG, # needs a catalog to issue a query to information_schema.databases even though the result is global } ) -class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin): +class SnowflakeEngineAdapter( + GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin +): DIALECT = "snowflake" SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True @@ -81,6 +82,8 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi SNOWPARK = "snowpark" SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTS_GRANTS = True + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE") + USE_CATALOG_IN_GRANTS = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -152,101 +155,6 @@ def _get_current_schema(self) -> str: raise SQLMeshError("Unable to determine current schema") return str(result[0]) - def _dcl_grants_config_expr( - self, - dcl_cmd: t.Type[DCL], - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - expressions: t.List[exp.Expression] = [] - if not grant_config: - return expressions - - object_kind = self._grant_object_kind(table_type) - for privilege, principals in grant_config.items(): - for principal in principals: - args: t.Dict[str, t.Any] = { - "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], - "securable": table.copy(), - "principals": [principal], - } - - if object_kind: - args["kind"] = exp.Var(this=object_kind) - - expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] - - return expressions - - def _apply_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type) - - def _revoke_grants_config_expr( - self, - table: exp.Table, - grant_config: GrantsConfig, - table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type) - - def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: - schema_identifier = table.args.get("db") or normalize_identifiers( - exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect - ) - catalog_identifier = table.args.get("catalog") - if not catalog_identifier: - current_catalog = self.get_current_catalog() - if not current_catalog: - raise SQLMeshError("Unable to determine current catalog for fetching grants") - catalog_identifier = normalize_identifiers( - exp.to_identifier(current_catalog, quoted=True), dialect=self.dialect - ) - catalog_identifier.set("quoted", True) - table_identifier = table.args.get("this") - - grant_expr = ( - exp.select("privilege_type", "grantee") - .from_( - exp.table_( - "TABLE_PRIVILEGES", - db="INFORMATION_SCHEMA", - catalog=catalog_identifier, - ) - ) - .where( - exp.and_( - exp.column("table_schema").eq(exp.Literal.string(schema_identifier.this)), - exp.column("table_name").eq(exp.Literal.string(table_identifier.this)), # type: ignore - exp.column("grantor").eq(exp.func("CURRENT_ROLE")), - exp.column("grantee").neq(exp.func("CURRENT_ROLE")), - ) - ) - ) - - results = self.fetchall(grant_expr) - - grants_dict: GrantsConfig = {} - for privilege_raw, grantee_raw in results: - if privilege_raw is None or grantee_raw is None: - continue - - privilege = str(privilege_raw) - grantee = str(grantee_raw) - if not privilege or not grantee: - continue - - grantees = grants_dict.setdefault(privilege, []) - if grantee not in grantees: - grantees.append(grantee) - - return grants_dict - def _create_catalog(self, catalog_name: exp.Identifier) -> None: props = exp.Properties( expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))] @@ -652,13 +560,32 @@ def _get_data_objects( for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples() ] + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides + # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying + # the information schema. So we need to manually replace those here. + expression = super()._get_grant_expression(table) + for col_exp in expression.find_all(exp.Column): + if col_exp.this.name == "table_catalog": + and_exp = col_exp.parent + assert and_exp is not None, "Expected column expression to have a parent" + assert and_exp.expression, "Expected AND expression to have an expression" + normalized_catalog = self._normalize_catalog( + exp.table_("placeholder", db="placeholder", catalog=and_exp.expression.this) + ) + and_exp.set( + "expression", + exp.Literal.string(normalized_catalog.args["catalog"].alias_or_name), + ) + return expression + def set_current_catalog(self, catalog: str) -> None: self.execute(exp.Use(this=exp.to_identifier(catalog))) def set_current_schema(self, schema: str) -> None: self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema))) - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression: # note: important to use self._default_catalog instead of the self.default_catalog property # otherwise we get RecursionError: maximum recursion depth exceeded # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc @@ -691,8 +618,12 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression: # Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match # the source dialect, which isnt always compatible with Snowflake expression = expression.transform(catalog_rewriter) + return expression - return super()._to_sql(expression=expression, quote=quote, **kwargs) + def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + return super()._to_sql( + expression=self._normalize_catalog(expression), quote=quote, **kwargs + ) def _create_column_comments( self, diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 18ba6ea106..7b3e23cd8e 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -397,7 +397,7 @@ def get_current_catalog(self) -> t.Optional[str]: def set_current_catalog(self, catalog_name: str) -> None: self.connection.set_current_catalog(catalog_name) - def get_current_database(self) -> str: + def _get_current_schema(self) -> str: if self._use_spark_session: return self.spark.catalog.currentDatabase() return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore @@ -537,7 +537,7 @@ def _ensure_fqn(self, table_name: TableName) -> exp.Table: if not table.catalog: table.set("catalog", self.get_current_catalog()) if not table.db: - table.set("db", self.get_current_database()) + table.set("db", self._get_current_schema()) return table def _build_create_comment_column_exp( diff --git a/tests/core/engine_adapter/test_base_postgres.py b/tests/core/engine_adapter/test_base_postgres.py index 1c410693b7..f286c47c56 100644 --- a/tests/core/engine_adapter/test_base_postgres.py +++ b/tests/core/engine_adapter/test_base_postgres.py @@ -82,7 +82,7 @@ def test_get_current_schema(make_mocked_engine_adapter: t.Callable, mocker: Mock adapter = make_mocked_engine_adapter(BasePostgresEngineAdapter) fetchone_mock = mocker.patch.object(adapter, "fetchone", return_value=("test_schema",)) - result = adapter.get_current_schema() + result = adapter._get_current_schema() assert result == "test_schema" fetchone_mock.assert_called_once() @@ -92,10 +92,10 @@ def test_get_current_schema(make_mocked_engine_adapter: t.Callable, mocker: Mock fetchone_mock.reset_mock() fetchone_mock.return_value = None - result = adapter.get_current_schema() + result = adapter._get_current_schema() assert result == "public" fetchone_mock.reset_mock() fetchone_mock.return_value = (None,) # search_path = '' or 'nonexistent_schema' - result = adapter.get_current_schema() + result = adapter._get_current_schema() assert result == "public" diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index eb320e03ca..52145740fc 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -128,14 +128,14 @@ def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t. assert to_sql_calls(adapter) == ["SELECT CURRENT_CATALOG()"] -def test_get_current_database(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): +def test_get_current_schema(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): mocker.patch( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" ) adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_database",) - assert adapter.get_current_database() == "test_database" + assert adapter._get_current_schema() == "test_database" assert to_sql_calls(adapter) == ["SELECT CURRENT_DATABASE()"] @@ -260,7 +260,7 @@ def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocke executed_query = fetchall_mock.call_args[0][0] executed_sql = executed_query.sql(dialect="databricks") expected_sql = ( - "SELECT privilege_type, grantee FROM test_db.information_schema.table_privileges " + "SELECT privilege_type, grantee FROM `test_db`.information_schema.table_privileges " "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" ) @@ -291,7 +291,7 @@ def test_sync_grants_config_no_catalog_or_schema( ("REFRESH", "stale"), ] fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) - mocker.patch.object(adapter, "get_current_database", return_value="schema") + mocker.patch.object(adapter, "_get_current_schema", return_value="schema") mocker.patch.object(adapter, "get_current_catalog", return_value="main_catalog") adapter.sync_grants_config(relation, new_grants_config) @@ -300,7 +300,7 @@ def test_sync_grants_config_no_catalog_or_schema( executed_query = fetchall_mock.call_args[0][0] executed_sql = executed_query.sql(dialect="databricks") expected_sql = ( - "SELECT privilege_type, grantee FROM main_catalog.information_schema.table_privileges " + "SELECT privilege_type, grantee FROM `main_catalog`.information_schema.table_privileges " "WHERE table_catalog = 'main_catalog' AND table_schema = 'schema' AND table_name = 'test_table' " "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" ) diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index f75bd594e9..ebcdd03f55 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -202,10 +202,10 @@ def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: Mock sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 4 - assert 'GRANT SELECT ON "test_schema"."test_table" TO user1, user2' in sql_calls - assert 'GRANT INSERT ON "test_schema"."test_table" TO user3' in sql_calls - assert 'REVOKE SELECT ON "test_schema"."test_table" FROM old_user' in sql_calls - assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM admin_user' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user3"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "admin_user"' in sql_calls def test_sync_grants_config_with_overlaps( @@ -238,10 +238,10 @@ def test_sync_grants_config_with_overlaps( sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 4 - assert 'GRANT SELECT ON "test_schema"."test_table" TO user2, user3' in sql_calls - assert 'GRANT INSERT ON "test_schema"."test_table" TO user4' in sql_calls - assert 'REVOKE SELECT ON "test_schema"."test_table" FROM user5' in sql_calls - assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM user3' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user2", "user3"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user4"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "user5"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "user3"' in sql_calls def test_diff_grants_configs(make_mocked_engine_adapter: t.Callable): @@ -267,7 +267,7 @@ def test_sync_grants_config_with_default_schema( currrent_grants = [("UPDATE", "old_user")] fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=currrent_grants) - get_schema_mock = mocker.patch.object(adapter, "get_current_schema", return_value="public") + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") adapter.sync_grants_config(relation, new_grants_config) diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index 27a2adb1ea..d77ee67b86 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -105,10 +105,10 @@ def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: Mock sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 4 - assert 'REVOKE SELECT ON TABLE "test_schema"."test_table" FROM old_user' in sql_calls - assert 'REVOKE UPDATE ON TABLE "test_schema"."test_table" FROM legacy_user' in sql_calls - assert 'GRANT SELECT ON TABLE "test_schema"."test_table" TO user1, user2' in sql_calls - assert 'GRANT INSERT ON TABLE "test_schema"."test_table" TO user3' in sql_calls + assert 'REVOKE SELECT ON TABLE "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON TABLE "test_schema"."test_table" FROM "legacy_user"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_schema"."test_table" TO "user3"' in sql_calls def test_sync_grants_config_with_overlaps( @@ -142,9 +142,9 @@ def test_sync_grants_config_with_overlaps( sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 3 - assert 'REVOKE SELECT ON TABLE "test_schema"."test_table" FROM user_legacy' in sql_calls - assert 'GRANT SELECT ON TABLE "test_schema"."test_table" TO user_new' in sql_calls - assert 'GRANT INSERT ON TABLE "test_schema"."test_table" TO user_writer' in sql_calls + assert 'REVOKE SELECT ON TABLE "test_schema"."test_table" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_schema"."test_table" TO "user_new"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_schema"."test_table" TO "user_writer"' in sql_calls @pytest.mark.parametrize( @@ -170,7 +170,7 @@ def test_sync_grants_config_object_kind( sql_calls = to_sql_calls(adapter) assert sql_calls == [ - f'GRANT SELECT ON {expected_keyword} "test_schema"."test_object" TO user_test' + f'GRANT SELECT ON {expected_keyword} "test_schema"."test_object" TO "user_test"' ] @@ -196,10 +196,10 @@ def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocke sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 4 - assert 'REVOKE SELECT ON TABLE "TestSchema"."TestTable" FROM user_old' in sql_calls - assert 'REVOKE UPDATE ON TABLE "TestSchema"."TestTable" FROM user_legacy' in sql_calls - assert 'GRANT SELECT ON TABLE "TestSchema"."TestTable" TO user1, user2' in sql_calls - assert 'GRANT INSERT ON TABLE "TestSchema"."TestTable" TO user3' in sql_calls + assert 'REVOKE SELECT ON TABLE "TestSchema"."TestTable" FROM "user_old"' in sql_calls + assert 'REVOKE UPDATE ON TABLE "TestSchema"."TestTable" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON TABLE "TestSchema"."TestTable" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TestSchema"."TestTable" TO "user3"' in sql_calls def test_sync_grants_config_no_schema( @@ -211,7 +211,7 @@ def test_sync_grants_config_no_schema( current_grants = [("UPDATE", "user_old")] fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) - get_schema_mock = mocker.patch.object(adapter, "get_current_schema", return_value="public") + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") adapter.sync_grants_config(relation, new_grants_config) @@ -228,9 +228,9 @@ def test_sync_grants_config_no_schema( sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 3 - assert 'REVOKE UPDATE ON TABLE "test_table" FROM user_old' in sql_calls - assert 'GRANT SELECT ON TABLE "test_table" TO user1' in sql_calls - assert 'GRANT INSERT ON TABLE "test_table" TO user2' in sql_calls + assert 'REVOKE UPDATE ON TABLE "test_table" FROM "user_old"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_table" TO "user1"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_table" TO "user2"' in sql_calls def test_create_table_from_query_exists_no_if_not_exists( diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index a84a6666ad..62d9dc3e6f 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -265,8 +265,8 @@ def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: Mock executed_query = fetchall_mock.call_args[0][0] executed_sql = executed_query.sql(dialect="snowflake") expected_sql = ( - 'SELECT privilege_type, grantee FROM "TEST_DB".INFORMATION_SCHEMA.TABLE_PRIVILEGES ' - "WHERE table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES " + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" ) assert executed_sql == expected_sql @@ -274,15 +274,15 @@ def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: Mock sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 5 - assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role1' in sql_calls - assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role2' in sql_calls - assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE role3' in sql_calls + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE3"' in sql_calls assert ( - 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE old_role' + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "OLD_ROLE"' in sql_calls ) assert ( - 'REVOKE UPDATE ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE legacy_role' + 'REVOKE UPDATE ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY_ROLE"' in sql_calls ) @@ -312,8 +312,8 @@ def test_sync_grants_config_with_overlaps( executed_query = fetchall_mock.call_args[0][0] executed_sql = executed_query.sql(dialect="snowflake") expected_sql = ( - """SELECT privilege_type, grantee FROM "TEST_DB".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ - "WHERE table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + """SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" ) assert executed_sql == expected_sql @@ -322,11 +322,14 @@ def test_sync_grants_config_with_overlaps( assert len(sql_calls) == 3 assert ( - 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE new_role' in sql_calls + 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "NEW_ROLE"' in sql_calls ) - assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE writer' in sql_calls assert ( - 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE legacy' in sql_calls + 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "WRITER"' in sql_calls + ) + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY"' + in sql_calls ) @@ -356,7 +359,7 @@ def test_sync_grants_config_object_kind( sql_calls = to_sql_calls(adapter) assert sql_calls == [ - f'GRANT SELECT ON {expected_keyword} "TEST_DB"."TEST_SCHEMA"."TEST_OBJECT" TO ROLE test' + f'GRANT SELECT ON {expected_keyword} "TEST_DB"."TEST_SCHEMA"."TEST_OBJECT" TO ROLE "TEST"' ] @@ -381,7 +384,7 @@ def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocke executed_sql = executed_query.sql(dialect="snowflake") expected_sql = ( """SELECT privilege_type, grantee FROM "test_db".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ - "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" ) assert executed_sql == expected_sql @@ -389,15 +392,15 @@ def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocke sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 5 - assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role1' in sql_calls - assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role2' in sql_calls - assert 'GRANT INSERT ON TABLE "test_db"."test_schema"."test_table" TO ROLE role3' in sql_calls + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE3"' in sql_calls assert ( - 'REVOKE SELECT ON TABLE "test_db"."test_schema"."test_table" FROM ROLE old_role' + 'REVOKE SELECT ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "OLD_ROLE"' in sql_calls ) assert ( - 'REVOKE UPDATE ON TABLE "test_db"."test_schema"."test_table" FROM ROLE legacy_role' + 'REVOKE UPDATE ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "LEGACY_ROLE"' in sql_calls ) @@ -426,7 +429,7 @@ def test_sync_grants_config_no_catalog_or_schema( executed_sql = executed_query.sql(dialect="snowflake") expected_sql = ( """SELECT privilege_type, grantee FROM "caTalog".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ - "WHERE table_schema = 'sChema' AND table_name = 'TesT_Table' " + "WHERE table_catalog = 'caTalog' AND table_schema = 'sChema' AND table_name = 'TesT_Table' " "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" ) assert executed_sql == expected_sql @@ -434,11 +437,11 @@ def test_sync_grants_config_no_catalog_or_schema( sql_calls = to_sql_calls(adapter) assert len(sql_calls) == 5 - assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE role1' in sql_calls - assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE role2' in sql_calls - assert 'GRANT INSERT ON TABLE "TesT_Table" TO ROLE role3' in sql_calls - assert 'REVOKE SELECT ON TABLE "TesT_Table" FROM ROLE old_role' in sql_calls - assert 'REVOKE UPDATE ON TABLE "TesT_Table" FROM ROLE legacy_role' in sql_calls + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TesT_Table" TO ROLE "ROLE3"' in sql_calls + assert 'REVOKE SELECT ON TABLE "TesT_Table" FROM ROLE "OLD_ROLE"' in sql_calls + assert 'REVOKE UPDATE ON TABLE "TesT_Table" FROM ROLE "LEGACY_ROLE"' in sql_calls def test_df_to_source_queries_use_schema( diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index bc4e352bd7..d7c3127f05 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -224,7 +224,7 @@ def test_replace_query_self_ref_not_exists( lambda self: "spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", side_effect=lambda: "default", ) @@ -283,7 +283,7 @@ def test_replace_query_self_ref_exists( return_value="spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", return_value="default", )