Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ use datafusion_comet_spark_expr::{
Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike,
SecondExpr, TimestampTruncExpr,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
Expand Down Expand Up @@ -300,6 +301,7 @@ impl PhysicalPlanner {
}
DataType::Binary => ScalarValue::Binary(None),
DataType::Decimal128(p, s) => ScalarValue::Decimal128(None, p, s),
DataType::Struct(fields) => ScalarStructBuilder::new_null(fields),
DataType::Null => ScalarValue::Null,
dt => {
return Err(ExecutionError::GeneralError(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case Literal(value, dataType) if supportedDataType(dataType) =>
case Literal(value, dataType)
if supportedDataType(dataType, allowStruct = value == null) =>
val exprBuilder = ExprOuterClass.Literal.newBuilder()

if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize)
withParquetTable(path.toString, "tbl") {
val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl"
val sqlString =
"SELECT _4 + null, _15 - null, _16 * null, cast(null as struct<_1:int>) FROM tbl"
val df2 = sql(sqlString)
val rows = df2.collect()
assert(rows.length == batchSize)
assert(rows.forall(_ == Row(null, null, null)))
assert(rows.forall(_ == Row(null, null, null, null)))

checkSparkAnswerAndOperator(sqlString)
}
Expand Down