diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 85434b8674..1cd67cd3b0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer @@ -1350,6 +1350,13 @@ case class CometUnionExec( trait CometBaseAggregate { + private def containsMapType(dt: DataType): Boolean = dt match { + case _: MapType => true + case StructType(fields) => fields.exists(f => containsMapType(f.dataType)) + case ArrayType(elementType, _) => containsMapType(elementType) + case _ => false + } + def doConvert( aggregate: BaseAggregateExec, builder: Operator.Builder, @@ -1377,12 +1384,8 @@ trait CometBaseAggregate { return None } - if (groupingExpressions.exists(expr => - expr.dataType match { - case _: MapType => true - case _ => false - })) { - withInfo(aggregate, "Grouping on map types is not supported") + if (groupingExpressions.exists(expr => containsMapType(expr.dataType))) { + withInfo(aggregate, "Grouping on map-containing types is not supported") return None } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 95f3774e01..1fe416f105 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -408,6 +408,32 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("grouping on struct containing map should fallback to Spark") { + withSQLConf( + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + val query = + """SELECT col1.data['key'] + |FROM VALUES + | (NAMED_STRUCT('data', MAP('key', 'value', 'num', '42'))), + | (NAMED_STRUCT('data', MAP('key', 'other', 'num', '7'))) + |t (col1) + |GROUP BY col1 + |HAVING col1.data['num'] IS NOT NULL + |ORDER BY col1.data['key'] + |""".stripMargin + + val (_, cometPlan) = + checkSparkAnswerAndFallbackReason( + query, + "Grouping on map-containing types is not supported") + + assert( + stripAQEPlan(cometPlan).collect { case s: CometHashAggregateExec => s }.isEmpty, + "Expected aggregate to fall back to Spark for grouping on Struct(Map(...))") + } + } + test("simple SUM, COUNT, MIN, MAX, AVG with non-distinct + null group keys") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( @@ -1990,4 +2016,19 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { sparkPlan.collect { case s: CometHashAggregateExec => s }.size } + test("group by array of map falls back to Spark (issue #4123)") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + checkSparkAnswerAndFallbackReason( + """SELECT a, COUNT(*) + |FROM VALUES + | (ARRAY(MAP('x', 10))), + | (ARRAY(MAP('y', 20))), + | (ARRAY(MAP('x', 10))) + |t (a) + |GROUP BY a + |""".stripMargin, + "Grouping on map-containing types is not supported") + } + } + }