Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
Expand Down Expand Up @@ -93,7 +93,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper

val notMatchedInstructions = notMatchedActions.map {
case InsertAction(cond, assignments) =>
Keep(cond.getOrElse(TrueLiteral), assignments.map(_.value))
Keep(Insert, cond.getOrElse(TrueLiteral), assignments.map(_.value))
case other =>
throw new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_3053",
Expand Down Expand Up @@ -199,7 +199,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
val keepCarryoverRowsInstruction = Copy(carryoverRowsOutput)
val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput)

val matchedInstructions = matchedActions.map { action =>
toInstruction(action, metadataAttrs)
Expand Down Expand Up @@ -436,7 +436,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)
Keep(Update, cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
Discard(cond.getOrElse(TrueLiteral))
Expand All @@ -445,7 +445,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
val rowValues = assignments.map(_.value)
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)
Keep(Insert, cond.getOrElse(TrueLiteral), output)

case other =>
throw new AnalysisException(
Expand All @@ -471,15 +471,15 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper

case UpdateAction(cond, assignments) =>
val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues)
Keep(cond.getOrElse(TrueLiteral), output)
Keep(Update, cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
Keep(cond.getOrElse(TrueLiteral), output)
Keep(Delete, cond.getOrElse(TrueLiteral), output)

case InsertAction(cond, assignments) =>
val output = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues)
Keep(cond.getOrElse(TrueLiteral), output)
Keep(Insert, cond.getOrElse(TrueLiteral), output)

case other =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, ROW_ID}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.truncatedString
Expand Down Expand Up @@ -88,19 +87,14 @@ object MergeRows {
override def dataType: DataType = NullType
}

// A special case of Keep where the row is kept as is.
case class Copy(output: Seq[Expression]) extends Instruction {
override def condition: Expression = TrueLiteral
override def outputs: Seq[Seq[Expression]] = Seq(output)
override def children: Seq[Expression] = output

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(output = newChildren)
}
}
sealed trait Context
case object Copy extends Context
case object Delete extends Context
case object Insert extends Context
case object Update extends Context

case class Keep(
context: Context,
condition: Expression,
output: Seq[Expression])
extends Instruction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.BasePredicate
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -49,7 +48,21 @@ case class MergeRowsExec(

override lazy val metrics: Map[String, SQLMetric] = Map(
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows copied unmodified because they did not match any action."))
"Number of target rows copied unmodified because they did not match any action"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: It seems like we call it "action" here and "clause" in all other metrics. It would be nice to align.

"numTargetRowsInserted" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows inserted"),
"numTargetRowsDeleted" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows deleted"),
"numTargetRowsUpdated" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows updated"),
"numTargetRowsMatchedUpdated" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows updated by a matched clause"),
"numTargetRowsMatchedDeleted" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows deleted by a matched clause"),
"numTargetRowsNotMatchedBySourceUpdated" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows updated by a not matched by source clause"),
"numTargetRowsNotMatchedBySourceDeleted" -> SQLMetrics.createMetric(sparkContext,
"Number of target rows deleted by a not matched by source clause"))

@transient override lazy val producedAttributes: AttributeSet = {
AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
Expand Down Expand Up @@ -113,11 +126,8 @@ case class MergeRowsExec(

private def planInstructions(instructions: Seq[Instruction]): Seq[InstructionExec] = {
instructions.map {
case Copy(output) =>
CopyExec(createProjection(output))

case Keep(cond, output) =>
KeepExec(createPredicate(cond), createProjection(output))
case Keep(context, cond, output) =>
KeepExec(context, createPredicate(cond), createProjection(output))

case Discard(cond) =>
DiscardExec(createPredicate(cond))
Expand All @@ -136,12 +146,8 @@ case class MergeRowsExec(
def condition: BasePredicate
}

case class CopyExec(projection: Projection) extends InstructionExec {
override lazy val condition: BasePredicate = createPredicate(TrueLiteral)
def apply(row: InternalRow): InternalRow = projection.apply(row)
}

case class KeepExec(
context: Context,
condition: BasePredicate,
projection: Projection) extends InstructionExec {
def apply(row: InternalRow): InternalRow = projection.apply(row)
Expand Down Expand Up @@ -219,9 +225,9 @@ case class MergeRowsExec(

if (isTargetRowPresent && isSourceRowPresent) {
cardinalityValidator.validate(row)
applyInstructions(row, matchedInstructions)
applyInstructions(row, matchedInstructions, sourcePresent = true)
} else if (isSourceRowPresent) {
applyInstructions(row, notMatchedInstructions)
applyInstructions(row, notMatchedInstructions, sourcePresent = true)
} else if (isTargetRowPresent) {
applyInstructions(row, notMatchedBySourceInstructions)
} else {
Expand All @@ -231,23 +237,32 @@ case class MergeRowsExec(

private def applyInstructions(
row: InternalRow,
instructions: Seq[InstructionExec]): InternalRow = {
instructions: Seq[InstructionExec],
sourcePresent: Boolean = false): InternalRow = {

for (instruction <- instructions) {
if (instruction.condition.eval(row)) {
instruction match {
case copy: CopyExec =>
// group-based operations copy over target rows that didn't match any actions
longMetric("numTargetRowsCopied") += 1
return copy.apply(row)

case keep: KeepExec =>
keep.context match {
case Copy => incrementCopyMetric()
case Update => incrementUpdateMetric(sourcePresent)
case Insert => incrementInsertMetric()
case Delete => incrementDeleteMetric(sourcePresent)
case _ => throw new IllegalArgumentException(
s"Unexpected context for KeepExec: ${keep.context}")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Shall we drop the empty line?


return keep.apply(row)

case _: DiscardExec =>
incrementDeleteMetric(sourcePresent)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Shall we drop the empty line?

return null

case split: SplitExec =>
incrementUpdateMetric(sourcePresent)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Shall we drop the empty line?

cachedExtraRow = split.projectExtraRow(row)
return split.projectRow(row)
}
Expand All @@ -257,4 +272,27 @@ case class MergeRowsExec(
null
}
}

// For group based merge, copy is inserted if row matches no other case
private def incrementCopyMetric(): Unit = longMetric("numTargetRowsCopied") += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the cost of doing this per each row? I know we implement some tricks for regular writes to update metrics only once per 100 rows. Do we need to worry about it here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good point. We should do the same here.

Copy link
Member Author

@szehon-ho szehon-ho Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that one, I took a closer look. It look like it was from this discussion: #31451 (comment). In that case there was concern to get the metric from the DSV2 connector, to avoid calling currentMetricsValue so many times because the external implementation can be heavy.

In our case, getting the metric is in memory, so it should be quick.

On the sending end, it looks like SQLMetric is an accumulator and updating it just sets an in memory value as well.

So at first glance, I think adding complexity to update the metric per 100 rows may not be worth it. But I may be missing something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to make sense to me. @cloud-fan, do you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I agree


private def incrementInsertMetric(): Unit = longMetric("numTargetRowsInserted") += 1

private def incrementDeleteMetric(sourcePresent: Boolean): Unit = {
longMetric("numTargetRowsDeleted") += 1
if (sourcePresent) {
longMetric("numTargetRowsMatchedDeleted") += 1
} else {
longMetric("numTargetRowsNotMatchedBySourceDeleted") += 1
}
}

private def incrementUpdateMetric(sourcePresent: Boolean): Unit = {
longMetric("numTargetRowsUpdated") += 1
if (sourcePresent) {
longMetric("numTargetRowsMatchedUpdated") += 1
} else {
longMetric("numTargetRowsNotMatchedBySourceUpdated") += 1
}
}
}
Loading