diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 6753e66c96c5e..5ac853b858ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} -import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} @@ -93,7 +93,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper val notMatchedInstructions = notMatchedActions.map { case InsertAction(cond, assignments) => - Keep(cond.getOrElse(TrueLiteral), assignments.map(_.value)) + Keep(Insert, cond.getOrElse(TrueLiteral), assignments.map(_.value)) case other => throw new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_3053", @@ -199,7 +199,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // as the last MATCHED and NOT MATCHED BY SOURCE instruction // this logic is specific to data sources that replace groups of data val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output - val keepCarryoverRowsInstruction = Copy(carryoverRowsOutput) + val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput) val matchedInstructions = matchedActions.map { action => toInstruction(action, metadataAttrs) @@ -436,7 +436,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper val rowValues = assignments.map(_.value) val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues - Keep(cond.getOrElse(TrueLiteral), output) + Keep(Update, cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => Discard(cond.getOrElse(TrueLiteral)) @@ -445,7 +445,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper val rowValues = assignments.map(_.value) val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues - Keep(cond.getOrElse(TrueLiteral), output) + Keep(Insert, cond.getOrElse(TrueLiteral), output) case other => throw new AnalysisException( @@ -471,15 +471,15 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case UpdateAction(cond, assignments) => val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues) - Keep(cond.getOrElse(TrueLiteral), output) + Keep(Update, cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues) - Keep(cond.getOrElse(TrueLiteral), output) + Keep(Delete, cond.getOrElse(TrueLiteral), output) case InsertAction(cond, assignments) => val output = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues) - Keep(cond.getOrElse(TrueLiteral), output) + Keep(Insert, cond.getOrElse(TrueLiteral), output) case other => throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala index a1e9058d97bcc..3730e3d16e471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable} -import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, ROW_ID} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.truncatedString @@ -88,19 +87,14 @@ object MergeRows { override def dataType: DataType = NullType } - // A special case of Keep where the row is kept as is. - case class Copy(output: Seq[Expression]) extends Instruction { - override def condition: Expression = TrueLiteral - override def outputs: Seq[Seq[Expression]] = Seq(output) - override def children: Seq[Expression] = output - - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): Expression = { - copy(output = newChildren) - } - } + sealed trait Context + case object Copy extends Context + case object Delete extends Context + case object Insert extends Context + case object Update extends Context case class Keep( + context: Context, condition: Expression, output: Seq[Expression]) extends Instruction { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index eb135a48ed7a5..cc3c74c9c88e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -26,11 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.BasePredicate import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.Projection import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split} +import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan @@ -49,7 +48,21 @@ case class MergeRowsExec( override lazy val metrics: Map[String, SQLMetric] = Map( "numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext, - "Number of target rows copied unmodified because they did not match any action.")) + "Number of target rows copied unmodified because they did not match any action"), + "numTargetRowsInserted" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows inserted"), + "numTargetRowsDeleted" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows deleted"), + "numTargetRowsUpdated" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows updated"), + "numTargetRowsMatchedUpdated" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows updated by a matched clause"), + "numTargetRowsMatchedDeleted" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows deleted by a matched clause"), + "numTargetRowsNotMatchedBySourceUpdated" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows updated by a not matched by source clause"), + "numTargetRowsNotMatchedBySourceDeleted" -> SQLMetrics.createMetric(sparkContext, + "Number of target rows deleted by a not matched by source clause")) @transient override lazy val producedAttributes: AttributeSet = { AttributeSet(output.filterNot(attr => inputSet.contains(attr))) @@ -113,11 +126,8 @@ case class MergeRowsExec( private def planInstructions(instructions: Seq[Instruction]): Seq[InstructionExec] = { instructions.map { - case Copy(output) => - CopyExec(createProjection(output)) - - case Keep(cond, output) => - KeepExec(createPredicate(cond), createProjection(output)) + case Keep(context, cond, output) => + KeepExec(context, createPredicate(cond), createProjection(output)) case Discard(cond) => DiscardExec(createPredicate(cond)) @@ -136,12 +146,8 @@ case class MergeRowsExec( def condition: BasePredicate } - case class CopyExec(projection: Projection) extends InstructionExec { - override lazy val condition: BasePredicate = createPredicate(TrueLiteral) - def apply(row: InternalRow): InternalRow = projection.apply(row) - } - case class KeepExec( + context: Context, condition: BasePredicate, projection: Projection) extends InstructionExec { def apply(row: InternalRow): InternalRow = projection.apply(row) @@ -219,9 +225,9 @@ case class MergeRowsExec( if (isTargetRowPresent && isSourceRowPresent) { cardinalityValidator.validate(row) - applyInstructions(row, matchedInstructions) + applyInstructions(row, matchedInstructions, sourcePresent = true) } else if (isSourceRowPresent) { - applyInstructions(row, notMatchedInstructions) + applyInstructions(row, notMatchedInstructions, sourcePresent = true) } else if (isTargetRowPresent) { applyInstructions(row, notMatchedBySourceInstructions) } else { @@ -231,23 +237,32 @@ case class MergeRowsExec( private def applyInstructions( row: InternalRow, - instructions: Seq[InstructionExec]): InternalRow = { + instructions: Seq[InstructionExec], + sourcePresent: Boolean = false): InternalRow = { for (instruction <- instructions) { if (instruction.condition.eval(row)) { instruction match { - case copy: CopyExec => - // group-based operations copy over target rows that didn't match any actions - longMetric("numTargetRowsCopied") += 1 - return copy.apply(row) - case keep: KeepExec => + keep.context match { + case Copy => incrementCopyMetric() + case Update => incrementUpdateMetric(sourcePresent) + case Insert => incrementInsertMetric() + case Delete => incrementDeleteMetric(sourcePresent) + case _ => throw new IllegalArgumentException( + s"Unexpected context for KeepExec: ${keep.context}") + } + return keep.apply(row) case _: DiscardExec => + incrementDeleteMetric(sourcePresent) + return null case split: SplitExec => + incrementUpdateMetric(sourcePresent) + cachedExtraRow = split.projectExtraRow(row) return split.projectRow(row) } @@ -257,4 +272,27 @@ case class MergeRowsExec( null } } + + // For group based merge, copy is inserted if row matches no other case + private def incrementCopyMetric(): Unit = longMetric("numTargetRowsCopied") += 1 + + private def incrementInsertMetric(): Unit = longMetric("numTargetRowsInserted") += 1 + + private def incrementDeleteMetric(sourcePresent: Boolean): Unit = { + longMetric("numTargetRowsDeleted") += 1 + if (sourcePresent) { + longMetric("numTargetRowsMatchedDeleted") += 1 + } else { + longMetric("numTargetRowsNotMatchedBySourceDeleted") += 1 + } + } + + private def incrementUpdateMetric(sourcePresent: Boolean): Unit = { + longMetric("numTargetRowsUpdated") += 1 + if (sourcePresent) { + longMetric("numTargetRowsMatchedUpdated") += 1 + } else { + longMetric("numTargetRowsNotMatchedBySourceUpdated") += 1 + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index c39ea7a0e8620..21b171bee9614 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -1797,6 +1797,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 2) + assertMetric(mergeExec, "numTargetRowsInserted", 0) + assertMetric(mergeExec, "numTargetRowsUpdated", 1) + assertMetric(mergeExec, "numTargetRowsDeleted", 0) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 1) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 0) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -1834,6 +1841,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } assertMetric(mergeExec, "numTargetRowsCopied", 0) + assertMetric(mergeExec, "numTargetRowsInserted", 1) + assertMetric(mergeExec, "numTargetRowsUpdated", 0) + assertMetric(mergeExec, "numTargetRowsDeleted", 0) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 0) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 0) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -1845,7 +1859,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("Merge metrics with matched and not matched by source clauses") { + test("Merge metrics with matched and not matched by source clauses: update") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -1871,6 +1885,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3) + assertMetric(mergeExec, "numTargetRowsInserted", 0) + assertMetric(mergeExec, "numTargetRowsUpdated", 2) + assertMetric(mergeExec, "numTargetRowsDeleted", 0) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 1) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 1) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 0) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -1883,7 +1904,53 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("Merge metrics with matched, not matched, and not matched by source clauses") { + test("Merge metrics with matched and not matched by source clauses: delete") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + val sourceDF = Seq(1, 2, 10).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + val mergeExec = findMergeExec { + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND salary < 200 THEN + | DELETE + |WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN + | DELETE + |""".stripMargin + } + + + assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3) + assertMetric(mergeExec, "numTargetRowsInserted", 0) + assertMetric(mergeExec, "numTargetRowsUpdated", 0) + assertMetric(mergeExec, "numTargetRowsDeleted", 2) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 0) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 1) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + // Row(1, 100, "hr") deleted + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "marketing")) + // Row(5, 500, "executive") deleted + ) + } + } + + test("Merge metrics with matched, not matched, and not matched by source clauses: update") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -1910,6 +1977,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3) + assertMetric(mergeExec, "numTargetRowsInserted", 1) + assertMetric(mergeExec, "numTargetRowsUpdated", 2) + assertMetric(mergeExec, "numTargetRowsDeleted", 0) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 1) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 1) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 0) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -1923,6 +1997,54 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge metrics with matched, not matched, and not matched by source clauses: delete") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "marketing" } + |{ "pk": 5, "salary": 500, "dep": "executive" } + |""".stripMargin) + + val sourceDF = Seq(1, 2, 6, 10).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + val mergeExec = findMergeExec { + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND salary < 200 THEN + | DELETE + |WHEN NOT MATCHED AND s.pk < 10 THEN + | INSERT (pk, salary, dep) VALUES (s.pk, -1, "dummy") + |WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN + | DELETE + |""".stripMargin + } + + assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3) + assertMetric(mergeExec, "numTargetRowsInserted", 1) + assertMetric(mergeExec, "numTargetRowsUpdated", 0) + assertMetric(mergeExec, "numTargetRowsDeleted", 2) + assertMetric(mergeExec, "numTargetRowsMatchedUpdated", 0) + assertMetric(mergeExec, "numTargetRowsMatchedDeleted", 1) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 0) + assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 1) + + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + // Row(1, 100, "hr") deleted + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "marketing"), + // Row(5, 500, "executive") deleted + Row(6, -1, "dummy"))) // inserted + } + } + private def findMergeExec(query: String): MergeRowsExec = { val plan = executeAndKeepPlan { sql(query)