From 85862592abd7dce4430f5e5a146ee9270c2111b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Aug 2024 10:46:53 +0800 Subject: [PATCH 1/3] loadTable should indicate if it's for writing --- .../sql/connector/catalog/TableCatalog.java | 11 +++ .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../sql/catalyst/analysis/unresolved.scala | 11 +++ .../sql/catalyst/parser/AstBuilder.scala | 38 ++++++---- .../sql/connector/catalog/CatalogV2Util.scala | 15 ++-- .../apache/spark/sql/DataFrameWriter.scala | 11 +-- .../apache/spark/sql/DataFrameWriterV2.scala | 7 +- .../apache/spark/sql/MergeIntoWriter.scala | 2 +- .../analysis/ResolveSessionCatalog.scala | 4 +- .../v2/WriteToDataSourceV2Exec.scala | 4 +- .../sql/connector/DataSourceV2SQLSuite.scala | 71 +++++++++++++++++++ 11 files changed, 147 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index facfc0d774e89..bd032b297e6f4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -110,6 +110,17 @@ public interface TableCatalog extends CatalogPlugin { */ Table loadTable(Identifier ident) throws NoSuchTableException; + /** + * A variant of {@link #loadTable(Identifier)} that indicates it's for data writing. + * Implementations can override this method to do additional handling for data writing, such as + * checking write permissions. + * + * @since 4.0.0 + */ + default Table loadTableForWrite(Identifier ident) throws NoSuchTableException { + return loadTable(ident); + } + /** * Load table metadata of a specific version by {@link Identifier identifier} from the catalog. *

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b2c36dda03143..4c3d33f93f0a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1321,7 +1321,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor cachedConnectRelation }.getOrElse(cachedRelation) }.orElse { - val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec) + val forWrite = "true".equalsIgnoreCase(u.options.get(UnresolvedRelation.FOR_WRITE)) + val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, forWrite) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f7a987368ec09..66e0fe77f93de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -124,10 +124,21 @@ case class UnresolvedRelation( override def name: String = tableName + def forWrite: UnresolvedRelation = { + val newOptions = new java.util.HashMap[String, String] + newOptions.put(UnresolvedRelation.FOR_WRITE, "true") + newOptions.putAll(options) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION) } object UnresolvedRelation { + // An internal option of `UnresolvedRelation` to indicate that we look up this relation for data + // writing. + val FOR_WRITE = "__for_write__" + def apply( tableIdentifier: TableIdentifier, extraOptions: CaseInsensitiveStringMap, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c6e0467b3aff2..0536f880af632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -455,7 +455,7 @@ class AstBuilder extends DataTypeAstBuilder = visitInsertIntoTable(table) withIdentClause(relationCtx, ident => { val insertIntoStatement = InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options), + createUnresolvedRelation(relationCtx, ident, options, forWrite = true), partition, cols, query, @@ -473,7 +473,7 @@ class AstBuilder extends DataTypeAstBuilder = visitInsertOverwriteTable(table) withIdentClause(relationCtx, ident => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options), + createUnresolvedRelation(relationCtx, ident, options, forWrite = true), partition, cols, query, @@ -482,9 +482,10 @@ class AstBuilder extends DataTypeAstBuilder byName) }) case ctx: InsertIntoReplaceWhereContext => + val options = Option(ctx.optionsClause()) withIdentClause(ctx.identifierReference, ident => { OverwriteByExpression.byPosition( - createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())), + createUnresolvedRelation(ctx.identifierReference, ident, options, forWrite = true), query, expression(ctx.whereClause().booleanExpression())) }) @@ -569,7 +570,7 @@ class AstBuilder extends DataTypeAstBuilder override def visitDeleteFromTable( ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val predicate = if (ctx.whereClause() != null) { @@ -581,7 +582,7 @@ class AstBuilder extends DataTypeAstBuilder } override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val assignments = withAssignments(ctx.setClause().assignmentList()) @@ -604,7 +605,7 @@ class AstBuilder extends DataTypeAstBuilder override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { val withSchemaEvolution = ctx.EVOLUTION() != null - val targetTable = createUnresolvedRelation(ctx.target) + val targetTable = createUnresolvedRelation(ctx.target, forWrite = true) val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) @@ -3115,10 +3116,17 @@ class AstBuilder extends DataTypeAstBuilder */ private def createUnresolvedRelation( ctx: IdentifierReferenceContext, - optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) { + optionsClause: Option[OptionsClauseContext] = None, + forWrite: Boolean = false): LogicalPlan = withOrigin(ctx) { val options = resolveOptions(optionsClause) - withIdentClause(ctx, parts => - new UnresolvedRelation(parts, options, isStreaming = false)) + withIdentClause(ctx, parts => { + val relation = new UnresolvedRelation(parts, options, isStreaming = false) + if (forWrite) { + relation.forWrite + } else { + relation + } + }) } /** @@ -3127,9 +3135,15 @@ class AstBuilder extends DataTypeAstBuilder private def createUnresolvedRelation( ctx: ParserRuleContext, ident: Seq[String], - optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) { + optionsClause: Option[OptionsClauseContext], + forWrite: Boolean): UnresolvedRelation = withOrigin(ctx) { val options = resolveOptions(optionsClause) - new UnresolvedRelation(ident, options, isStreaming = false) + val relation = new UnresolvedRelation(ident, options, isStreaming = false) + if (forWrite) { + relation.forWrite + } else { + relation + } } private def resolveOptions( @@ -5005,7 +5019,7 @@ class AstBuilder extends DataTypeAstBuilder if (query.isDefined) { CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { - CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None), + CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None, forWrite = false), ident, isLazy, options) } }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 283c550c4556f..21c8c733fb09c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -403,9 +403,10 @@ private[sql] object CatalogV2Util { def loadTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] = + timeTravelSpec: Option[TimeTravelSpec] = None, + forWrite: Boolean = false): Option[Table] = try { - Option(getTable(catalog, ident, timeTravelSpec)) + Option(getTable(catalog, ident, timeTravelSpec, forWrite)) } catch { case _: NoSuchTableException => None case _: NoSuchDatabaseException => None @@ -414,8 +415,10 @@ private[sql] object CatalogV2Util { def getTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Table = { + timeTravelSpec: Option[TimeTravelSpec] = None, + forWrite: Boolean = false): Table = { if (timeTravelSpec.nonEmpty) { + assert(!forWrite, "Should not write to a table with time travel") timeTravelSpec.get match { case v: AsOfVersion => catalog.asTableCatalog.loadTable(ident, v.version) @@ -423,7 +426,11 @@ private[sql] object CatalogV2Util { catalog.asTableCatalog.loadTable(ident, ts.timestamp) } } else { - catalog.asTableCatalog.loadTable(ident) + if (forWrite) { + catalog.asTableCatalog.loadTableForWrite(ident) + } else { + catalog.asTableCatalog.loadTable(ident) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 991487170f177..944085e0deb31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -473,7 +474,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val table = catalog.asTableCatalog.loadTable(ident) match { + val table = catalog.asTableCatalog.loadTableForWrite(ident) match { case _: V1Table => return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => @@ -504,7 +505,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(tableIdent: TableIdentifier): Unit = { runCommand(df.sparkSession) { InsertIntoStatement( - table = UnresolvedRelation(tableIdent), + table = UnresolvedRelation(tableIdent).forWrite, partitionSpec = Map.empty[String, Option[String]], Nil, query = df.logicalPlan, @@ -588,7 +589,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined + val canUseV2 = lookupV2Provider().isDefined || + df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => @@ -609,7 +611,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable( catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = { - val tableOpt = try Option(catalog.loadTable(ident)) catch { + val tableOpt = try Option(catalog.loadTableForWrite(ident)) catch { case _: NoSuchTableException => None } @@ -670,7 +672,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val catalog = df.sparkSession.sessionState.catalog val qualifiedIdent = catalog.qualifyIdentifier(tableIdent) val tableExists = catalog.tableExists(qualifiedIdent) - val tableName = qualifiedIdent.unquotedString (tableExists, mode) match { case (true, SaveMode.Ignore) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 9b824074533af..1de6fa376d6dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -168,7 +168,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def append(): Unit = { - val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap) + val append = AppendData.byName( + UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap) runCommand(append) } @@ -185,7 +186,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) + UnresolvedRelation(tableName).forWrite, logicalPlan, condition.expr, options.toMap) runCommand(overwrite) } @@ -205,7 +206,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName), logicalPlan, options.toMap) + UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap) runCommand(dynamicOverwrite) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index b7f9c96f82e04..857965f8a2f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -188,7 +188,7 @@ class MergeIntoWriter[T] private[sql] ( } val merge = MergeIntoTable( - UnresolvedRelation(tableName), + UnresolvedRelation(tableName).forWrite, logicalPlan, on.expr, matchedActions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 9053fb9cc73f6..20e3b4e980f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -683,7 +683,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } private def supportsV1Command(catalog: CatalogPlugin): Boolean = { - catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) && - !SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined + isSessionCatalog(catalog) && + SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5632595de7cf8..67fb1fe95c9f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -84,7 +84,7 @@ case class CreateTableAsSelectExec( } val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), - partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident)) + partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident)) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -164,7 +164,7 @@ case class ReplaceTableAsSelectExec( } val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), - partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident)) + partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident)) writeToTable(catalog, table, writeOptions, ident, query) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4a627727b1ed9..b36843f5449c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3703,6 +3703,63 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-49246: read-only catalog") { + def checkWriteOperations(catalog: String): Unit = { + withSQLConf(s"spark.sql.catalog.$catalog" -> classOf[ReadOnlyCatalog].getName) { + val input = sql("SELECT 1") + val tbl = s"$catalog.default.t" + withTable(tbl) { + sql(s"CREATE TABLE $tbl (i INT)") + val df = sql(s"SELECT * FROM $tbl") + assert(df.collect().isEmpty) + assert(df.schema == new StructType().add("i", "int")) + + intercept[RuntimeException](sql(s"INSERT INTO $tbl SELECT 1")) + intercept[RuntimeException](sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1")) + intercept[RuntimeException] (sql(s"INSERT OVERWRITE $tbl SELECT 1")) + intercept[RuntimeException] (sql(s"DELETE FROM $tbl WHERE i = 0")) + intercept[RuntimeException] (sql(s"UPDATE $tbl SET i = 0")) + intercept[RuntimeException] { + sql( + s""" + |MERGE INTO $tbl USING (SELECT 1 i) AS source + |ON source.i = $tbl.i + |WHEN NOT MATCHED THEN INSERT * + |""".stripMargin) + } + + intercept[RuntimeException](input.write.insertInto(tbl)) + intercept[RuntimeException](input.write.mode("append").saveAsTable(tbl)) + intercept[RuntimeException](input.writeTo(tbl).append()) + intercept[RuntimeException](input.writeTo(tbl).overwrite(df.col("i") === 1)) + intercept[RuntimeException](input.writeTo(tbl).overwritePartitions()) + } + + // Test CTAS + withTable(tbl) { + intercept[RuntimeException](sql(s"CREATE TABLE $tbl AS SELECT 1 i")) + } + withTable(tbl) { + intercept[RuntimeException](sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i")) + } + withTable(tbl) { + intercept[RuntimeException](input.write.saveAsTable(tbl)) + } + withTable(tbl) { + intercept[RuntimeException](input.writeTo(tbl).create()) + } + withTable(tbl) { + intercept[RuntimeException](input.writeTo(tbl).createOrReplace()) + } + } + } + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + checkWriteOperations(SESSION_CATALOG_NAME) + checkWriteOperations("read_only_cat") + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, @@ -3771,3 +3828,17 @@ class V2CatalogSupportBuiltinDataSource extends InMemoryCatalog { } } +class ReadOnlyCatalog extends InMemoryCatalog { + override def createTable( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + super.createTable(ident, columns, partitions, properties) + null + } + + override def loadTableForWrite(ident: Identifier): Table = { + throw new RuntimeException("cannot write") + } +} From 6bb83d58bbde0bda772e3f2521cf3bc6b9f50176 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 16 Aug 2024 16:03:28 +0800 Subject: [PATCH 2/3] address comments --- .../sql/connector/catalog/TableCatalog.java | 17 +++- .../catalog/TableWritePrivilege.java | 40 ++++++++ .../sql/catalyst/analysis/Analyzer.scala | 9 +- .../sql/catalyst/analysis/unresolved.scala | 32 +++++-- .../sql/catalyst/parser/AstBuilder.scala | 56 +++++------ .../catalyst/plans/logical/v2Commands.scala | 15 +++ .../sql/connector/catalog/CatalogV2Util.scala | 14 +-- .../sql/catalyst/parser/DDLParserSuite.scala | 29 ++++-- .../sql/catalyst/parser/PlanParserSuite.scala | 93 ++++++++++--------- .../apache/spark/sql/DataFrameWriter.scala | 13 ++- .../apache/spark/sql/DataFrameWriterV2.scala | 10 +- .../apache/spark/sql/MergeIntoWriter.scala | 3 +- .../v2/WriteToDataSourceV2Exec.scala | 8 +- .../analyzer-results/explain-aqe.sql.out | 2 +- .../analyzer-results/explain.sql.out | 2 +- .../sql-tests/results/explain-aqe.sql.out | 2 +- .../sql-tests/results/explain.sql.out | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 63 ++++++++----- .../command/AlignAssignmentsSuiteBase.scala | 4 +- .../command/PlanResolutionSuite.scala | 6 +- 20 files changed, 277 insertions(+), 143 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index bd032b297e6f4..6951d2b18765c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -111,13 +111,22 @@ public interface TableCatalog extends CatalogPlugin { Table loadTable(Identifier ident) throws NoSuchTableException; /** - * A variant of {@link #loadTable(Identifier)} that indicates it's for data writing. - * Implementations can override this method to do additional handling for data writing, such as - * checking write permissions. + * Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data + * into this table later. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @param writePrivileges + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist or is a view * * @since 4.0.0 */ - default Table loadTableForWrite(Identifier ident) throws NoSuchTableException { + default Table loadTable( + Identifier ident, + Set writePrivileges) throws NoSuchTableException { return loadTable(ident); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java new file mode 100644 index 0000000000000..2e6a6a724f11c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +/** + * The table write privileges that will be provided when loading a table. + * + * @since 4.0.0 + */ +public enum TableWritePrivilege { + /** + * The privilege for adding rows to the table. + */ + INSERT, + + /** + * The privilege for changing existing rows in th table. + */ + UPDATE, + + /** + * The privilege for deleting rows from the table. + */ + DELETE +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4c3d33f93f0a9..e94c650bae5c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1321,9 +1321,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor cachedConnectRelation }.getOrElse(cachedRelation) }.orElse { - val forWrite = "true".equalsIgnoreCase(u.options.get(UnresolvedRelation.FOR_WRITE)) - val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, forWrite) - val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) + val writePrivilegesString = + Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) + val table = CatalogV2Util.loadTable( + catalog, ident, finalTimeTravelSpec, writePrivilegesString) + val loaded = createRelation( + catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => loaded.map { loadedRelation => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 66e0fe77f93de..4c532b5cae618 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, Lo import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.TableWritePrivilege import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Metadata, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -124,20 +125,35 @@ case class UnresolvedRelation( override def name: String = tableName - def forWrite: UnresolvedRelation = { - val newOptions = new java.util.HashMap[String, String] - newOptions.put(UnresolvedRelation.FOR_WRITE, "true") - newOptions.putAll(options) - copy(options = new CaseInsensitiveStringMap(newOptions)) + def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = { + if (privileges.nonEmpty) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(",")) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } + } + + def clearWritePrivileges: UnresolvedRelation = { + if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } } final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION) } object UnresolvedRelation { - // An internal option of `UnresolvedRelation` to indicate that we look up this relation for data - // writing. - val FOR_WRITE = "__for_write__" + // An internal option of `UnresolvedRelation` to specify the required write privileges when + // writing data to this relation. + val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__" def apply( tableIdentifier: TableIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0536f880af632..5f2c6fe14bfa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} @@ -455,7 +455,7 @@ class AstBuilder extends DataTypeAstBuilder = visitInsertIntoTable(table) withIdentClause(relationCtx, ident => { val insertIntoStatement = InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options, forWrite = true), + createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)), partition, cols, query, @@ -473,7 +473,8 @@ class AstBuilder extends DataTypeAstBuilder = visitInsertOverwriteTable(table) withIdentClause(relationCtx, ident => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident, options, forWrite = true), + createUnresolvedRelation(relationCtx, ident, options, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), partition, cols, query, @@ -485,7 +486,8 @@ class AstBuilder extends DataTypeAstBuilder val options = Option(ctx.optionsClause()) withIdentClause(ctx.identifierReference, ident => { OverwriteByExpression.byPosition( - createUnresolvedRelation(ctx.identifierReference, ident, options, forWrite = true), + createUnresolvedRelation(ctx.identifierReference, ident, options, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), query, expression(ctx.whereClause().booleanExpression())) }) @@ -570,7 +572,8 @@ class AstBuilder extends DataTypeAstBuilder override def visitDeleteFromTable( ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val predicate = if (ctx.whereClause() != null) { @@ -582,7 +585,8 @@ class AstBuilder extends DataTypeAstBuilder } override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val assignments = withAssignments(ctx.setClause().assignmentList()) @@ -605,9 +609,6 @@ class AstBuilder extends DataTypeAstBuilder override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { val withSchemaEvolution = ctx.EVOLUTION() != null - val targetTable = createUnresolvedRelation(ctx.target, forWrite = true) - val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") - val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) val sourceTableOrQuery = if (ctx.source != null) { createUnresolvedRelation(ctx.source) @@ -638,7 +639,7 @@ class AstBuilder extends DataTypeAstBuilder s"Unrecognized matched action: ${clause.matchedAction().getText}") } } - } + }.toSeq val notMatchedActions = ctx.notMatchedClause().asScala.map { clause => { if (clause.notMatchedAction().INSERT() != null) { @@ -659,7 +660,7 @@ class AstBuilder extends DataTypeAstBuilder s"Unrecognized matched action: ${clause.notMatchedAction().getText}") } } - } + }.toSeq val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map { clause => { val notMatchedBySourceAction = clause.notMatchedBySourceAction() @@ -674,7 +675,7 @@ class AstBuilder extends DataTypeAstBuilder s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}") } } - } + }.toSeq if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx) } @@ -693,13 +694,19 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx) } + val targetTable = createUnresolvedRelation( + ctx.target, + writePrivileges = MergeIntoTable.getWritePrivileges( + matchedActions, notMatchedActions, notMatchedBySourceActions)) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) MergeIntoTable( aliasedTarget, aliasedSource, mergeCondition, - matchedActions.toSeq, - notMatchedActions.toSeq, - notMatchedBySourceActions.toSeq, + matchedActions, + notMatchedActions, + notMatchedBySourceActions, withSchemaEvolution) } @@ -3117,15 +3124,11 @@ class AstBuilder extends DataTypeAstBuilder private def createUnresolvedRelation( ctx: IdentifierReferenceContext, optionsClause: Option[OptionsClauseContext] = None, - forWrite: Boolean = false): LogicalPlan = withOrigin(ctx) { + writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) { val options = resolveOptions(optionsClause) withIdentClause(ctx, parts => { val relation = new UnresolvedRelation(parts, options, isStreaming = false) - if (forWrite) { - relation.forWrite - } else { - relation - } + relation.requireWritePrivileges(writePrivileges) }) } @@ -3136,14 +3139,10 @@ class AstBuilder extends DataTypeAstBuilder ctx: ParserRuleContext, ident: Seq[String], optionsClause: Option[OptionsClauseContext], - forWrite: Boolean): UnresolvedRelation = withOrigin(ctx) { + writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) { val options = resolveOptions(optionsClause) val relation = new UnresolvedRelation(ident, options, isStreaming = false) - if (forWrite) { - relation.forWrite - } else { - relation - } + relation.requireWritePrivileges(writePrivileges) } private def resolveOptions( @@ -5019,7 +5018,8 @@ class AstBuilder extends DataTypeAstBuilder if (query.isDefined) { CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { - CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None, forWrite = false), + CacheTable( + createUnresolvedRelation(ctx.identifierReference, ident, None, writePrivileges = Nil), ident, isLazy, options) } }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 6339a18796fa0..05628d7b1c98e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -794,6 +794,21 @@ case class MergeIntoTable( copy(targetTable = newLeft, sourceTable = newRight) } +object MergeIntoTable { + def getWritePrivileges( + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = { + val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege] + (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach { + case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE) + case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE) + case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT) + } + privileges.toSeq + } +} + sealed abstract class MergeAction extends Expression with Unevaluable { def condition: Option[Expression] override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 21c8c733fb09c..6698f0a021400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -404,9 +404,9 @@ private[sql] object CatalogV2Util { catalog: CatalogPlugin, ident: Identifier, timeTravelSpec: Option[TimeTravelSpec] = None, - forWrite: Boolean = false): Option[Table] = + writePrivilegesString: Option[String] = None): Option[Table] = try { - Option(getTable(catalog, ident, timeTravelSpec, forWrite)) + Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString)) } catch { case _: NoSuchTableException => None case _: NoSuchDatabaseException => None @@ -416,9 +416,9 @@ private[sql] object CatalogV2Util { catalog: CatalogPlugin, ident: Identifier, timeTravelSpec: Option[TimeTravelSpec] = None, - forWrite: Boolean = false): Table = { + writePrivilegesString: Option[String] = None): Table = { if (timeTravelSpec.nonEmpty) { - assert(!forWrite, "Should not write to a table with time travel") + assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel") timeTravelSpec.get match { case v: AsOfVersion => catalog.asTableCatalog.loadTable(ident, v.version) @@ -426,8 +426,10 @@ private[sql] object CatalogV2Util { catalog.asTableCatalog.loadTable(ident, ts.timestamp) } } else { - if (forWrite) { - catalog.asTableCatalog.loadTableForWrite(ident) + if (writePrivilegesString.isDefined) { + val writePrivileges = writePrivilegesString.get.split(",").map(_.trim) + .map(TableWritePrivilege.valueOf).toSet.asJava + catalog.asTableCatalog.loadTable(ident, writePrivileges) } else { catalog.asTableCatalog.loadTable(ident) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 59602a4c77d08..60daa1e87a8e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -39,7 +39,16 @@ class DDLParserSuite extends AnalysisTest { } private def parseCompare(sql: String, expected: LogicalPlan): Unit = { - comparePlans(parsePlan(sql), expected, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sql).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, expected, checkAnalysis = false) } private def internalException(sqlText: String): SparkThrowable = { @@ -2635,20 +2644,20 @@ class DDLParserSuite extends AnalysisTest { withSQLConf( SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> optimizeInsertIntoValues.toString) { - comparePlans(parsePlan(dateTypeSql), insertPartitionPlan( + parseCompare(dateTypeSql, insertPartitionPlan( "2019-01-02", optimizeInsertIntoValues)) withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { - comparePlans(parsePlan(intervalTypeSql), insertPartitionPlan( + parseCompare(intervalTypeSql, insertPartitionPlan( interval, optimizeInsertIntoValues)) } - comparePlans(parsePlan(ymIntervalTypeSql), insertPartitionPlan( + parseCompare(ymIntervalTypeSql, insertPartitionPlan( "INTERVAL '1-2' YEAR TO MONTH", optimizeInsertIntoValues)) - comparePlans(parsePlan(dtIntervalTypeSql), + parseCompare(dtIntervalTypeSql, insertPartitionPlan( "INTERVAL '1 02:03:04.128462' DAY TO SECOND", optimizeInsertIntoValues)) - comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan( + parseCompare(timestampTypeSql, insertPartitionPlan( timestamp, optimizeInsertIntoValues)) - comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan( + parseCompare(binaryTypeSql, insertPartitionPlan( binaryStr, optimizeInsertIntoValues)) } } @@ -2748,12 +2757,12 @@ class DDLParserSuite extends AnalysisTest { // In each of the following cases, the DEFAULT reference parses as an unresolved attribute // reference. We can handle these cases after the parsing stage, at later phases of analysis. - comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"), + parseCompare("VALUES (1, 2, DEFAULT) AS val", SubqueryAlias("val", UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2), UnresolvedAttribute("DEFAULT")))))) - comparePlans(parsePlan( - "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"), + parseCompare( + "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)", InsertIntoStatement( UnresolvedRelation(Seq("t")), Map("part" -> Some("2019-01-02")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8d01040563361..ba7e7d2c0e427 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -40,7 +40,16 @@ class PlanParserSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sqlCommand).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, plan, checkAnalysis = false) } private def parseException(sqlText: String): SparkThrowable = { @@ -1034,57 +1043,56 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ HINT */ * FROM t"), + assertEqual( + "SELECT /*+ HINT */ * FROM t", UnresolvedHint("HINT", Seq.empty, table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ MAPJOIN(u) */ * FROM t", UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + assertEqual( + "SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t", UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), + assertEqual( + "SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t", UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + assertEqual( + "SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`", UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")), table("default.t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + assertEqual( + "SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a", UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy($"a")($"a")).orderBy($"a".asc)) - comparePlans( - parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + assertEqual( + "SELECT /*+ COALESCE(10) */ * FROM t", UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star()))) - comparePlans( - parsePlan( - "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + assertEqual( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t", InsertIntoStatement(table("s"), Map.empty, Nil, UnresolvedHint("REPARTITION", Seq(Literal(100)), UnresolvedHint("COALESCE", Seq(Literal(500)), UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star())))) @@ -1095,49 +1103,48 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'+'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star())))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star()))))) - comparePlans( - parsePlan( - """ - |SELECT - |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ - |* FROM t - """.stripMargin), + assertEqual( + """ + |SELECT + |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ + |* FROM t + """.stripMargin, UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")), table("t").select(star())))))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 944085e0deb31..60734efbf5bba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog.TableWritePrivilege +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -474,7 +476,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val table = catalog.asTableCatalog.loadTableForWrite(ident) match { + val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match { case _: V1Table => return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => @@ -505,7 +507,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(tableIdent: TableIdentifier): Unit = { runCommand(df.sparkSession) { InsertIntoStatement( - table = UnresolvedRelation(tableIdent).forWrite, + table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges), partitionSpec = Map.empty[String, Option[String]], Nil, query = df.logicalPlan, @@ -514,6 +516,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } + private def getWritePrivileges: Seq[TableWritePrivilege] = mode match { + case SaveMode.Overwrite => Seq(INSERT, DELETE) + case _ => Seq(INSERT) + } + private def getBucketSpec: Option[BucketSpec] = { if (sortColumnNames.isDefined && numBuckets.isEmpty) { throw QueryCompilationErrors.sortByWithoutBucketingError() @@ -611,7 +618,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable( catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = { - val tableOpt = try Option(catalog.loadTableForWrite(ident)) catch { + val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch { case _: NoSuchTableException => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 1de6fa376d6dd..96756d57358ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -169,7 +170,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def append(): Unit = { val append = AppendData.byName( - UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), + logicalPlan, options.toMap) runCommand(append) } @@ -186,7 +188,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName).forWrite, logicalPlan, condition.expr, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, condition.expr, options.toMap) runCommand(overwrite) } @@ -206,7 +209,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, options.toMap) runCommand(dynamicOverwrite) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index 857965f8a2f35..5e67d1bff0805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -188,7 +188,8 @@ class MergeIntoWriter[T] private[sql] ( } val merge = MergeIntoTable( - UnresolvedRelation(tableName).forWrite, + UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges( + matchedActions, notMatchedActions, notMatchedBySourceActions)), logicalPlan, on.expr, matchedActions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 67fb1fe95c9f1..89372017257dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} @@ -84,7 +84,8 @@ case class CreateTableAsSelectExec( } val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), - partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident)) + partitioning.toArray, properties.asJava) + ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -164,7 +165,8 @@ case class ReplaceTableAsSelectExec( } val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), - partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident)) + partitioning.toArray, properties.asJava) + ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) writeToTable(catalog, table, writeOptions, ident, query) } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out index 3aea86b232cba..f9a282c2b927b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val))], Formatted -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out index 3aea86b232cba..f9a282c2b927b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val))], Formatted -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 3830b47ba8a6d..16077a78f3892 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1139,7 +1139,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index c0dee38e6d07a..9d25b829e03fc 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1031,7 +1031,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index b36843f5449c6..a63eaddc2206c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3704,6 +3704,11 @@ class DataSourceV2SQLSuiteV1Filter } test("SPARK-49246: read-only catalog") { + def assertPrivilegeError(f: => Unit, privilege: String): Unit = { + val e = intercept[RuntimeException](f) + assert(e.getMessage.contains(privilege)) + } + def checkWriteOperations(catalog: String): Unit = { withSQLConf(s"spark.sql.catalog.$catalog" -> classOf[ReadOnlyCatalog].getName) { val input = sql("SELECT 1") @@ -3714,42 +3719,47 @@ class DataSourceV2SQLSuiteV1Filter assert(df.collect().isEmpty) assert(df.schema == new StructType().add("i", "int")) - intercept[RuntimeException](sql(s"INSERT INTO $tbl SELECT 1")) - intercept[RuntimeException](sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1")) - intercept[RuntimeException] (sql(s"INSERT OVERWRITE $tbl SELECT 1")) - intercept[RuntimeException] (sql(s"DELETE FROM $tbl WHERE i = 0")) - intercept[RuntimeException] (sql(s"UPDATE $tbl SET i = 0")) - intercept[RuntimeException] { - sql( - s""" - |MERGE INTO $tbl USING (SELECT 1 i) AS source - |ON source.i = $tbl.i - |WHEN NOT MATCHED THEN INSERT * - |""".stripMargin) - } + assertPrivilegeError(sql(s"INSERT INTO $tbl SELECT 1"), "INSERT") + assertPrivilegeError( + sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"INSERT OVERWRITE $tbl SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"DELETE FROM $tbl WHERE i = 0"), "DELETE") + assertPrivilegeError(sql(s"UPDATE $tbl SET i = 0"), "UPDATE") + assertPrivilegeError( + sql(s""" + |MERGE INTO $tbl USING (SELECT 1 i) AS source + |ON source.i = $tbl.i + |WHEN MATCHED THEN UPDATE SET * + |WHEN NOT MATCHED THEN INSERT * + |WHEN NOT MATCHED BY SOURCE THEN DELETE + |""".stripMargin), + "DELETE,INSERT,UPDATE" + ) - intercept[RuntimeException](input.write.insertInto(tbl)) - intercept[RuntimeException](input.write.mode("append").saveAsTable(tbl)) - intercept[RuntimeException](input.writeTo(tbl).append()) - intercept[RuntimeException](input.writeTo(tbl).overwrite(df.col("i") === 1)) - intercept[RuntimeException](input.writeTo(tbl).overwritePartitions()) + assertPrivilegeError(input.write.insertInto(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").insertInto(tbl), "DELETE,INSERT") + assertPrivilegeError(input.write.mode("append").saveAsTable(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").saveAsTable(tbl), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).append(), "INSERT") + assertPrivilegeError(input.writeTo(tbl).overwrite(df.col("i") === 1), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).overwritePartitions(), "DELETE,INSERT") } // Test CTAS withTable(tbl) { - intercept[RuntimeException](sql(s"CREATE TABLE $tbl AS SELECT 1 i")) + assertPrivilegeError(sql(s"CREATE TABLE $tbl AS SELECT 1 i"), "INSERT") } withTable(tbl) { - intercept[RuntimeException](sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i")) + assertPrivilegeError(sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i"), "INSERT") } withTable(tbl) { - intercept[RuntimeException](input.write.saveAsTable(tbl)) + assertPrivilegeError(input.write.saveAsTable(tbl), "INSERT") } withTable(tbl) { - intercept[RuntimeException](input.writeTo(tbl).create()) + assertPrivilegeError(input.writeTo(tbl).create(), "INSERT") } withTable(tbl) { - intercept[RuntimeException](input.writeTo(tbl).createOrReplace()) + assertPrivilegeError(input.writeTo(tbl).createOrReplace(), "INSERT") } } } @@ -3838,7 +3848,10 @@ class ReadOnlyCatalog extends InMemoryCatalog { null } - override def loadTableForWrite(ident: Identifier): Table = { - throw new RuntimeException("cannot write") + override def loadTable( + ident: Identifier, + writePrivileges: util.Set[TableWritePrivilege]): Table = { + throw new RuntimeException("cannot write with " + + writePrivileges.asScala.toSeq.map(_.toString).sorted.mkString(",")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index ebb719a35a8bf..75837c59945fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog @@ -161,6 +161,8 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("cat") newCatalog } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 13735658e6fdf..73bcde1e6e5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCom import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.connector.catalog.{CatalogManager, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, TableWritePrivilege, V1Table} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.errors.QueryExecutionErrors @@ -158,6 +158,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("testcat") newCatalog } @@ -175,6 +177,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn(CatalogManager.SESSION_CATALOG_NAME) newCatalog } From 1a352fe8032c55a32cac799d124276f133826256 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 20 Aug 2024 11:23:13 +0800 Subject: [PATCH 3/3] Apply suggestions from code review --- .../org/apache/spark/sql/connector/catalog/TableCatalog.java | 2 +- .../apache/spark/sql/connector/catalog/TableWritePrivilege.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index 6951d2b18765c..ad4fe743218fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -122,7 +122,7 @@ public interface TableCatalog extends CatalogPlugin { * @return the table's metadata * @throws NoSuchTableException If the table doesn't exist or is a view * - * @since 4.0.0 + * @since 3.5.3 */ default Table loadTable( Identifier ident, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java index 2e6a6a724f11c..ca2d4ba9e7b4e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java @@ -20,7 +20,7 @@ /** * The table write privileges that will be provided when loading a table. * - * @since 4.0.0 + * @since 3.5.3 */ public enum TableWritePrivilege { /**