From 2ba19f0095f39f2098ee0febcdaac415371343ed Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Thu, 8 Aug 2024 19:54:31 +0200 Subject: [PATCH] Add support for null literal with struct type --- native/core/src/execution/datafusion/planner.rs | 2 ++ .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 3 ++- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 5 +++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b604e98ba8..75b42ab049 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -97,6 +97,7 @@ use datafusion_comet_spark_expr::{ Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, }; +use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, @@ -300,6 +301,7 @@ impl PhysicalPlanner { } DataType::Binary => ScalarValue::Binary(None), DataType::Decimal128(p, s) => ScalarValue::Decimal128(None, p, s), + DataType::Struct(fields) => ScalarStructBuilder::new_null(fields), DataType::Null => ScalarValue::Null, dt => { return Err(ExecutionError::GeneralError(format!( diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5f3cc7a2ea..4ab8002bf9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1136,7 +1136,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case Literal(value, dataType) if supportedDataType(dataType) => + case Literal(value, dataType) + if supportedDataType(dataType, allowStruct = value == null) => val exprBuilder = ExprOuterClass.Literal.newBuilder() if (value == null) { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ded5bc5c5a..0700b7aabe 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -118,11 +118,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) withParquetTable(path.toString, "tbl") { - val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl" + val sqlString = + "SELECT _4 + null, _15 - null, _16 * null, cast(null as struct<_1:int>) FROM tbl" val df2 = sql(sqlString) val rows = df2.collect() assert(rows.length == batchSize) - assert(rows.forall(_ == Row(null, null, null))) + assert(rows.forall(_ == Row(null, null, null, null))) checkSparkAnswerAndOperator(sqlString) }