From 8078466aa24d09e64e7aa70f13da71a658d9018f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Mar 2024 10:20:01 -0700 Subject: [PATCH 01/13] feat: Support HashJoin --- core/src/execution/datafusion/planner.rs | 153 +++++++++++++++++- core/src/execution/operators/copy.rs | 2 +- core/src/execution/proto/operator.proto | 19 +++ .../comet/CometSparkSessionExtensions.scala | 26 ++- .../apache/comet/serde/QueryPlanSerde.scala | 48 +++++- .../apache/spark/sql/comet/operators.scala | 43 ++++- .../apache/comet/exec/CometExecSuite.scala | 44 +++++ 7 files changed, 327 insertions(+), 8 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ef2787f83b..01c79e0c07 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -17,7 +17,7 @@ //! Converts Spark physical plan to DataFusion physical plan -use std::{str::FromStr, sync::Arc}; +use std::{collections::HashMap, str::FromStr, sync::Arc}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::{ @@ -37,13 +37,17 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, filter::FilterExec, + joins::{utils::JoinFilter, HashJoinExec, PartitionMode}, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, ExecutionPlan, Partitioning, }, }; -use datafusion_common::ScalarValue; +use datafusion_common::{ + tree_node::{TreeNode, VisitRecursion}, + JoinType as DFJoinType, ScalarValue, +}; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; @@ -76,7 +80,7 @@ use crate::{ agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr, ScalarFunc, }, - spark_operator::{operator::OpStruct, Operator}, + spark_operator::{operator::OpStruct, JoinType, Operator}, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; @@ -858,6 +862,107 @@ impl PhysicalPlanner { Arc::new(CometExpandExec::new(projections, child, schema)), )) } + OpStruct::HashJoin(join) => { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + + let left_join_exprs: Vec<_> = join + .left_join_keys + .iter() + .map(|expr| self.create_expr(expr, left.schema())) + .collect::, _>>()?; + let right_join_exprs: Vec<_> = join + .right_join_keys + .iter() + .map(|expr| self.create_expr(expr, right.schema())) + .collect::, _>>()?; + + let join_on = left_join_exprs + .into_iter() + .zip(right_join_exprs) + .collect::>(); + + let join_type = match join.join_type.try_into() { + Ok(JoinType::Inner) => DFJoinType::Inner, + Ok(JoinType::LeftOuter) => DFJoinType::Left, + Ok(JoinType::RightOuter) => DFJoinType::Right, + Ok(JoinType::FullOuter) => DFJoinType::Full, + Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, + Ok(JoinType::RightSemi) => DFJoinType::RightSemi, + Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, + Ok(JoinType::RightAnti) => DFJoinType::RightAnti, + Err(_) => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported join type: {:?}", + join.join_type + ))); + } + }; + + // Handle join filter as DataFusion `JoinFilter` struct + let join_filter = if let Some(expr) = &join.condition { + let physical_expr = self.create_expr(expr, left.schema())?; + let (left_field_indices, right_field_indices) = expr_to_columns( + &physical_expr, + left.schema().fields.len(), + right.schema().fields.len(), + )?; + let column_indices = JoinFilter::build_column_indices( + left_field_indices.clone(), + right_field_indices.clone(), + ); + + let filter_fields: Vec = left_field_indices + .into_iter() + .map(|i| left.schema().field(i).clone()) + .chain( + right_field_indices + .into_iter() + .map(|i| right.schema().field(i).clone()), + ) + .collect_vec(); + + let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); + + Some(JoinFilter::new( + physical_expr, + column_indices, + filter_schema, + )) + } else { + None + }; + + // DataFusion `HashJoinExec` operator keeps the input batch internally. We need + // to copy the input batch to avoid the data corruption from reusing the input + // batch. + let left = if !is_op_do_copying(&left) { + Arc::new(CopyExec::new(left)) + } else { + left + }; + + let right = if !is_op_do_copying(&right) { + Arc::new(CopyExec::new(right)) + } else { + right + }; + + let join = Arc::new(HashJoinExec::try_new( + left, + right, + join_on, + join_filter, + &join_type, + PartitionMode::Partitioned, + false, + )?); + + Ok((left_scans, join)) + } } } @@ -1026,6 +1131,48 @@ impl From for DataFusionError { } } +/// Returns true if given operator copies input batch to avoid data corruption from reusing +/// input arrays. +fn is_op_do_copying(op: &Arc) -> bool { + op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() +} + +/// Collects the indices of the columns in the input schema that are used in the expression +/// and returns them as a pair of vectors, one for the left side and one for the right side. +fn expr_to_columns( + expr: &Arc, + left_field_len: usize, + right_field_len: usize, +) -> Result<(Vec, Vec), ExecutionError> { + let mut left_field_indices: Vec = vec![]; + let mut right_field_indices: Vec = vec![]; + + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + if column.index() > left_field_len + right_field_len { + return Err(DataFusionError::Internal(format!( + "Column index {} out of range", + column.index() + ))); + } else if column.index() < left_field_len { + left_field_indices.push(column.index()); + } else { + right_field_indices.push(column.index() - left_field_len); + } + } + VisitRecursion::Continue + }) + })?; + + left_field_indices.sort(); + right_field_indices.sort(); + + Ok((left_field_indices, right_field_indices)) +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/core/src/execution/operators/copy.rs b/core/src/execution/operators/copy.rs index 996db2b470..699ccf7ae7 100644 --- a/core/src/execution/operators/copy.rs +++ b/core/src/execution/operators/copy.rs @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec { } fn children(&self) -> Vec> { - self.input.children() + vec![self.input.clone()] } fn with_new_children( diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 5b07cb30b1..ce58edd0f4 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -40,6 +40,7 @@ message Operator { Limit limit = 105; ShuffleWriter shuffle_writer = 106; Expand expand = 107; + HashJoin hash_join = 108; } } @@ -87,3 +88,21 @@ message Expand { repeated spark.spark_expression.Expr project_list = 1; int32 num_expr_per_project = 3; } + +message HashJoin { + repeated spark.spark_expression.Expr left_join_keys = 1; + repeated spark.spark_expression.Expr right_join_keys = 2; + JoinType join_type = 3; + optional spark.spark_expression.Expr condition = 4; +} + +enum JoinType { + Inner = 0; + LeftOuter = 1; + RightOuter = 2; + FullOuter = 3; + LeftSemi = 4; + RightSemi = 5; + LeftAnti = 6; + RightAnti = 7; +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 39c83ae53d..7eca995579 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -337,6 +338,27 @@ class CometSparkSessionExtensions op } + case op: ShuffledHashJoinExec + if isCometOperatorEnabled(conf, "hash_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometHashJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => @@ -576,7 +598,9 @@ object CometSparkSessionExtensions extends Logging { private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = { val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled" - conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) + val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled" + conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) && + !conf.getConfString(operatorDisabledFlag, "false").toBoolean } private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5da926e38e..b0692ad8e2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -25,7 +25,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} @@ -35,6 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +44,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} import org.apache.comet.shims.ShimQueryPlanSerde /** @@ -1838,6 +1840,48 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } + case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") => + if (join.buildSide == BuildRight) { + // DataFusion HashJoin assumes build side is always left. + // TODO: support BuildRight + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => return None // Spark doesn't support other join types + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + if (leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.HashJoin + .newBuilder() + .setJoinType(joinType) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + condition.foreach(joinBuilder.setCondition) + Some(result.setHashJoin(joinBuilder).build()) + } else { + None + } + case op if isCometSink(op) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 5551ffdbcb..2d03b2e65d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -31,9 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -324,6 +326,8 @@ abstract class CometNativeExec extends CometExec { abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode +abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode + /** * Represents the serialized plan of Comet native operators. Only the first operator in a block of * continuous Comet native operators has defined plan bytes which contains the serialization of @@ -584,6 +588,43 @@ case class CometHashAggregateExec( Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) } +case class CometHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometHashJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, left, right) +} + case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 6a34d4fe4a..c21fd109e4 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -58,6 +58,50 @@ class CometExecSuite extends CometTestBase { } } + test("HashJoin without join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df1) + + // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. + // We need to investigate why this happens and fix it. + /* + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + val df3 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + */ + + val df4 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df4) + + val df5 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df5) + + val df6 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df6) + + val df7 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df7) + } + } + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) From fefcf0145a6b0cc9d7ac15aa7d7858b127cd3c66 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 12:41:39 -0700 Subject: [PATCH 02/13] Add comment --- .../scala/org/apache/comet/exec/CometExecSuite.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c21fd109e4..a8669ef6c3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -71,7 +71,8 @@ class CometExecSuite extends CometTestBase { checkSparkAnswerAndOperator(df1) // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. - // We need to investigate why this happens and fix it. + // Left join with build left and right join with build right in hash join is only supported + // in Spark 3.5 or above. See SPARK-36612. /* val df2 = sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") @@ -82,6 +83,14 @@ class CometExecSuite extends CometTestBase { checkSparkAnswerAndOperator(df3) */ + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + val df3 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + val df4 = sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") checkSparkAnswerAndOperator(df4) From cc136195152a071d7ed644870d7d0f59c7e22873 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 13:10:38 -0700 Subject: [PATCH 03/13] Clean up test --- .../apache/comet/exec/CometExecSuite.scala | 54 +++++++++---------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index a8669ef6c3..b8237c8353 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -65,47 +65,41 @@ class CometExecSuite extends CometTestBase { SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left val df1 = sql( "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") checkSparkAnswerAndOperator(df1) - // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. - // Left join with build left and right join with build right in hash join is only supported - // in Spark 3.5 or above. See SPARK-36612. - /* - val df2 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df2) - - val df3 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df3) - */ - + // Right join: build left val df2 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") checkSparkAnswerAndOperator(df2) + // Full join: build left val df3 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df3) - - val df4 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df4) - - val df5 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df5) - - val df6 = sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df6) + checkSparkAnswerAndOperator(df3) - val df7 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df7) + // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. + // Left join with build left and right join with build right in hash join is only supported + // in Spark 3.5 or above. See SPARK-36612. + // + // Left join: build left + // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + + // TODO: DataFusion HashJoin doesn't support build right yet. + // Inner join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Left join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Right join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Full join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") } } } From 334d7d9f078cf20a38688f76cf8520f3515d5d00 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 15:23:32 -0700 Subject: [PATCH 04/13] Fix join filter --- core/src/execution/datafusion/planner.rs | 125 +++++++++++++++++- .../apache/comet/exec/CometExecSuite.scala | 32 +++++ 2 files changed, 154 insertions(+), 3 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 01c79e0c07..29bf309594 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -45,7 +45,7 @@ use datafusion::{ }, }; use datafusion_common::{ - tree_node::{TreeNode, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}, JoinType as DFJoinType, ScalarValue, }; use itertools::Itertools; @@ -904,7 +904,18 @@ impl PhysicalPlanner { // Handle join filter as DataFusion `JoinFilter` struct let join_filter = if let Some(expr) = &join.condition { - let physical_expr = self.create_expr(expr, left.schema())?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let left_fields = left_schema.fields(); + let right_fields = right_schema.fields(); + let all_fields: Vec<_> = left_fields + .into_iter() + .chain(right_fields.into_iter()) + .map(|f| f.clone()) + .collect(); + let full_schema = Arc::new(Schema::new(all_fields)); + + let physical_expr = self.create_expr(expr, full_schema)?; let (left_field_indices, right_field_indices) = expr_to_columns( &physical_expr, left.schema().fields.len(), @@ -916,10 +927,12 @@ impl PhysicalPlanner { ); let filter_fields: Vec = left_field_indices + .clone() .into_iter() .map(|i| left.schema().field(i).clone()) .chain( right_field_indices + .clone() .into_iter() .map(|i| right.schema().field(i).clone()), ) @@ -927,8 +940,21 @@ impl PhysicalPlanner { let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); - Some(JoinFilter::new( + // Rewrite the physical expression to use the new column indices. + // DataFusion's join filter is bound to intermediate schema which contains + // only the fields used in the filter expression. But the Spark's join filter + // expression is bound to the full schema. We need to rewrite the physical + // expression to use the new column indices. + let rewritten_physical_expr = rewrite_physical_expr( physical_expr, + left_schema.fields.len(), + right_schema.fields.len(), + &left_field_indices, + &right_field_indices, + )?; + + Some(JoinFilter::new( + rewritten_physical_expr, column_indices, filter_schema, )) @@ -1173,6 +1199,99 @@ fn expr_to_columns( Ok((left_field_indices, right_field_indices)) } +/// A physical join filter rewritter which rewrites the column indices in the expression +/// to use the new column indices. See `rewrite_physical_expr`. +struct JoinFilterRewriter<'a> { + left_field_len: usize, + right_field_len: usize, + left_field_indices: &'a [usize], + right_field_indices: &'a [usize], +} + +impl JoinFilterRewriter<'_> { + fn new<'a>( + left_field_len: usize, + right_field_len: usize, + left_field_indices: &'a [usize], + right_field_indices: &'a [usize], + ) -> JoinFilterRewriter<'a> { + JoinFilterRewriter { + left_field_len, + right_field_len, + left_field_indices, + right_field_indices, + } + } +} + +impl TreeNodeRewriter for JoinFilterRewriter<'_> { + type N = Arc; + + fn mutate(&mut self, node: Self::N) -> datafusion_common::Result { + let new_expr: Arc = + if let Some(column) = node.as_any().downcast_ref::() { + if column.index() < self.left_field_len { + // left side + let new_index = self + .left_field_indices + .iter() + .position(|&x| x == column.index()) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Column index {} not found in left field indices", + column.index() + )) + })?; + Arc::new(Column::new(column.name(), new_index)) + } else if column.index() < self.left_field_len + self.right_field_len { + // right side + let new_index = self + .right_field_indices + .iter() + .position(|&x| x + self.left_field_len == column.index()) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Column index {} not found in right field indices", + column.index() + )) + })?; + Arc::new(Column::new( + column.name(), + new_index + self.left_field_indices.len(), + )) + } else { + return Err(DataFusionError::Internal(format!( + "Column index {} out of range", + column.index() + ))); + } + } else { + node.clone() + }; + Ok(new_expr) + } +} + +/// Rewrites the physical expression to use the new column indices. +/// This is necessary when the physical expression is used in a join filter, as the column +/// indices are different from the original schema. +fn rewrite_physical_expr( + expr: Arc, + left_field_len: usize, + right_field_len: usize, + left_field_indices: &[usize], + right_field_indices: &[usize], +) -> Result, ExecutionError> { + let mut rewriter = JoinFilterRewriter::new( + left_field_len, + right_field_len, + left_field_indices, + right_field_indices, + ); + + Ok(expr.rewrite(&mut rewriter)?) +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b8237c8353..09d3a6263a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -105,6 +105,38 @@ class CometExecSuite extends CometTestBase { } } + test("HashJoin with join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df3) + } + } + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) From cbd87cf373272ca9e423db1a931b21906287e81e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 21:46:37 -0700 Subject: [PATCH 05/13] Fix clippy --- core/src/execution/datafusion/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 29bf309594..752828ad20 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -910,8 +910,8 @@ impl PhysicalPlanner { let right_fields = right_schema.fields(); let all_fields: Vec<_> = left_fields .into_iter() - .chain(right_fields.into_iter()) - .map(|f| f.clone()) + .chain(right_fields) + .cloned() .collect(); let full_schema = Arc::new(Schema::new(all_fields)); From c95659c0f581abef680fd17e2845d958670d95d9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 22:28:43 -0700 Subject: [PATCH 06/13] Use consistent function with sort merge join --- core/src/execution/datafusion/planner.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 752828ad20..d19883730e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -965,13 +965,13 @@ impl PhysicalPlanner { // DataFusion `HashJoinExec` operator keeps the input batch internally. We need // to copy the input batch to avoid the data corruption from reusing the input // batch. - let left = if !is_op_do_copying(&left) { + let left = if op_reuse_array(&left) { Arc::new(CopyExec::new(left)) } else { left }; - let right = if !is_op_do_copying(&right) { + let right = if op_reuse_array(&right) { Arc::new(CopyExec::new(right)) } else { right @@ -1157,12 +1157,14 @@ impl From for DataFusionError { } } -/// Returns true if given operator copies input batch to avoid data corruption from reusing -/// input arrays. -fn is_op_do_copying(op: &Arc) -> bool { - op.as_any().downcast_ref::().is_some() - || op.as_any().downcast_ref::().is_some() - || op.as_any().downcast_ref::().is_some() +/// Returns true if given operator can return input array as output array without +/// modification. This is used to determine if we need to copy the input batch to avoid +/// data corruption from reusing the input batch. +fn op_reuse_array(op: &Arc) -> bool { + op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() } /// Collects the indices of the columns in the input schema that are used in the expression From 895160bc2ef27bca21aec0c3cbbce97d5d774fa2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Mar 2024 13:13:47 -0700 Subject: [PATCH 07/13] Add note about left semi and left anti joins --- .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 09d3a6263a..72e67bf5ce 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -100,6 +100,13 @@ class CometExecSuite extends CometTestBase { // // Full join: build right // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // val left = sql("SELECT * FROM tbl_a") + // val right = sql("SELECT * FROM tbl_b") + // + // Left semi and anti joins are only supported with build right in Spark. + // left.join(right, left("_2") === right("_1"), "leftsemi") + // left.join(right, left("_2") === right("_1"), "leftanti") } } } From 772588b1e5ef97efe7a87d0979e33d11c24bad2a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Mar 2024 14:40:49 -0700 Subject: [PATCH 08/13] feat: Support BroadcastHashJoin --- .../apache/comet/CometArrowStreamWriter.java | 51 +++++++++++++++ .../org/apache/comet/vector/NativeUtil.scala | 18 +++--- .../shuffle/ArrowReaderIterator.scala | 16 ++--- core/src/execution/operators/scan.rs | 2 +- .../comet/CometSparkSessionExtensions.scala | 35 +++++++++- .../apache/comet/serde/QueryPlanSerde.scala | 18 ++++-- .../comet/CometBroadcastExchangeExec.scala | 23 ++++++- .../apache/spark/sql/comet/operators.scala | 42 +++++++++++- .../apache/comet/exec/CometExecSuite.scala | 64 ++++++++++++++++++- 9 files changed, 240 insertions(+), 29 deletions(-) create mode 100644 common/src/main/java/org/apache/comet/CometArrowStreamWriter.java diff --git a/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java new file mode 100644 index 0000000000..a492ce887d --- /dev/null +++ b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream. + * Arrow `ArrowStreamWriter` cannot change the root after initialization. + */ +public class CometArrowStreamWriter extends ArrowStreamWriter { + public CometArrowStreamWriter( + VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } + + public void writeMoreBatch(VectorSchemaRoot root) throws IOException { + VectorUnloader unloader = + new VectorUnloader( + root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true); + + try (ArrowRecordBatch batch = unloader.getRecordBatch()) { + writeRecordBatch(batch); + } + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 1682295f72..cc726e3e8b 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -20,6 +20,7 @@ package org.apache.comet.vector import java.io.OutputStream +import java.nio.channels.Channels import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.SparkException import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.comet.CometArrowStreamWriter + class NativeUtil { private val allocator = new RootAllocator(Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider @@ -46,29 +48,27 @@ class NativeUtil { * the output stream */ def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = { - var schemaRoot: Option[VectorSchemaRoot] = None - var writer: Option[ArrowStreamWriter] = None + var writer: Option[CometArrowStreamWriter] = None var rowCount = 0 batches.foreach { batch => val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch) - val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava)) + val root = new VectorSchemaRoot(fieldVectors.asJava) val provider = batchProviderOpt.getOrElse(dictionaryProvider) if (writer.isEmpty) { - writer = Some(new ArrowStreamWriter(root, provider, out)) + writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out))) writer.get.start() + writer.get.writeBatch() + } else { + writer.get.writeMoreBatch(root) } - writer.get.writeBatch() root.clear() - schemaRoot = Some(root) - rowCount += batch.numRows() } writer.map(_.end()) - schemaRoot.map(_.close()) rowCount } diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index e8dba93e70..304c3ce779 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.vector.StreamReader +import org.apache.comet.vector._ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { @@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna return true } + // Release the previous batch. + // If it is not released, when closing the reader, arrow library will complain about + // memory leak. + if (currentBatch != null) { + currentBatch.close() + } + batch = nextBatch() if (batch.isEmpty) { return false @@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna val nextBatch = batch.get - // Release the previous batch. - // If it is not released, when closing the reader, arrow library will complain about - // memory leak. - if (currentBatch != null) { - currentBatch.close() - } - currentBatch = nextBatch batch = None currentBatch diff --git a/core/src/execution/operators/scan.rs b/core/src/execution/operators/scan.rs index e31230c583..9581ac4054 100644 --- a/core/src/execution/operators/scan.rs +++ b/core/src/execution/operators/scan.rs @@ -26,7 +26,7 @@ use futures::Stream; use itertools::Itertools; use arrow::compute::{cast_with_options, CastOptions}; -use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_array::{make_array, Array, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_data::ArrayData; use arrow_schema::{DataType, Field, Schema, SchemaRef}; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 7eca995579..09d92edd5c 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -359,6 +359,27 @@ class CometSparkSessionExtensions op } + case op: BroadcastHashJoinExec + if isCometOperatorEnabled(conf, "broadcast_hash_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometBroadcastHashJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => @@ -394,6 +415,16 @@ class CometSparkSessionExtensions u } + // For AQE broadcast stage on a Comet broadcast exchange + case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + case b: BroadcastExchangeExec if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && isCometBroadCastEnabled(conf) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b0692ad8e2..e1a348d83e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1840,7 +1840,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") => + case join: HashJoin => + // `HashJoin` has only two implementations in Spark, but we check the type of the join to + // make sure we are handling the correct join type. + if (!(isCometOperatorEnabled(op.conf, "hash_join") && + join.isInstanceOf[ShuffledHashJoinExec]) && + !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") && + join.isInstanceOf[BroadcastHashJoinExec])) { + return None + } + if (join.buildSide == BuildRight) { // DataFusion HashJoin assumes build side is always left. // TODO: support BuildRight @@ -1932,6 +1941,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true case _: BroadcastExchangeExec => true case _ => false } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f115b2ad92..e72088179a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -257,9 +257,28 @@ class CometBatchRDD( } override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - val partition = split.asInstanceOf[CometBatchPartition] + new Iterator[ColumnarBatch] { + val partition = split.asInstanceOf[CometBatchPartition] + val batchesIter = partition.value.value.map(CometExec.decodeBatches(_)).toIterator + var iter: Iterator[ColumnarBatch] = null + + override def hasNext: Boolean = { + if (iter != null) { + if (iter.hasNext) { + return true + } + } + if (batchesIter.hasNext) { + iter = batchesIter.next() + return iter.hasNext + } + false + } - partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator + override def next(): ColumnarBatch = { + iter.next() + } + } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2d03b2e65d..4cf55d4acb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec { plan match { case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | - _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec => + _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | + _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec => func(plan) case _: CometPlan => // Other Comet operators, continue to traverse the tree. @@ -625,6 +626,43 @@ case class CometHashJoinExec( Objects.hashCode(leftKeys, rightKeys, condition, left, right) } +case class CometBroadcastHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometBroadcastHashJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, left, right) +} + case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 72e67bf5ce..75974b6335 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -58,6 +58,68 @@ class CometExecSuite extends CometTestBase { } } + test("Broadcast HashJoin without join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + + test("Broadcast HashJoin with join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + test("HashJoin without join filter") { withSQLConf( SQLConf.PREFER_SORTMERGEJOIN.key -> "false", From d329d3f6ff9638effa2c9d1eea0985ee5432f133 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Mar 2024 20:03:51 -0700 Subject: [PATCH 09/13] Move tests --- .../apache/comet/exec/CometExecSuite.scala | 148 ----------------- .../apache/comet/exec/CometJoinSuite.scala | 150 ++++++++++++++++++ 2 files changed, 150 insertions(+), 148 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 4158f8acd2..83e1466c93 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -58,154 +58,6 @@ class CometExecSuite extends CometTestBase { } } - test("Broadcast HashJoin without join filter") { - assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf( - CometConf.COMET_BATCH_SIZE.key -> "100", - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - "spark.comet.exec.broadcast.enabled" -> "true", - "spark.sql.join.forceApplyShuffledHashJoin" -> "true", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator( - df1, - Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) - - // Right join: build left - val df2 = - sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator( - df2, - Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) - } - } - } - } - - test("Broadcast HashJoin with join filter") { - assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf( - CometConf.COMET_BATCH_SIZE.key -> "100", - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - "spark.comet.exec.broadcast.enabled" -> "true", - "spark.sql.join.forceApplyShuffledHashJoin" -> "true", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql( - "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator( - df1, - Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) - - // Right join: build left - val df2 = - sql( - "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator( - df2, - Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) - } - } - } - } - - test("HashJoin without join filter") { - withSQLConf( - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df1) - - // Right join: build left - val df2 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df2) - - // Full join: build left - val df3 = - sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") - checkSparkAnswerAndOperator(df3) - - // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. - // Left join with build left and right join with build right in hash join is only supported - // in Spark 3.5 or above. See SPARK-36612. - // - // Left join: build left - // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - - // TODO: DataFusion HashJoin doesn't support build right yet. - // Inner join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Left join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Right join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // Full join: build right - // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") - // - // val left = sql("SELECT * FROM tbl_a") - // val right = sql("SELECT * FROM tbl_b") - // - // Left semi and anti joins are only supported with build right in Spark. - // left.join(right, left("_2") === right("_1"), "leftsemi") - // left.join(right, left("_2") === right("_1"), "leftanti") - } - } - } - } - - test("HashJoin with join filter") { - withSQLConf( - SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { - withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { - // Inner join: build left - val df1 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df1) - - // Right join: build left - val df2 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df2) - - // Full join: build left - val df3 = - sql( - "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + - "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") - checkSparkAnswerAndOperator(df3) - } - } - } - } - test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 73ce0e1fdc..6f479e3bba 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -23,9 +23,11 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometJoinSuite extends CometTestBase { @@ -38,6 +40,154 @@ class CometJoinSuite extends CometTestBase { } } + test("Broadcast HashJoin without join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + + test("Broadcast HashJoin with join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + + test("HashJoin without join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + + // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. + // Left join with build left and right join with build right in hash join is only supported + // in Spark 3.5 or above. See SPARK-36612. + // + // Left join: build left + // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + + // TODO: DataFusion HashJoin doesn't support build right yet. + // Inner join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Left join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Right join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // Full join: build right + // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + // + // val left = sql("SELECT * FROM tbl_a") + // val right = sql("SELECT * FROM tbl_b") + // + // Left semi and anti joins are only supported with build right in Spark. + // left.join(right, left("_2") === right("_1"), "leftsemi") + // left.join(right, left("_2") === right("_1"), "leftanti") + } + } + } + } + + test("HashJoin with join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df3) + } + } + } + } + // TODO: Add a test for SortMergeJoin with join filter after new DataFusion release test("SortMergeJoin without join filter") { withSQLConf( From 85afec89ad49f9b114d78dcdcffa73f1ec8fc0e4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Mar 2024 23:25:04 -0700 Subject: [PATCH 10/13] Remove unused import --- spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 83e1466c93..4172c7caab 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec From 64240fa8b084a9965ac4cc4eeedd09774c39a1bc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 19 Mar 2024 17:39:53 -0700 Subject: [PATCH 11/13] Add function to parse join parameters --- core/src/execution/datafusion/planner.rs | 346 +++++++++++------------ 1 file changed, 168 insertions(+), 178 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 19180fd49a..c8869c5f3e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -92,6 +92,14 @@ type PhyAggResult = Result>, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +struct JoinParameters { + pub left: Arc, + pub right: Arc, + pub join_on: Vec<(Arc, Arc)>, + pub join_filter: Option, + pub join_type: DFJoinType, +} + pub const TEST_EXEC_CONTEXT_ID: i64 = -1; /// The query planner for converting Spark query plans to DataFusion query plans. @@ -873,50 +881,22 @@ impl PhysicalPlanner { )) } OpStruct::SortMergeJoin(join) => { - assert!(children.len() == 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; - - left_scans.append(&mut right_scans); - - let left_join_exprs = join - .left_join_keys - .iter() - .map(|expr| self.create_expr(expr, left.schema())) - .collect::, _>>()?; - let right_join_exprs = join - .right_join_keys - .iter() - .map(|expr| self.create_expr(expr, right.schema())) - .collect::, _>>()?; - - let join_on = left_join_exprs - .into_iter() - .zip(right_join_exprs) - .collect::>(); - - let join_type = match join.join_type.try_into() { - Ok(JoinType::Inner) => DFJoinType::Inner, - Ok(JoinType::LeftOuter) => DFJoinType::Left, - Ok(JoinType::RightOuter) => DFJoinType::Right, - Ok(JoinType::FullOuter) => DFJoinType::Full, - Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, - Ok(JoinType::RightSemi) => DFJoinType::RightSemi, - Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, - Ok(JoinType::RightAnti) => DFJoinType::RightAnti, - Err(_) => { - return Err(ExecutionError::GeneralError(format!( - "Unsupported join type: {:?}", - join.join_type - ))); - } - }; + let (join_params, scans) = self.parse_join_parameters( + inputs, + children, + &join.left_join_keys, + &join.right_join_keys, + join.join_type, + &None, + )?; let sort_options = join .sort_options .iter() .map(|sort_option| { - let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap(); + let sort_expr = self + .create_sort_expr(sort_option, join_params.left.schema()) + .unwrap(); SortOptions { descending: sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, @@ -924,163 +904,173 @@ impl PhysicalPlanner { }) .collect(); - // DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need - // to copy the input batch to avoid the data corruption from reusing the input - // batch. - let left = if can_reuse_input_batch(&left) { - Arc::new(CopyExec::new(left)) - } else { - left - }; - - let right = if can_reuse_input_batch(&right) { - Arc::new(CopyExec::new(right)) - } else { - right - }; - let join = Arc::new(SortMergeJoinExec::try_new( - left, - right, - join_on, - None, - join_type, + join_params.left, + join_params.right, + join_params.join_on, + join_params.join_filter, + join_params.join_type, sort_options, // null doesn't equal to null in Spark join key. If the join key is // `EqualNullSafe`, Spark will rewrite it during planning. false, )?); - Ok((left_scans, join)) + Ok((scans, join)) } OpStruct::HashJoin(join) => { - assert!(children.len() == 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; - - left_scans.append(&mut right_scans); + let (join_params, scans) = self.parse_join_parameters( + inputs, + children, + &join.left_join_keys, + &join.right_join_keys, + join.join_type, + &join.condition, + )?; + let join = Arc::new(HashJoinExec::try_new( + join_params.left, + join_params.right, + join_params.join_on, + join_params.join_filter, + &join_params.join_type, + PartitionMode::Partitioned, + // null doesn't equal to null in Spark join key. If the join key is + // `EqualNullSafe`, Spark will rewrite it during planning. + false, + )?); + Ok((scans, join)) + } + } + } - let left_join_exprs: Vec<_> = join - .left_join_keys - .iter() - .map(|expr| self.create_expr(expr, left.schema())) - .collect::, _>>()?; - let right_join_exprs: Vec<_> = join - .right_join_keys - .iter() - .map(|expr| self.create_expr(expr, right.schema())) - .collect::, _>>()?; + fn parse_join_parameters( + &self, + inputs: &mut Vec>, + children: &[Operator], + left_join_keys: &[Expr], + right_join_keys: &[Expr], + join_type: i32, + condition: &Option, + ) -> Result<(JoinParameters, Vec), ExecutionError> { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + + let left_join_exprs: Vec<_> = left_join_keys + .iter() + .map(|expr| self.create_expr(expr, left.schema())) + .collect::, _>>()?; + let right_join_exprs: Vec<_> = right_join_keys + .iter() + .map(|expr| self.create_expr(expr, right.schema())) + .collect::, _>>()?; - let join_on = left_join_exprs - .into_iter() - .zip(right_join_exprs) - .collect::>(); + let join_on = left_join_exprs + .into_iter() + .zip(right_join_exprs) + .collect::>(); - let join_type = match join.join_type.try_into() { - Ok(JoinType::Inner) => DFJoinType::Inner, - Ok(JoinType::LeftOuter) => DFJoinType::Left, - Ok(JoinType::RightOuter) => DFJoinType::Right, - Ok(JoinType::FullOuter) => DFJoinType::Full, - Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, - Ok(JoinType::RightSemi) => DFJoinType::RightSemi, - Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, - Ok(JoinType::RightAnti) => DFJoinType::RightAnti, - Err(_) => { - return Err(ExecutionError::GeneralError(format!( - "Unsupported join type: {:?}", - join.join_type - ))); - } - }; + let join_type = match join_type.try_into() { + Ok(JoinType::Inner) => DFJoinType::Inner, + Ok(JoinType::LeftOuter) => DFJoinType::Left, + Ok(JoinType::RightOuter) => DFJoinType::Right, + Ok(JoinType::FullOuter) => DFJoinType::Full, + Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, + Ok(JoinType::RightSemi) => DFJoinType::RightSemi, + Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, + Ok(JoinType::RightAnti) => DFJoinType::RightAnti, + Err(_) => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported join type: {:?}", + join_type + ))); + } + }; - // Handle join filter as DataFusion `JoinFilter` struct - let join_filter = if let Some(expr) = &join.condition { - let left_schema = left.schema(); - let right_schema = right.schema(); - let left_fields = left_schema.fields(); - let right_fields = right_schema.fields(); - let all_fields: Vec<_> = left_fields - .into_iter() - .chain(right_fields) - .cloned() - .collect(); - let full_schema = Arc::new(Schema::new(all_fields)); - - let physical_expr = self.create_expr(expr, full_schema)?; - let (left_field_indices, right_field_indices) = expr_to_columns( - &physical_expr, - left.schema().fields.len(), - right.schema().fields.len(), - )?; - let column_indices = JoinFilter::build_column_indices( - left_field_indices.clone(), - right_field_indices.clone(), - ); - - let filter_fields: Vec = left_field_indices + // Handle join filter as DataFusion `JoinFilter` struct + let join_filter = if let Some(expr) = condition { + let left_schema = left.schema(); + let right_schema = right.schema(); + let left_fields = left_schema.fields(); + let right_fields = right_schema.fields(); + let all_fields: Vec<_> = left_fields + .into_iter() + .chain(right_fields) + .cloned() + .collect(); + let full_schema = Arc::new(Schema::new(all_fields)); + + let physical_expr = self.create_expr(expr, full_schema)?; + let (left_field_indices, right_field_indices) = + expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?; + let column_indices = JoinFilter::build_column_indices( + left_field_indices.clone(), + right_field_indices.clone(), + ); + + let filter_fields: Vec = left_field_indices + .clone() + .into_iter() + .map(|i| left.schema().field(i).clone()) + .chain( + right_field_indices .clone() .into_iter() - .map(|i| left.schema().field(i).clone()) - .chain( - right_field_indices - .clone() - .into_iter() - .map(|i| right.schema().field(i).clone()), - ) - .collect_vec(); - - let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); - - // Rewrite the physical expression to use the new column indices. - // DataFusion's join filter is bound to intermediate schema which contains - // only the fields used in the filter expression. But the Spark's join filter - // expression is bound to the full schema. We need to rewrite the physical - // expression to use the new column indices. - let rewritten_physical_expr = rewrite_physical_expr( - physical_expr, - left_schema.fields.len(), - right_schema.fields.len(), - &left_field_indices, - &right_field_indices, - )?; - - Some(JoinFilter::new( - rewritten_physical_expr, - column_indices, - filter_schema, - )) - } else { - None - }; - - // DataFusion `HashJoinExec` operator keeps the input batch internally. We need - // to copy the input batch to avoid the data corruption from reusing the input - // batch. - let left = if can_reuse_input_batch(&left) { - Arc::new(CopyExec::new(left)) - } else { - left - }; + .map(|i| right.schema().field(i).clone()), + ) + .collect_vec(); + + let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); + + // Rewrite the physical expression to use the new column indices. + // DataFusion's join filter is bound to intermediate schema which contains + // only the fields used in the filter expression. But the Spark's join filter + // expression is bound to the full schema. We need to rewrite the physical + // expression to use the new column indices. + let rewritten_physical_expr = rewrite_physical_expr( + physical_expr, + left_schema.fields.len(), + right_schema.fields.len(), + &left_field_indices, + &right_field_indices, + )?; + + Some(JoinFilter::new( + rewritten_physical_expr, + column_indices, + filter_schema, + )) + } else { + None + }; - let right = if can_reuse_input_batch(&right) { - Arc::new(CopyExec::new(right)) - } else { - right - }; + // DataFusion Join operators keep the input batch internally. We need + // to copy the input batch to avoid the data corruption from reusing the input + // batch. + let left = if can_reuse_input_batch(&left) { + Arc::new(CopyExec::new(left)) + } else { + left + }; - let join = Arc::new(HashJoinExec::try_new( - left, - right, - join_on, - join_filter, - &join_type, - PartitionMode::Partitioned, - false, - )?); + let right = if can_reuse_input_batch(&right) { + Arc::new(CopyExec::new(right)) + } else { + right + }; - Ok((left_scans, join)) - } - } + Ok(( + JoinParameters { + left, + right, + join_on, + join_type, + join_filter, + }, + left_scans, + )) } /// Create a DataFusion physical aggregate expression from Spark physical aggregate expression From 187ba365e58a5d9da103e1e36f1569723347c6d7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Mar 2024 13:49:31 -0700 Subject: [PATCH 12/13] Remove duplicate code --- core/src/execution/operators/scan.rs | 2 +- .../comet/CometSparkSessionExtensions.scala | 21 ----------- .../apache/spark/sql/comet/operators.scala | 37 ------------------- 3 files changed, 1 insertion(+), 59 deletions(-) diff --git a/core/src/execution/operators/scan.rs b/core/src/execution/operators/scan.rs index 9581ac4054..e31230c583 100644 --- a/core/src/execution/operators/scan.rs +++ b/core/src/execution/operators/scan.rs @@ -26,7 +26,7 @@ use futures::Stream; use itertools::Itertools; use arrow::compute::{cast_with_options, CastOptions}; -use arrow_array::{make_array, Array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_data::ArrayData; use arrow_schema::{DataType, Field, Schema, SchemaRef}; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 366f4d7594..37ff8f7cd7 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -335,27 +335,6 @@ class CometSparkSessionExtensions op } - case op: ShuffledHashJoinExec - if isCometOperatorEnabled(conf, "hash_join") && - op.children.forall(isCometNative(_)) => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - CometHashJoinExec( - nativeOp, - op, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None)) - case None => - op - } - case op: SortMergeJoinExec if isCometOperatorEnabled(conf, "sort_merge_join") && op.children.forall(isCometNative(_)) => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8d228bbbbe..4f9addf218 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -589,43 +589,6 @@ case class CometHashAggregateExec( Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) } -case class CometHashJoinExec( - override val nativeOp: Operator, - override val originalPlan: SparkPlan, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - buildSide: BuildSide, - override val left: SparkPlan, - override val right: SparkPlan, - override val serializedPlanOpt: SerializedPlan) - extends CometBinaryExec { - override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = - this.copy(left = newLeft, right = newRight) - - override def stringArgs: Iterator[Any] = - Iterator(leftKeys, rightKeys, joinType, condition, left, right) - - override def equals(obj: Any): Boolean = { - obj match { - case other: CometHashJoinExec => - this.leftKeys == other.leftKeys && - this.rightKeys == other.rightKeys && - this.condition == other.condition && - this.buildSide == other.buildSide && - this.left == other.left && - this.right == other.right && - this.serializedPlanOpt == other.serializedPlanOpt - case _ => - false - } - } - - override def hashCode(): Int = - Objects.hashCode(leftKeys, rightKeys, condition, left, right) -} - case class CometSortMergeJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, From ce9e1ef9f0f8c9df427dfe8f6bbbc684649b1e07 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Mar 2024 22:03:39 -0700 Subject: [PATCH 13/13] For review --- .../comet/CometSparkSessionExtensions.scala | 20 ++++++++-------- .../comet/CometBroadcastExchangeExec.scala | 24 ++----------------- .../apache/spark/sql/comet/operators.scala | 20 ++++++++-------- 3 files changed, 22 insertions(+), 42 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 37ff8f7cd7..fcbf42f5b0 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -335,19 +335,20 @@ class CometSparkSessionExtensions op } - case op: SortMergeJoinExec - if isCometOperatorEnabled(conf, "sort_merge_join") && + case op: ShuffledHashJoinExec + if isCometOperatorEnabled(conf, "hash_join") && op.children.forall(isCometNative(_)) => val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometSortMergeJoinExec( + CometHashJoinExec( nativeOp, op, op.leftKeys, op.rightKeys, op.joinType, op.condition, + op.buildSide, op.left, op.right, SerializedPlan(None)) @@ -355,13 +356,13 @@ class CometSparkSessionExtensions op } - case op: ShuffledHashJoinExec - if isCometOperatorEnabled(conf, "hash_join") && + case op: BroadcastHashJoinExec + if isCometOperatorEnabled(conf, "broadcast_hash_join") && op.children.forall(isCometNative(_)) => val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometHashJoinExec( + CometBroadcastHashJoinExec( nativeOp, op, op.leftKeys, @@ -376,20 +377,19 @@ class CometSparkSessionExtensions op } - case op: BroadcastHashJoinExec - if isCometOperatorEnabled(conf, "broadcast_hash_join") && + case op: SortMergeJoinExec + if isCometOperatorEnabled(conf, "sort_merge_join") && op.children.forall(isCometNative(_)) => val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometBroadcastHashJoinExec( + CometSortMergeJoinExec( nativeOp, op, op.leftKeys, op.rightKeys, op.joinType, op.condition, - op.buildSide, op.left, op.right, SerializedPlan(None)) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index e72088179a..24f9f32795 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -257,28 +257,8 @@ class CometBatchRDD( } override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - val partition = split.asInstanceOf[CometBatchPartition] - val batchesIter = partition.value.value.map(CometExec.decodeBatches(_)).toIterator - var iter: Iterator[ColumnarBatch] = null - - override def hasNext: Boolean = { - if (iter != null) { - if (iter.hasNext) { - return true - } - } - if (batchesIter.hasNext) { - iter = batchesIter.next() - return iter.hasNext - } - false - } - - override def next(): ColumnarBatch = { - iter.next() - } - } + val partition = split.asInstanceOf[CometBatchPartition] + partition.value.value.toIterator.flatMap(CometExec.decodeBatches) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 4f9addf218..84734a175a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -589,13 +589,14 @@ case class CometHashAggregateExec( Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) } -case class CometSortMergeJoinExec( +case class CometHashJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], + buildSide: BuildSide, override val left: SparkPlan, override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) @@ -608,10 +609,11 @@ case class CometSortMergeJoinExec( override def equals(obj: Any): Boolean = { obj match { - case other: CometSortMergeJoinExec => + case other: CometHashJoinExec => this.leftKeys == other.leftKeys && this.rightKeys == other.rightKeys && this.condition == other.condition && + this.buildSide == other.buildSide && this.left == other.left && this.right == other.right && this.serializedPlanOpt == other.serializedPlanOpt @@ -621,10 +623,10 @@ case class CometSortMergeJoinExec( } override def hashCode(): Int = - Objects.hashCode(leftKeys, rightKeys, condition, left, right) + Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right) } -case class CometHashJoinExec( +case class CometBroadcastHashJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, leftKeys: Seq[Expression], @@ -644,7 +646,7 @@ case class CometHashJoinExec( override def equals(obj: Any): Boolean = { obj match { - case other: CometHashJoinExec => + case other: CometBroadcastHashJoinExec => this.leftKeys == other.leftKeys && this.rightKeys == other.rightKeys && this.condition == other.condition && @@ -658,17 +660,16 @@ case class CometHashJoinExec( } override def hashCode(): Int = - Objects.hashCode(leftKeys, rightKeys, condition, left, right) + Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right) } -case class CometBroadcastHashJoinExec( +case class CometSortMergeJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], - buildSide: BuildSide, override val left: SparkPlan, override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) @@ -681,11 +682,10 @@ case class CometBroadcastHashJoinExec( override def equals(obj: Any): Boolean = { obj match { - case other: CometBroadcastHashJoinExec => + case other: CometSortMergeJoinExec => this.leftKeys == other.leftKeys && this.rightKeys == other.rightKeys && this.condition == other.condition && - this.buildSide == other.buildSide && this.left == other.left && this.right == other.right && this.serializedPlanOpt == other.serializedPlanOpt