diff --git a/native/Cargo.lock b/native/Cargo.lock index 7705d10251..5d8dce283b 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -882,6 +882,7 @@ dependencies = [ "datafusion-comet-spark-expr", "datafusion-common", "datafusion-expr", + "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-common", "flate2", @@ -1053,6 +1054,27 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-functions-nested" +version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +dependencies = [ + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "itertools 0.12.1", + "log", + "paste", + "rand", +] + [[package]] name = "datafusion-optimizer" version = "41.0.0" @@ -2323,7 +2345,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.72", diff --git a/native/Cargo.toml b/native/Cargo.toml index 68dd51aa74..a41934fe3e 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -42,6 +42,7 @@ parquet = { version = "52.2.0", default-features = false, features = ["experimen datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1" } datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", features = ["unicode_expressions", "crypto_expressions"] } datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", features = ["crypto_expressions"] } +datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", default-features = false } datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", default-features = false } datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", default-features = false } datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion.git", rev = "41.0.0-rc1", default-features = false } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 5b5c237ce1..55c46e1d6c 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -67,6 +67,7 @@ itertools = "0.11.0" paste = "1.0.14" datafusion-common = { workspace = true } datafusion = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index a29b83803c..6d61f4a4d6 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -236,10 +236,11 @@ fn prepare_datafusion_session_context( let runtime = RuntimeEnv::new(rt_config).unwrap(); - Ok(SessionContext::new_with_config_rt( - session_config, - Arc::new(runtime), - )) + let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime)); + + datafusion_functions_nested::register_all(&mut session_ctx)?; + + Ok(session_ctx) } fn parse_bool(conf: &HashMap, name: &str) -> CometResult { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8b994d623b..ca17f6f0b4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2361,6 +2361,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .build() } + // datafusion's make_array only supports nullable element types + // https://github.com/apache/datafusion/issues/11923 + case array @ CreateArray(children, _) if array.dataType.containsNull => + val childExprs = children.map(exprToProto(_, inputs, binding)) + + if (childExprs.forall(_.isDefined)) { + scalarExprToProtoWithReturnType("make_array", array.dataType, childExprs: _*) + } else { + withInfo(expr, "unsupported arguments for CreateArray", children: _*) + None + } + case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8eeb998bd1..b9c5c3ca85 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1998,4 +1998,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("CreateArray") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + val df = spark.read.parquet(path.toString) + checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) + checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) + } + } + } }