diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b552a071d6..cdaf3d5f88 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,6 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -134,7 +135,34 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] { val arrayContainsScalarExpr = scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto) - optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*) + + // Handle NULL array input - return NULL if array is NULL (matching Spark's behavior) + val isNotNullExpr = createUnaryExpr( + expr, + expr.children.head, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + + val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty) + + if (arrayContainsScalarExpr.isDefined && isNotNullExpr.isDefined && + nullLiteralProto.isDefined) { + val caseWhenExpr = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(isNotNullExpr.get) + .addThen(arrayContainsScalarExpr.get) + .setElseExpr(nullLiteralProto.get) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(caseWhenExpr) + .build()) + } else { + withInfo(expr, expr.children: _*) + None + } } } @@ -395,6 +423,15 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] { inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { val children = expr.children + + // Handle empty array: return literal directly to avoid DataFusion coerce_types bug + // when make_array is called with 0 arguments (issue #3338) + if (children.isEmpty) { + val emptyArrayLiteral = + Literal.create(new GenericArrayData(Array.empty[Any]), expr.dataType) + return exprToProtoInternal(emptyArrayLiteral, inputs, binding) + } + val childExprs = children.map(exprToProtoInternal(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql b/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql index 86ad0cc488..cdbe3e68c2 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql @@ -35,5 +35,24 @@ query spark_answer_only SELECT array_contains(array(1, 2, 3), val) FROM test_array_contains -- literal + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3345) +-- Note: array_contains(array(), 1) still has a bug (issue #3346) so we use spark_answer_only +-- The NULL array case (cast(NULL as array)) was fixed in issue #3345 +query spark_answer_only SELECT array_contains(array(1, 2, 3), 2), array_contains(array(1, 2, 3), 4), array_contains(array(), 1), array_contains(cast(NULL as array), 1) + +-- Additional NULL array tests (issue #3345 fix verification) +-- NULL array with integer value +query +SELECT array_contains(cast(NULL as array), 1) + +-- NULL array with string value +query +SELECT array_contains(cast(NULL as array), 'test') + +-- NULL array with NULL value +query +SELECT array_contains(cast(NULL as array), cast(NULL as int)) + +-- NULL array with column value +query +SELECT array_contains(cast(NULL as array), val) FROM test_array_contains diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index cf49117364..b22d0f72db 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -325,6 +325,38 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_contains - NULL array returns NULL") { + // Test that array_contains returns NULL when the array argument is NULL + // This matches Spark's SQL three-valued logic behavior + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Test NULL array with non-null value + checkSparkAnswerAndOperator( + sql("SELECT array_contains(cast(null as array), 1) FROM t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_contains(cast(null as array), 'test') FROM t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_contains(cast(null as array), 1.5) FROM t1")) + + // Test NULL array with NULL value + checkSparkAnswerAndOperator( + sql("SELECT array_contains(cast(null as array), cast(null as int)) FROM t1")) + + // Test NULL array with column value + checkSparkAnswerAndOperator( + sql("SELECT array_contains(cast(null as array), _2) FROM t1")) + + // Test non-null array with values (to ensure fix doesn't break normal operation) + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 2) FROM t1")) + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 5) FROM t1")) + } + } + } + test("array_contains - test all types (convert from Parquet)") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet")