diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index d92bf57895..5cd12d9e5e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -24,7 +24,10 @@ use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, execution::FunctionRegistry, - logical_expr::Operator as DataFusionOperator, + logical_expr::{ + expr::find_df_window_func, Operator as DataFusionOperator, WindowFrame, WindowFrameBound, + WindowFrameUnits, + }, physical_expr::{ execution_props::ExecutionProps, expressions::{ @@ -42,7 +45,8 @@ use datafusion::{ limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, - ExecutionPlan, Partitioning, + windows::BoundedWindowAggExec, + ExecutionPlan, InputOrderMode, Partitioning, WindowExpr, }, prelude::SessionContext, }; @@ -87,12 +91,15 @@ use crate::{ }, operators::{CopyExec, ExecutionError, ScanExec}, serde::to_arrow_datatype, - spark_expression, spark_expression::{ - agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr, - ScalarFunc, + self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, + Expr, ScalarFunc, + }, + spark_operator::{ + lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct, + upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, JoinType, Operator, + WindowFrameType, }, - spark_operator::{operator::OpStruct, BuildSide, JoinType, Operator}, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; @@ -870,6 +877,50 @@ impl PhysicalPlanner { )?), )) } + OpStruct::Window(wnd) => { + dbg!(&children[0]); + let (scans, child) = self.create_plan(&children[0], inputs)?; + //dbg!(&child); + let input_schema = child.schema(); + dbg!(&input_schema); + let sort_exprs: Result, ExecutionError> = wnd + .order_by_list + .iter() + .map(|expr| self.create_sort_expr(expr, input_schema.clone())) + .collect(); + + let partition_exprs: Result>, ExecutionError> = wnd + .partition_by_list + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect(); + + let sort_exprs = &sort_exprs?; + let partition_exprs = &partition_exprs?; + + let window_expr: Result>, ExecutionError> = wnd + .window_expr + .iter() + .map(|expr| { + self.create_window_expr( + expr, + input_schema.clone(), + partition_exprs, + sort_exprs, + ) + }) + .collect(); + + Ok(( + scans, + Arc::new(BoundedWindowAggExec::try_new( + window_expr?, + child, + partition_exprs.to_vec(), + InputOrderMode::Sorted, + )?), + )) + } OpStruct::Expand(expand) => { assert!(children.len() == 1); let (scans, child) = self.create_plan(&children[0], inputs)?; @@ -1330,6 +1381,106 @@ impl PhysicalPlanner { } } + /// Create a DataFusion windows physical expression from Spark physical expression + fn create_window_expr<'a>( + &'a self, + spark_expr: &'a crate::execution::spark_operator::WindowExpr, + input_schema: SchemaRef, + partition_by: &[Arc], + sort_exprs: &[PhysicalSortExpr], + ) -> Result, ExecutionError> { + let (window_func_name, window_func_args) = + match &spark_expr.func.as_ref().unwrap().expr_struct.as_ref() { + Some(ExprStruct::ScalarFunc(f)) => (f.func.clone(), f.args.clone()), + other => { + return Err(ExecutionError::GeneralError(format!( + "{other:?} not supported for window function" + ))) + } + }; + + let window_func = match find_df_window_func(&window_func_name) { + Some(f) => f, + _ => { + return Err(ExecutionError::GeneralError(format!( + "{window_func_name} not supported for window function" + ))) + } + }; + + let window_args = window_func_args + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect::, ExecutionError>>()?; + + let spark_window_frame = match spark_expr + .spec + .as_ref() + .and_then(|inner| inner.frame_specification.as_ref()) + { + Some(frame) => frame, + _ => { + return Err(ExecutionError::DeserializeError( + "Cannot deserialize window frame".to_string(), + )) + } + }; + + let units = match spark_window_frame.frame_type() { + WindowFrameType::Rows => WindowFrameUnits::Rows, + WindowFrameType::Range => WindowFrameUnits::Range, + }; + + let lower_bound: WindowFrameBound = match spark_window_frame + .lower_bound + .as_ref() + .and_then(|inner| inner.lower_frame_bound_struct.as_ref()) + { + Some(l) => match l { + LowerFrameBoundStruct::UnboundedPreceding(_) => { + WindowFrameBound::Preceding(ScalarValue::Null) + } + LowerFrameBoundStruct::Preceding(offset) => { + WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset))) + } + LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Preceding(ScalarValue::Null), + }; + + let upper_bound: WindowFrameBound = match spark_window_frame + .upper_bound + .as_ref() + .and_then(|inner| inner.upper_frame_bound_struct.as_ref()) + { + Some(u) => match u { + UpperFrameBoundStruct::UnboundedFollowing(_) => { + WindowFrameBound::Preceding(ScalarValue::Null) + } + UpperFrameBoundStruct::Following(offset) => { + WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset))) + } + UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Following(ScalarValue::Null), + }; + + let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound); + + dbg!(&window_func); + datafusion::physical_plan::windows::create_window_expr( + &window_func, + window_func_name, + &window_args, + partition_by, + sort_exprs, + window_frame.into(), + &input_schema, + false, // TODO: Ignore nulls + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } + /// Create a DataFusion physical partitioning from Spark physical partitioning fn create_partitioning( &self, diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index bc194238ba..b4c1ba5727 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -310,6 +310,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( try_unwrap_or_throw(&e, |mut env| { // Retrieve the query let exec_context = get_execution_context(exec_context); + //dbg!(&exec_context.spark_plan.children[0]); let exec_context_id = exec_context.id; diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index de25f94dae..5d80a72a06 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -42,6 +42,7 @@ message Operator { Expand expand = 107; SortMergeJoin sort_merge_join = 108; HashJoin hash_join = 109; + Window window = 110; } } @@ -120,3 +121,60 @@ enum BuildSide { BuildLeft = 0; BuildRight = 1; } + +message WindowExpr { + spark.spark_expression.Expr func = 1; + WindowSpecDefinition spec = 2; +} + +enum WindowFrameType { + Rows = 0; + Range = 1; +} + +message WindowFrame { + WindowFrameType frame_type = 1; + LowerWindowFrameBound lower_bound = 2; + UpperWindowFrameBound upper_bound = 3; +} + +message LowerWindowFrameBound { + oneof lower_frame_bound_struct { + UnboundedPreceding unboundedPreceding = 1; + Preceding preceding = 2; + CurrentRow currentRow = 3; + } +} + +message UpperWindowFrameBound { + oneof upper_frame_bound_struct { + UnboundedFollowing unboundedFollowing = 1; + Following following = 2; + CurrentRow currentRow = 3; + } +} + +message Preceding { + int32 offset = 1; +} + +message Following { + int32 offset = 1; +} + +message UnboundedPreceding {} +message UnboundedFollowing {} +message CurrentRow {} + +message WindowSpecDefinition { + repeated spark.spark_expression.Expr partitionSpec = 1; + repeated spark.spark_expression.Expr orderSpec = 2; + WindowFrame frameSpecification = 3; +} + +message Window { + repeated WindowExpr window_expr = 1; + repeated spark.spark_expression.Expr order_by_list = 2; + repeated spark.spark_expression.Expr partition_by_list = 3; + Operator child = 4; +} \ No newline at end of file diff --git a/docs/source/contributor-guide/debugging.md b/docs/source/contributor-guide/debugging.md index d1f62a5dbe..b1f92ea85e 100644 --- a/docs/source/contributor-guide/debugging.md +++ b/docs/source/contributor-guide/debugging.md @@ -31,6 +31,62 @@ This HOWTO describes how to debug JVM code and Native code concurrently. The gui _Caveat: The steps here have only been tested with JDK 11_ on Mac (M1) +# Expand Comet exception details +By default, Comet outputs the exception details specific for Comet. There is a possibility of extending the exception +details by leveraging Datafusion [backtraces](https://arrow.apache.org/datafusion/user-guide/example-usage.html#enable-backtraces) + +```scala +scala> spark.sql("my_failing_query").show(false) + +24/03/05 17:00:07 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)/ 1] +org.apache.comet.CometNativeException: Internal error: MIN/MAX is not expected to receive scalars of incompatible types (Date32("NULL"), Int32(15901)). +This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker + at org.apache.comet.Native.executePlan(Native Method) + at org.apache.comet.CometExecIterator.executeNative(CometExecIterator.scala:65) + at org.apache.comet.CometExecIterator.getNextBatch(CometExecIterator.scala:111) + at org.apache.comet.CometExecIterator.hasNext(CometExecIterator.scala:126) + +``` +To do that with Comet it is needed to enable `backtrace` in https://github.com/apache/arrow-datafusion-comet/blob/main/core/Cargo.toml + +``` +datafusion-common = { version = "36.0.0", features = ["backtrace"] } +datafusion = { default-features = false, version = "36.0.0", features = ["unicode_expressions", "backtrace"] } +``` + +Then build the Comet as [described](https://github.com/apache/arrow-datafusion-comet/blob/main/README.md#getting-started) + +Start Comet with `RUST_BACKTRACE=1` + +```commandline +RUST_BACKTRACE=1 $SPARK_HOME/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions --conf spark.comet.enabled=true --conf spark.comet.exec.enabled=true --conf spark.comet.exec.all.enabled=true +``` + +Get the expanded exception details +```scala +scala> spark.sql("my_failing_query").show(false) +24/03/05 17:00:49 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) +org.apache.comet.CometNativeException: Internal error: MIN/MAX is not expected to receive scalars of incompatible types (Date32("NULL"), Int32(15901)) + +backtrace: 0: std::backtrace::Backtrace::create +1: datafusion_physical_expr::aggregate::min_max::min +2: ::update_batch + 3: as futures_core::stream::Stream>::poll_next +4: comet::execution::jni_api::Java_org_apache_comet_Native_executePlan::{{closure}} +5: _Java_org_apache_comet_Native_executePlan +. +This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker + at org.apache.comet.Native.executePlan(Native Method) +at org.apache.comet.CometExecIterator.executeNative(CometExecIterator.scala:65) +at org.apache.comet.CometExecIterator.getNextBatch(CometExecIterator.scala:111) +at org.apache.comet.CometExecIterator.hasNext(CometExecIterator.scala:126) + +... +``` +Note: +- The backtrace coverage in Datafusion is still improving. So there is a chance the error still not covered, feel free to file a [ticket](https://github.com/apache/arrow-datafusion/issues) +- The backtrace doesn't come for free and therefore intended for debugging purposes + ## Debugging for Advanced Developers Add a `.lldbinit` to comet/core. This is not strictly necessary but will be useful if you want to diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index d6ec85f5be..1b32ce05ac 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -532,6 +533,16 @@ class CometSparkSessionExtensions s } + case w: WindowExec => + QueryPlanSerde.operator2Proto(w) match { + case Some(nativeOp) => + val cometOp = + CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child) + CometSinkPlaceHolder(nativeOp, w, cometOp) + case None => + w + } + case s: TakeOrderedAndProjectExec => val info1 = createMessage( !isCometOperatorEnabled(conf, "takeOrderedAndProjectExec"), diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c1c8b5c56b..9fab5f9524 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, Shuffle import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -200,6 +201,91 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def windowExprToProto( + windowExpr: WindowExpression, + inputs: Seq[Attribute]): Option[OperatorOuterClass.WindowExpr] = { + val func = exprToProto(windowExpr.windowFunction, inputs).getOrElse(return None) + + val f = windowExpr.windowSpec.frameSpecification + + val (frameType, lowerBound, upperBound) = f match { + case SpecifiedWindowFrame(frameType, lBound, uBound) => + val frameProto = frameType match { + case RowFrame => OperatorOuterClass.WindowFrameType.Rows + case RangeFrame => OperatorOuterClass.WindowFrameType.Range + } + + val lBoundProto = lBound match { + case UnboundedPreceding => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setPreceding( + OperatorOuterClass.Preceding + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + val uBoundProto = uBound match { + case UnboundedFollowing => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setFollowing( + OperatorOuterClass.Following + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + (frameProto, lBoundProto, uBoundProto) + case _ => + ( + OperatorOuterClass.WindowFrameType.Rows, + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build(), + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build()) + } + + val frame = OperatorOuterClass.WindowFrame + .newBuilder() + .setFrameType(frameType) + .setLowerBound(lowerBound) + .setUpperBound(upperBound) + .build() + + val spec = + OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build() + + Some(OperatorOuterClass.WindowExpr.newBuilder().setFunc(func).setSpec(spec).build()) + } + def aggExprToProto( aggExpr: AggregateExpression, inputs: Seq[Attribute], @@ -1999,6 +2085,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case r @ Rank(_) => + val exprChildren = r.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("rank", exprChildren: _*) + + case r @ RowNumber() => + val exprChildren = r.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("row_number", exprChildren: _*) + + case l @ Lag(_, _, _, _) => + val exprChildren = l.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("lag", exprChildren: _*) + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. // See https://github.com/apache/spark/pull/38151 @@ -2661,6 +2759,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: TakeOrderedAndProjectExec => true case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true case _: BroadcastExchangeExec => true + case _: WindowExec => true case _ => false } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala new file mode 100644 index 0000000000..8c60c89e48 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters.asJavaIterableConverter + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, SortOrder, WindowExpression} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.CometWindowExec.getNativePlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, windowExprToProto} + +/** + * Comet physical plan node for Spark `WindowsExec`. + * + * It is used to execute a `WindowsExec` physical operator by using Comet native engine. It is not + * like other physical plan nodes which are wrapped by `CometExec`, because it contains two native + * executions separated by a Comet shuffle exchange. + */ +case class CometWindowExec( + override val originalPlan: SparkPlan, + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + override def nodeName: String = "CometWindowExec" + + override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + override def supportsColumnar: Boolean = true + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + + childRDD.mapPartitionsInternal { iter => + CometExec.getCometIterator( + Seq(iter), + getNativePlan(output, windowExpression, partitionSpec, orderSpec, child).get) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + +} + +object CometWindowExec { + def getNativePlan( + outputAttributes: Seq[Attribute], + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan): Option[Operator] = { + + val orderSpecs = orderSpec.map(exprToProto(_, child.output)) + val partitionSpecs = partitionSpec.map(exprToProto(_, child.output)) + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + val windowExprs = windowExpression.map(w => + windowExprToProto( + w.asInstanceOf[Alias].child.asInstanceOf[WindowExpression], + outputAttributes)) + + val windowBuilder = OperatorOuterClass.Window + .newBuilder() + + if (windowExprs.forall(_.isDefined)) { + windowBuilder + .addAllWindowExpr(windowExprs.map(_.get).asJava) + + if (orderSpecs.forall(_.isDefined)) { + windowBuilder.addAllOrderByList(orderSpecs.map(_.get).asJava) + } + + if (partitionSpecs.forall(_.isDefined)) { + windowBuilder.addAllPartitionByList(partitionSpecs.map(_.get).asJava) + } + + scanBuilder.addAllFields(scanTypes.asJava) + + val opBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + + Some(opBuilder.setWindow(windowBuilder).build()) + } else None + } +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index a0afe6b0c0..0154cc1fc1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1430,6 +1430,66 @@ class CometExecSuite extends CometTestBase { } }) } + + test("Windows support") { + // Disable native exec requirements because to invoke TakeOrderedAndProjectExec + // we can through Window which is not currently supported + Seq("true", "false").foreach(aqeEnabled => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { + withTable("t1") { + val numRows = 10 + spark + .range(numRows) + .selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b") + .repartition(3) // Force repartition to test data will come to single partition + .write + .saveAsTable("t1") + + val df1 = spark.sql(""" + |SELECT a, b, lag(a) OVER(ORDER BY a) AS rn + |FROM t1 LIMIT 3 + |""".stripMargin) + + assert(df1.rdd.getNumPartitions == 1) + df1.printSchema() + df1.show(false) + println("df: " + df1.collectAsList()) + checkAnswer(df1, Seq(Row(null, 2, 1), Row(null, 4, 2), Row(null, 6, 3))) + + // val cometWindowExec = stripAQEPlan(df1.queryExecution.executedPlan).collect { + // case b: CometWindowExec => b + // } + // + // assert(cometWindowExec.length == 1) + // + // val df2 = spark.sql(""" + // |SELECT b, RANK() OVER(ORDER BY a, b) AS rk, DENSE_RANK(b) OVER(ORDER BY a, b) AS s + // |FROM t1 LIMIT 2 + // |""".stripMargin) + // assert(df2.rdd.getNumPartitions == 1) + // checkAnswer(df2, Seq(Row(2, 1, 1), Row(4, 2, 2))) + // + // checkSparkAnswer("select _1, rank() over (order by _2) from tbl") + // checkSparkAnswer("select _1, rank() over (order by _2 nulls last) from tbl") + // checkSparkAnswer("select _1, rank() over (order by _2 nulls first) from tbl") + // checkSparkAnswer("select _1, rank() over (order by _2 desc) from tbl") + // checkSparkAnswer("select _1, rank() over (order by _2 desc nulls last) from tbl") + // checkSparkAnswer("select _1, rank() over (order by _2 desc nulls first) from tbl") + // + // checkSparkAnswer("select _1, row_number() over (order by _2) from tbl") + // checkSparkAnswer("select _1, row_number() over (partition by _2 order by _2) from tbl") + // + // checkSparkAnswer( + // "select _1, sum(_2) over (order by _2 rows between 1 preceding and 1 following) from tbl") + // checkSparkAnswer( + // "select _1, sum(_2) over (order by _2 rows between 1 preceding and current row) from tbl") + // checkSparkAnswer( + // "select _1, sum(_2) over (order by _2 rows between current row and 1 following) from tbl") + } + }) + } } case class BucketedTableTestSpec(