Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
override def validateJoinTypeAndBuildSide(): ValidationResult = {
joinType match {
case _: InnerLike =>
case ExistenceJoin(_) =>
return ValidationResult.failed("ExistenceJoin is not supported for CH backend.")
case _ =>
if (joinType == LeftSemi || condition.isDefined) {
return ValidationResult.failed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, JoinType}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.StringValue

case class VeloxBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
Expand Down Expand Up @@ -51,4 +55,16 @@ case class VeloxBroadcastNestedLoopJoinExecTransformer(
newRight: SparkPlan): VeloxBroadcastNestedLoopJoinExecTransformer =
copy(left = newLeft, right = newRight)

override def genJoinParameters(): com.google.protobuf.Any = {
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins

import org.apache.gluten.execution.{VeloxBroadcastNestedLoopJoinExecTransformer, VeloxWholeStageTransformerSuite}

import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._

class GlutenExistenceJoinSuite extends VeloxWholeStageTransformerSuite with SQLTestUtils {

override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

test("existence join with broadcast nested loop join") {

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")

val left: DataFrame = spark.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c")
)),
new StructType().add("id", IntegerType).add("val", StringType)
)

val right: DataFrame = spark.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "x"),
Row(3, "y")
)),
new StructType().add("id", IntegerType).add("val", StringType)
)

val leftPlan = left.logicalPlan
val rightPlan = right.logicalPlan

val existsAttr = AttributeReference("exists", BooleanType, nullable = false)()

val joinCondition: Expression = LessThan(leftPlan.output(0), rightPlan.output(0))

val existenceJoin = Join(
left = leftPlan,
right = rightPlan,
joinType = ExistenceJoin(existsAttr),
condition = Some(joinCondition),
hint = JoinHint.NONE
)

val project = Project(
projectList = leftPlan.output :+ existsAttr,
child = existenceJoin
)

val df = Dataset.ofRows(spark, project)

assert(existenceJoin.joinType == ExistenceJoin(existsAttr))
assert(existenceJoin.condition.contains(joinCondition))
val expected = Seq(
Row(1, "a", true),
Row(2, "b", true),
Row(3, "c", false)
)
assert(df.collect() === expected)
val count = collect(df.queryExecution.executedPlan) {
case _: VeloxBroadcastNestedLoopJoinExecTransformer => true
}.size

assert(count == 1, s"Expected 1 VeloxBroadcastNestedLoopJoinExecTransformer, but found $count")
}
}
8 changes: 8 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI:
if (crossRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(crossRel.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
VELOX_NYI("Unsupported Join type: {}", std::to_string(crossRel.type()));
}
break;
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_OUTER:
joinType = core::JoinType::kFull;
break;
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR
switch (crossRel.type()) {
case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI:
break;
case ::substrait::CrossRel_JoinType_JOIN_TYPE_OUTER:
if (crossRel.has_expression()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.utils.SubstraitUtil

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
import org.apache.spark.sql.execution.joins.BaseJoinExec
Expand Down Expand Up @@ -87,6 +87,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case j: ExistenceJoin =>
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case FullOuter =>
Expand All @@ -108,7 +110,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
case BuildRight =>
joinType match {
case _: InnerLike => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case LeftOuter | ExistenceJoin(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType with building right side")
Expand Down Expand Up @@ -177,6 +179,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
ValidationResult.failed(
s"FullOuter join with join condition is not supported with BroadcastNestedLoopJoin")
}
case ExistenceJoin(_) =>
ValidationResult.succeeded
case _ =>
ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin")
}
Expand All @@ -186,9 +190,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
}

(joinType, buildSide) match {
case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) =>
case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) | (ExistenceJoin(_), BuildLeft) =>
ValidationResult.failed(s"$joinType join is not supported with $buildSide")
case _ => ValidationResult.succeeded // continue
case _ =>
ValidationResult.succeeded
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}

import io.substrait.proto.{CrossRel, JoinRel, NamedStruct, Type}

Expand Down Expand Up @@ -59,7 +59,7 @@ object SubstraitUtil {
// the left and right relations are exchanged and the
// join type is reverted.
CrossRel.JoinType.JOIN_TYPE_LEFT
case LeftSemi =>
case LeftSemi | ExistenceJoin(_) =>
CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI
case FullOuter =>
CrossRel.JoinType.JOIN_TYPE_OUTER
Expand Down
Loading