From 69f9efe663c88fa34c227ff9d433f69b59b605b6 Mon Sep 17 00:00:00 2001 From: wayneli-vt Date: Sun, 28 Dec 2025 18:10:48 +0800 Subject: [PATCH] bug fix --- .../spark/sql/RowTrackingTestBase.scala | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala index 137bbe76d759..204ca7fcd431 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala @@ -22,12 +22,11 @@ import org.apache.paimon.Snapshot.CommitKind import org.apache.paimon.spark.PaimonSparkTestBase import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Join, LogicalPlan, MergeRows, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -import scala.collection.JavaConverters._ -import scala.collection.mutable +import java.util.concurrent.{CountDownLatch, TimeUnit} abstract class RowTrackingTestBase extends PaimonSparkTestBase { @@ -397,13 +396,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { sql( "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')") - val capturedPlans: mutable.ListBuffer[LogicalPlan] = mutable.ListBuffer.empty + var findSplitsPlan: LogicalPlan = null + val latch = new CountDownLatch(1) val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - capturedPlans += qe.analyzed + if (qe.analyzed.collectFirst { case _: Deduplicate => true }.nonEmpty) { + latch.countDown() + findSplitsPlan = qe.analyzed + } } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - capturedPlans += qe.analyzed + if (qe.analyzed.collectFirst { case _: Deduplicate => true }.nonEmpty) { + latch.countDown() + findSplitsPlan = qe.analyzed + } } } spark.listenerManager.register(listener) @@ -416,9 +422,10 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { |WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES (target_ROW_ID, b * 1.1, c) |WHEN NOT MATCHED THEN INSERT (a, b, c) VALUES (target_ROW_ID, b, c) |""".stripMargin) + assert(latch.await(10, TimeUnit.SECONDS), "await timeout") // Assert that no Join operator was used during // `org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.targetRelatedSplits` - assert(capturedPlans.head.collect { case plan: Join => plan }.isEmpty) + assert(findSplitsPlan != null && findSplitsPlan.collect { case plan: Join => plan }.isEmpty) spark.listenerManager.unregister(listener) checkAnswer( @@ -442,13 +449,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { sql( "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')") - val capturedPlans = new java.util.concurrent.CopyOnWriteArrayList[LogicalPlan]() + var updatePlan: LogicalPlan = null + val latch = new CountDownLatch(1) val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - capturedPlans.add(qe.analyzed) + if (qe.analyzed.collectFirst { case _: MergeRows => true }.nonEmpty) { + latch.countDown() + updatePlan = qe.analyzed + } } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - capturedPlans.add(qe.analyzed) + if (qe.analyzed.collectFirst { case _: MergeRows => true }.nonEmpty) { + latch.countDown() + updatePlan = qe.analyzed + } } } spark.listenerManager.register(listener) @@ -460,17 +474,17 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { |WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b * 3, |c = concat(target.c, source.c) |""".stripMargin).collect() + assert(latch.await(10, TimeUnit.SECONDS), "await timeout") // Assert no shuffle/join/sort was used in // 'org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.updateActionInvoke' assert( - capturedPlans.asScala.forall( - plan => - plan.collectFirst { - case p: Join => p - case p: Sort => p - case p: RepartitionByExpression => p - }.isEmpty), - s"Found unexpected Join/Sort/Exchange in plan:\n$capturedPlans" + updatePlan != null && + updatePlan.collectFirst { + case p: Join => p + case p: Sort => p + case p: RepartitionByExpression => p + }.isEmpty, + s"Found unexpected Join/Sort/Exchange in plan: $updatePlan" ) spark.listenerManager.unregister(listener)