From b66a40c94508e9858dee71dbcacc314396a2f085 Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 10 May 2025 15:15:14 +0000 Subject: [PATCH] update --- ...oadcastNestedLoopJoinExecTransformer.scala | 26 +++++++- cpp/velox/substrait/SubstraitToVeloxPlan.cc | 8 +++ .../SubstraitToVeloxPlanValidator.cc | 1 + ...oadcastNestedLoopJoinExecTransformer.scala | 17 +++-- .../apache/gluten/utils/SubstraitUtil.scala | 4 +- .../joins/GlutenExistenceJoinSuite.scala | 65 ++++++++++++++++++- 6 files changed, 113 insertions(+), 8 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala index 8517422698d9..0d883f9750fd 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastNestedLoopJoinExecTransformer.scala @@ -16,14 +16,19 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.ValidationResult + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, JoinType} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import com.google.protobuf.StringValue + case class VeloxBroadcastNestedLoopJoinExecTransformer( left: SparkPlan, right: SparkPlan, @@ -51,4 +56,23 @@ case class VeloxBroadcastNestedLoopJoinExecTransformer( newRight: SparkPlan): VeloxBroadcastNestedLoopJoinExecTransformer = copy(left = newLeft, right = newRight) + override def genJoinParameters(): com.google.protobuf.Any = { + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("isExistenceJoin=") + .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + + override def backendSpecificJoinValidation(): Option[ValidationResult] = { + joinType match { + case ExistenceJoin(_) => Some(ValidationResult.succeeded) + case _ => None + } + } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 1304cb511b55..d4ec14d08584 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -346,6 +346,14 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT: joinType = core::JoinType::kLeft; break; + case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI: + if (crossRel.has_advanced_extension() && + SubstraitParser::configSetInOptimization(crossRel.advanced_extension(), "isExistenceJoin=")) { + joinType = core::JoinType::kLeftSemiProject; + } else { + VELOX_NYI("Unsupported Join type: {}", std::to_string(crossRel.type())); + } + break; case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_OUTER: joinType = core::JoinType::kFull; break; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 32997d9fd780..06be83856840 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -1083,6 +1083,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR switch (crossRel.type()) { case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER: case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI: break; case ::substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: if (crossRel.has_expression()) { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 5f3d2889eeda..d0cbfd4aa84e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.joins.BaseJoinExec @@ -87,6 +87,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case FullOuter => @@ -108,7 +110,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( case BuildRight => joinType match { case _: InnerLike => left.outputPartitioning - case LeftOuter => left.outputPartitioning + case LeftOuter | ExistenceJoin(_) => left.outputPartitioning case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType with building right side") @@ -177,6 +179,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( ValidationResult.failed( s"FullOuter join with join condition is not supported with BroadcastNestedLoopJoin") } + case ExistenceJoin(_) => + backendSpecificJoinValidation().getOrElse( + ValidationResult.failed("ExistenceJoin not is not supported for this backend.") + ) case _ => ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin") } @@ -186,12 +192,15 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } (joinType, buildSide) match { - case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => + case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) | (ExistenceJoin(_), BuildLeft) => ValidationResult.failed(s"$joinType join is not supported with $buildSide") - case _ => ValidationResult.succeeded // continue + case _ => + ValidationResult.succeeded // continue } } + protected def backendSpecificJoinValidation(): Option[ValidationResult] = None + override protected def doValidateInternal(): ValidationResult = { if (!GlutenConfig.get.broadcastNestedLoopJoinTransformerTransformerEnabled) { return ValidationResult.failed( diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 80b3a365b85a..9e1117085214 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -23,7 +23,7 @@ import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import io.substrait.proto.{CrossRel, JoinRel, NamedStruct, Type} @@ -59,7 +59,7 @@ object SubstraitUtil { // the left and right relations are exchanged and the // join type is reverted. CrossRel.JoinType.JOIN_TYPE_LEFT - case LeftSemi => + case LeftSemi | ExistenceJoin(_) => CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI case FullOuter => CrossRel.JoinType.JOIN_TYPE_OUTER diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/joins/GlutenExistenceJoinSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/joins/GlutenExistenceJoinSuite.scala index 309af61a43ae..f44ffddeebc1 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/joins/GlutenExistenceJoinSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/joins/GlutenExistenceJoinSuite.scala @@ -16,6 +16,69 @@ */ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.GlutenSQLTestsBaseTrait +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.ExistenceJoin +import org.apache.spark.sql.catalyst.plans.logical.{JoinHint, _} +import org.apache.spark.sql.types._ -class GlutenExistenceJoinSuite extends ExistenceJoinSuite with GlutenSQLTestsBaseTrait {} +class GlutenExistenceJoinSuite extends ExistenceJoinSuite with GlutenSQLTestsBaseTrait { + + test("test existence join with broadcast nested loop join") { + + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") + spark.conf.set("spark.sql.join.preferSortMergeJoin", "false") + + val left: DataFrame = spark.createDataFrame( + sparkContext.parallelize( + Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c") + )), + new StructType().add("id", IntegerType).add("val", StringType) + ) + + val right: DataFrame = spark.createDataFrame( + sparkContext.parallelize( + Seq( + Row(1, "x"), + Row(3, "y") + )), + new StructType().add("id", IntegerType).add("val", StringType) + ) + + val leftPlan = left.logicalPlan + val rightPlan = right.logicalPlan + + val existsAttr = AttributeReference("exists", BooleanType, nullable = false)() + + val joinCondition: Expression = LessThan(leftPlan.output(0), rightPlan.output(0)) + + val existenceJoin = Join( + left = leftPlan, + right = rightPlan, + joinType = ExistenceJoin(existsAttr), + condition = Some(joinCondition), + hint = JoinHint.NONE + ) + + val project = Project( + projectList = leftPlan.output :+ existsAttr, + child = existenceJoin + ) + + val df = Dataset.ofRows(spark, project) + + assert(existenceJoin.joinType == ExistenceJoin(existsAttr)) + assert(existenceJoin.condition.contains(joinCondition)) + val expected = Seq( + Row(1, "a", true), + Row(2, "b", true), + Row(3, "c", false) + ) + assert(df.collect() === expected) + + } +}