diff --git a/native/Cargo.lock b/native/Cargo.lock index a3f6f6d30c..3692f04883 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -346,13 +346,13 @@ checksum = "0c24e9d990669fbd16806bff449e4ac644fd9b1fca014760087732fe4102f131" [[package]] name = "async-trait" -version = "0.1.81" +version = "0.1.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -463,9 +463,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" [[package]] name = "byteorder" @@ -487,9 +487,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.14" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932" +checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" dependencies = [ "jobserver", "libc", @@ -659,9 +659,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "core-foundation-sys" @@ -671,9 +671,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpp_demangle" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" +checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d" dependencies = [ "cfg-if", ] @@ -811,7 +811,7 @@ dependencies = [ [[package]] name = "datafusion" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -833,7 +833,6 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", - "datafusion-physical-expr-functions-aggregate", "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", @@ -860,7 +859,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow-schema", "async-trait", @@ -945,7 +944,6 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "datafusion-physical-plan", "num", "rand", "regex", @@ -956,7 +954,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -978,7 +976,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "log", "tokio", @@ -987,7 +985,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "chrono", @@ -1007,7 +1005,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1028,7 +1026,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "datafusion-common", @@ -1038,7 +1036,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-buffer", @@ -1064,7 +1062,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1084,7 +1082,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1097,7 +1095,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-array", @@ -1119,7 +1117,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "datafusion-common", "datafusion-expr", @@ -1130,7 +1128,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "async-trait", @@ -1149,7 +1147,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1180,7 +1178,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1190,25 +1188,10 @@ dependencies = [ "rand", ] -[[package]] -name = "datafusion-physical-expr-functions-aggregate" -version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" -dependencies = [ - "ahash", - "arrow", - "datafusion-common", - "datafusion-expr", - "datafusion-expr-common", - "datafusion-functions-aggregate-common", - "datafusion-physical-expr-common", - "rand", -] - [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "datafusion-common", "datafusion-execution", @@ -1220,7 +1203,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "ahash", "arrow", @@ -1238,7 +1221,6 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", - "datafusion-physical-expr-functions-aggregate", "futures", "half", "hashbrown", @@ -1255,7 +1237,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=dff590b#dff590bfd2bb9993b2c8ce6f76a3bdd973e520a8" +source = "git+https://github.com/apache/datafusion.git?rev=91b1d2b#91b1d2bfe8f603df94e846b91d8475a0af2e5240" dependencies = [ "arrow", "arrow-array", @@ -1448,7 +1430,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -1624,9 +1606,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown", @@ -2128,9 +2110,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.3" +version = "0.36.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" dependencies = [ "memchr", ] @@ -2435,7 +2417,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2556,9 +2538,9 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "rgb" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71" +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" dependencies = [ "bytemuck", ] @@ -2571,18 +2553,18 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f" dependencies = [ "bitflags 2.6.0", "errno", @@ -2657,7 +2639,7 @@ checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2775,7 +2757,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2815,7 +2797,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2826,9 +2808,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.10.0" +version = "12.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16629323a4ec5268ad23a575110a724ad4544aae623451de600c747bf87b36cf" +checksum = "9c1db5ac243c7d7f8439eb3b8f0357888b37cf3732957e91383b0ad61756374e" dependencies = [ "debugid", "memmap2", @@ -2838,9 +2820,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.10.0" +version = "12.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c043a45f08f41187414592b3ceb53fb0687da57209cc77401767fb69d5b596" +checksum = "ea26e430c27d4a8a5dea4c4b81440606c7c1a415bd611451ef6af8c81416afc3" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -2860,9 +2842,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.76" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -2899,7 +2881,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2959,9 +2941,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.3" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", @@ -2977,7 +2959,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -2999,7 +2981,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] @@ -3149,7 +3131,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", "wasm-bindgen-shared", ] @@ -3171,7 +3153,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3410,7 +3392,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.76", + "syn 2.0.77", ] [[package]] diff --git a/native/Cargo.toml b/native/Cargo.toml index 33711b1df5..54e568943c 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -39,14 +39,14 @@ arrow-buffer = { version = "52.2.0" } arrow-data = { version = "52.2.0" } arrow-schema = { version = "52.2.0" } parquet = { version = "52.2.0", default-features = false, features = ["experimental"] } -datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "dff590b" } -datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "dff590b", features = ["unicode_expressions", "crypto_expressions"] } -datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", features = ["crypto_expressions"] } -datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } -datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "dff590b", default-features = false } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b" } +datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", features = ["unicode_expressions", "crypto_expressions"] } +datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", features = ["crypto_expressions"] } +datafusion-functions-nested = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "91b1d2b", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.3.0" } datafusion-comet-proto = { path = "proto", version = "0.3.0" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/native/core/src/execution/datafusion/expressions/avg.rs b/native/core/src/execution/datafusion/expressions/avg.rs index 5e7b555c8a..7820497d46 100644 --- a/native/core/src/execution/datafusion/expressions/avg.rs +++ b/native/core/src/execution/datafusion/expressions/avg.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow::compute::sum; use arrow_array::{ builder::PrimitiveBuilder, @@ -25,20 +24,24 @@ use arrow_array::{ }; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{ - type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, + type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, Signature, }; use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; - +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; use DataType::*; /// AVG aggregate expression #[derive(Debug, Clone)] pub struct Avg { name: String, + signature: Signature, expr: Arc, input_data_type: DataType, result_data_type: DataType, @@ -51,6 +54,7 @@ impl Avg { Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, input_data_type: data_type, result_data_type, @@ -58,17 +62,13 @@ impl Avg { } } -impl AggregateExpr for Avg { +impl AggregateUDFImpl for Avg { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::::default()), @@ -80,7 +80,7 @@ impl AggregateExpr for Avg { } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "sum"), @@ -95,19 +95,22 @@ impl AggregateExpr for Avg { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn groups_accumulator_supported(&self) -> bool { + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( @@ -126,6 +129,14 @@ impl AggregateExpr for Avg { fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } } impl PartialEq for Avg { diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index b035d32449..0462f2d3d5 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum}; use arrow_array::{ builder::PrimitiveBuilder, @@ -24,16 +23,20 @@ use arrow_array::{ Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray, }; use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature}; use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; use arrow_data::decimal::{ validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; - +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::avg_return_type; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; use num::{integer::div_ceil, Integer}; use DataType::*; @@ -41,6 +44,7 @@ use DataType::*; #[derive(Debug, Clone)] pub struct AvgDecimal { name: String, + signature: Signature, expr: Arc, sum_data_type: DataType, result_data_type: DataType, @@ -56,6 +60,7 @@ impl AvgDecimal { ) -> Self { Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, result_data_type: result_type, sum_data_type: sum_type, @@ -63,17 +68,13 @@ impl AvgDecimal { } } -impl AggregateExpr for AvgDecimal { +impl AggregateUDFImpl for AvgDecimal { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { match (&self.sum_data_type, &self.result_data_type) { (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { Ok(Box::new(AvgDecimalAccumulator::new( @@ -91,7 +92,7 @@ impl AggregateExpr for AvgDecimal { } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "sum"), @@ -106,23 +107,22 @@ impl AggregateExpr for AvgDecimal { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn reverse_expr(&self) -> Option> { - None + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { // instantiate specialized accumulator based for the type match (&self.sum_data_type, &self.result_data_type) { (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { @@ -154,6 +154,14 @@ impl AggregateExpr for AvgDecimal { ), } } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } } impl PartialEq for AvgDecimal { diff --git a/native/core/src/execution/datafusion/expressions/bitwise_not.rs b/native/core/src/execution/datafusion/expressions/bitwise_not.rs index c7b2bc0677..a2b9ebe5b5 100644 --- a/native/core/src/execution/datafusion/expressions/bitwise_not.rs +++ b/native/core/src/execution/datafusion/expressions/bitwise_not.rs @@ -26,12 +26,11 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; use datafusion_common::{Result, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - macro_rules! compute_op { ($OPERAND:expr, $DT:ident) => {{ let operand = $OPERAND diff --git a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs index b66ca5b2cb..462a22247f 100644 --- a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs +++ b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -21,9 +21,10 @@ use crate::{ use arrow::record_batch::RecordBatch; use arrow_array::cast::as_primitive_array; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::physical_plan::ColumnarValue; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr}; +use datafusion_physical_expr::PhysicalExpr; use std::{ any::Any, fmt::Display, diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index ff2cffd428..e922171bd2 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -29,11 +29,10 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - /// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals /// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds /// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the diff --git a/native/core/src/execution/datafusion/expressions/correlation.rs b/native/core/src/execution/datafusion/expressions/correlation.rs index 3dcf6cca8b..6bf35e7115 100644 --- a/native/core/src/execution/datafusion/expressions/correlation.rs +++ b/native/core/src/execution/datafusion/expressions/correlation.rs @@ -20,16 +20,20 @@ use arrow::compute::{and, filter, is_not_null}; use std::{any::Any, sync::Arc}; use crate::execution::datafusion::expressions::{ - covariance::CovarianceAccumulator, stats::StatsType, stddev::StddevAccumulator, - utils::down_cast_any_ref, + covariance::CovarianceAccumulator, stddev::StddevAccumulator, }; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// CORR aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -39,6 +43,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Correlation { name: String, + signature: Signature, expr1: Arc, expr2: Arc, null_on_divide_by_zero: bool, @@ -56,6 +61,7 @@ impl Correlation { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), expr1, expr2, null_on_divide_by_zero, @@ -63,23 +69,34 @@ impl Correlation { } } -impl AggregateExpr for Correlation { +impl AggregateUDFImpl for Correlation { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(CorrelationAccumulator::try_new( self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -113,18 +130,6 @@ impl AggregateExpr for Correlation { ), ]) } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr1), Arc::clone(&self.expr2)] - } - - fn name(&self) -> &str { - &self.name - } - - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } } impl PartialEq for Correlation { diff --git a/native/core/src/execution/datafusion/expressions/covariance.rs b/native/core/src/execution/datafusion/expressions/covariance.rs index 11b345a22b..9166e39766 100644 --- a/native/core/src/execution/datafusion/expressions/covariance.rs +++ b/native/core/src/execution/datafusion/expressions/covariance.rs @@ -19,20 +19,21 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::stats::StatsType; use arrow::{ array::{ArrayRef, Float64Array}, compute::cast, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_physical_expr::{ - aggregate::utils::down_cast_any_ref, expressions::format_state_name, AggregateExpr, - PhysicalExpr, -}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// COVAR_SAMP and COVAR_POP aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -41,6 +42,7 @@ use datafusion_physical_expr::{ #[derive(Debug, Clone)] pub struct Covariance { name: String, + signature: Signature, expr1: Arc, expr2: Arc, stats_type: StatsType, @@ -61,6 +63,7 @@ impl Covariance { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), expr1, expr2, stats_type, @@ -69,24 +72,35 @@ impl Covariance { } } -impl AggregateExpr for Covariance { +impl AggregateUDFImpl for Covariance { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(CovarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -110,18 +124,6 @@ impl AggregateExpr for Covariance { ), ]) } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr1), Arc::clone(&self.expr2)] - } - - fn name(&self) -> &str { - &self.name - } - - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } } impl PartialEq for Covariance { diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index 4435f6b690..10c9d30920 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -30,13 +30,11 @@ pub mod comet_scalar_funcs; pub mod correlation; pub mod covariance; pub mod negative; -pub mod stats; pub mod stddev; pub mod strings; pub mod subquery; pub mod sum_decimal; pub mod unbound; -mod utils; pub mod variance; pub use datafusion_comet_spark_expr::{EvalMode, SparkError}; diff --git a/native/core/src/execution/datafusion/expressions/negative.rs b/native/core/src/execution/datafusion/expressions/negative.rs index fbcd194f08..8dfe717422 100644 --- a/native/core/src/execution/datafusion/expressions/negative.rs +++ b/native/core/src/execution/datafusion/expressions/negative.rs @@ -21,6 +21,7 @@ use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeT use arrow_array::RecordBatch; use arrow_buffer::IntervalDayTime; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::{ logical_expr::{interval_arithmetic::Interval, ColumnarValue}, physical_expr::PhysicalExpr, @@ -28,7 +29,6 @@ use datafusion::{ use datafusion_comet_spark_expr::SparkError; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::sort_properties::ExprProperties; -use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::{ any::Any, hash::{Hash, Hasher}, diff --git a/native/core/src/execution/datafusion/expressions/normalize_nan.rs b/native/core/src/execution/datafusion/expressions/normalize_nan.rs index d2192feef7..c5331ad7bd 100644 --- a/native/core/src/execution/datafusion/expressions/normalize_nan.rs +++ b/native/core/src/execution/datafusion/expressions/normalize_nan.rs @@ -29,10 +29,9 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; - #[derive(Debug, Hash)] pub struct NormalizeNaNAndZero { pub data_type: DataType, diff --git a/native/core/src/execution/datafusion/expressions/stats.rs b/native/core/src/execution/datafusion/expressions/stats.rs deleted file mode 100644 index 1f4e64d0b4..0000000000 --- a/native/core/src/execution/datafusion/expressions/stats.rs +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/// Enum used for differentiating population and sample for statistical functions -#[derive(PartialEq, Eq, Debug, Clone, Copy)] -pub enum StatsType { - /// Population - Population, - /// Sample - Sample, -} diff --git a/native/core/src/execution/datafusion/expressions/stddev.rs b/native/core/src/execution/datafusion/expressions/stddev.rs index bc96a56808..1ba495e215 100644 --- a/native/core/src/execution/datafusion/expressions/stddev.rs +++ b/native/core/src/execution/datafusion/expressions/stddev.rs @@ -17,16 +17,18 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::{ - stats::StatsType, utils::down_cast_any_ref, variance::VarianceAccumulator, -}; +use crate::execution::datafusion::expressions::variance::VarianceAccumulator; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::StatsType; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -36,6 +38,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Stddev { name: String, + signature: Signature, expr: Arc, stats_type: StatsType, null_on_divide_by_zero: bool, @@ -54,6 +57,7 @@ impl Stddev { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable), expr, stats_type, null_on_divide_by_zero, @@ -61,31 +65,42 @@ impl Stddev { } } -impl AggregateExpr for Stddev { +impl AggregateUDFImpl for Stddev { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(StddevAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(StddevAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -101,14 +116,6 @@ impl AggregateExpr for Stddev { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } diff --git a/native/core/src/execution/datafusion/expressions/strings.rs b/native/core/src/execution/datafusion/expressions/strings.rs index 96e45eae2e..200b4ec5a3 100644 --- a/native/core/src/execution/datafusion/expressions/strings.rs +++ b/native/core/src/execution/datafusion/expressions/strings.rs @@ -17,10 +17,7 @@ #![allow(deprecated)] -use crate::execution::{ - datafusion::expressions::utils::down_cast_any_ref, - kernels::strings::{string_space, substring}, -}; +use crate::execution::kernels::strings::{string_space, substring}; use arrow::{ compute::{ contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn, @@ -30,6 +27,7 @@ use arrow::{ }; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue::Utf8}; use datafusion_physical_expr::PhysicalExpr; use std::{ diff --git a/native/core/src/execution/datafusion/expressions/subquery.rs b/native/core/src/execution/datafusion/expressions/subquery.rs index cf6f8d846d..3eeb29c16e 100644 --- a/native/core/src/execution/datafusion/expressions/subquery.rs +++ b/native/core/src/execution/datafusion/expressions/subquery.rs @@ -15,9 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::{ + execution::utils::bytes_to_i128, + jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, +}; use arrow_array::RecordBatch; use arrow_schema::{DataType, Schema, TimeUnit}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use jni::{ @@ -31,11 +36,6 @@ use std::{ sync::Arc, }; -use crate::{ - execution::{datafusion::expressions::utils::down_cast_any_ref, utils::bytes_to_i128}, - jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, -}; - #[derive(Debug, Hash)] pub struct Subquery { /// The ID of the execution context that owns this subquery. We use this ID to retrieve the diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 37030b67a8..e957bd25e2 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::unlikely; use arrow::{ array::BooleanBufferBuilder, buffer::{BooleanBuffer, NullBuffer}, @@ -25,15 +26,18 @@ use arrow_array::{ use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{Result as DFResult, ScalarValue}; -use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; +use datafusion_physical_expr::PhysicalExpr; use std::{any::Any, ops::BitAnd, sync::Arc}; -use crate::unlikely; - #[derive(Debug)] pub struct SumDecimal { name: String, + signature: Signature, expr: Arc, /// The data type of the SUM result @@ -56,6 +60,7 @@ impl SumDecimal { }; Self { name: name.into(), + signature: Signature::user_defined(Immutable), expr, result_type: data_type, precision, @@ -65,27 +70,19 @@ impl SumDecimal { } } -impl AggregateExpr for SumDecimal { +impl AggregateUDFImpl for SumDecimal { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> DFResult { - Ok(Field::new( - &self.name, - self.result_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> DFResult> { + fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { Ok(Box::new(SumDecimalAccumulator::new( self.precision, self.scale, ))) } - fn state_fields(&self) -> DFResult> { + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ Field::new(&self.name, self.result_type.clone(), self.nullable), Field::new("is_empty", DataType::Boolean, false), @@ -93,19 +90,26 @@ impl AggregateExpr for SumDecimal { Ok(fields) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - fn name(&self) -> &str { &self.name } - fn groups_accumulator_supported(&self) -> bool { + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(self.result_type.clone()) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> DFResult> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { Ok(Box::new(SumDecimalGroupsAccumulator::new( self.result_type.clone(), self.precision, @@ -119,6 +123,10 @@ impl AggregateExpr for SumDecimal { &DataType::Decimal128(self.precision, self.scale), ) } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } } impl PartialEq for SumDecimal { diff --git a/native/core/src/execution/datafusion/expressions/unbound.rs b/native/core/src/execution/datafusion/expressions/unbound.rs index 95f9912c98..a6babd0f7e 100644 --- a/native/core/src/execution/datafusion/expressions/unbound.rs +++ b/native/core/src/execution/datafusion/expressions/unbound.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use arrow_array::RecordBatch; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion::physical_plan::ColumnarValue; use datafusion_common::{internal_err, Result}; use datafusion_physical_expr::PhysicalExpr; diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs deleted file mode 100644 index 540fca86b1..0000000000 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ /dev/null @@ -1,19 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// re-export for legacy reasons -pub use datafusion_comet_spark_expr::utils::down_cast_any_ref; diff --git a/native/core/src/execution/datafusion/expressions/variance.rs b/native/core/src/execution/datafusion/expressions/variance.rs index 5cfbf29471..2f4d8091c2 100644 --- a/native/core/src/execution/datafusion/expressions/variance.rs +++ b/native/core/src/execution/datafusion/expressions/variance.rs @@ -17,14 +17,18 @@ use std::{any::Any, sync::Arc}; -use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; use arrow::{ array::{ArrayRef, Float64Array}, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, Signature}; +use datafusion_physical_expr::expressions::StatsType; +use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; /// VAR_SAMP and VAR_POP aggregate expression /// The implementation mostly is the same as the DataFusion's implementation. The reason @@ -34,6 +38,7 @@ use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, Ph #[derive(Debug)] pub struct Variance { name: String, + signature: Signature, expr: Arc, stats_type: StatsType, null_on_divide_by_zero: bool, @@ -52,6 +57,7 @@ impl Variance { assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), + signature: Signature::numeric(1, Immutable), expr, stats_type, null_on_divide_by_zero, @@ -59,31 +65,39 @@ impl Variance { } } -impl AggregateExpr for Variance { +impl AggregateUDFImpl for Variance { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result> { Ok(Box::new(VarianceAccumulator::try_new( self.stats_type, self.null_on_divide_by_zero, )?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(&self.name, "count"), @@ -99,14 +113,6 @@ impl AggregateExpr for Variance { ]) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Float64(None)) } diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b9b8828241..be27495ab0 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -33,7 +33,6 @@ use crate::{ correlation::Correlation, covariance::Covariance, negative, - stats::StatsType, stddev::Stddev, strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr}, subquery::Subquery, @@ -55,7 +54,6 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; -use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ @@ -70,7 +68,7 @@ use datafusion::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal as DataFusionLiteral, NotExpr, }, - AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, + PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, physical_optimizer::join_selection::swap_hash_join, physical_plan::{ @@ -83,6 +81,8 @@ use datafusion::{ }, prelude::SessionContext, }; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; + use datafusion_comet_proto::{ spark_expression::{ self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, @@ -105,8 +105,10 @@ use datafusion_common::{ JoinType as DFJoinType, ScalarValue, }; use datafusion_expr::expr::find_df_window_func; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; -use datafusion_physical_expr::expressions::Literal; +use datafusion_expr::{ + AggregateUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; +use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use itertools::Itertools; use jni::objects::GlobalRef; @@ -116,7 +118,7 @@ use std::{collections::HashMap, sync::Arc}; // For clippy error on type_complexity. type ExecResult = Result; -type PhyAggResult = Result>, ExecutionError>; +type PhyAggResult = Result>, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; @@ -1275,7 +1277,7 @@ impl PhysicalPlanner { &self, spark_expr: &AggExpr, schema: SchemaRef, - ) -> Result, ExecutionError> { + ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); @@ -1347,12 +1349,24 @@ impl PhysicalPlanner { match datatype { DataType::Decimal128(_, _) => { - Ok(Arc::new(SumDecimal::new("sum", child, datatype))) + let func = AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + Arc::clone(&child), + datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); + let child = + Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); AggregateExprBuilder::new(sum_udaf(), vec![child]) .schema(schema) @@ -1369,24 +1383,45 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); match datatype { - DataType::Decimal128(_, _) => Ok(Arc::new(AvgDecimal::new( - child, - "avg", - datatype, - input_datatype, - ))), + DataType::Decimal128(_, _) => { + let func = AggregateUDF::new_from_impl(AvgDecimal::new( + Arc::clone(&child), + "avg", + datatype, + input_datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } _ => { // cast to the result data type of AVG if the result data type is different // from the input type, e.g. AVG(Int32). We should not expect a cast // failure since it should have already been checked at Spark side. - let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); - Ok(Arc::new(Avg::new(child, "avg", datatype))) + let child: Arc = + Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); + let func = AggregateUDF::new_from_impl(Avg::new( + Arc::clone(&child), + "avg", + datatype, + )); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } } } AggExprStruct::First(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); + let func = AggregateUDF::new_from_impl(FirstValue::new()); AggregateExprBuilder::new(Arc::new(func), vec![child]) .schema(schema) @@ -1398,7 +1433,7 @@ impl PhysicalPlanner { } AggExprStruct::Last(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; - let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); + let func = AggregateUDF::new_from_impl(LastValue::new()); AggregateExprBuilder::new(Arc::new(func), vec![child]) .schema(schema) @@ -1448,22 +1483,40 @@ impl PhysicalPlanner { self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Covariance::new( - child1, - child2, - "covariance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Covariance::new( - child1, - child2, - "covariance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = AggregateUDF::new_from_impl(Covariance::new( + Arc::clone(&child1), + Arc::clone(&child2), + "covariance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr( + "covariance", + schema, + vec![child1, child2], + func, + ) + } + 1 => { + let func = AggregateUDF::new_from_impl(Covariance::new( + Arc::clone(&child1), + Arc::clone(&child2), + "covariance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr( + "covariance_pop", + schema, + vec![child1, child2], + func, + ) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", stats_type @@ -1474,20 +1527,28 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Variance::new( - child, - "variance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Variance::new( - child, - "variance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = AggregateUDF::new_from_impl(Variance::new( + Arc::clone(&child), + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr("variance", schema, vec![child], func) + } + 1 => { + let func = AggregateUDF::new_from_impl(Variance::new( + Arc::clone(&child), + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr("variance_pop", schema, vec![child], func) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", stats_type @@ -1498,20 +1559,28 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); match expr.stats_type { - 0 => Ok(Arc::new(Stddev::new( - child, - "stddev", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))), - 1 => Ok(Arc::new(Stddev::new( - child, - "stddev_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))), + 0 => { + let func = AggregateUDF::new_from_impl(Stddev::new( + Arc::clone(&child), + "stddev", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr("stddev", schema, vec![child], func) + } + 1 => { + let func = AggregateUDF::new_from_impl(Stddev::new( + Arc::clone(&child), + "stddev_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + )); + + Self::create_aggr_func_expr("stddev_pop", schema, vec![child], func) + } stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for stddev", stats_type @@ -1524,13 +1593,14 @@ impl PhysicalPlanner { let child2 = self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Correlation::new( - child1, - child2, + let func = AggregateUDF::new_from_impl(Correlation::new( + Arc::clone(&child1), + Arc::clone(&child2), "correlation", datatype, expr.null_on_divide_by_zero, - ))) + )); + Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func) } } } @@ -1791,6 +1861,21 @@ impl PhysicalPlanner { Ok(scalar_expr) } + + fn create_aggr_func_expr( + name: &str, + schema: SchemaRef, + children: Vec>, + func: AggregateUDF, + ) -> Result, ExecutionError> { + AggregateExprBuilder::new(Arc::new(func), children) + .schema(schema) + .alias(name) + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) + } } impl From for ExecutionError { diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1cb790efcb..2d99a854d3 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -240,7 +240,7 @@ fn prepare_datafusion_session_context( // this is the threshold of number of groups / number of rows and the // maximum value is 1.0, so we set the threshold a little higher just // to be safe - ScalarValue::Float64(Some(1.1)), + &ScalarValue::Float64(Some(1.1)), ); for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) { diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 0a371a6e61..a5d1569129 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -35,7 +35,6 @@ datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -datafusion-physical-plan = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 12dbfdcdc7..6a3974fe1e 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -50,6 +50,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use num::{ cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive, @@ -57,7 +58,7 @@ use num::{ use regex::Regex; use crate::timezone; -use crate::utils::{array_with_timezone, down_cast_any_ref}; +use crate::utils::array_with_timezone; use crate::{EvalMode, SparkError, SparkResult}; diff --git a/native/spark-expr/src/if_expr.rs b/native/spark-expr/src/if_expr.rs index 9a90b727ce..193a90fb55 100644 --- a/native/spark-expr/src/if_expr.rs +++ b/native/spark-expr/src/if_expr.rs @@ -26,11 +26,10 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::Result; use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr}; -use crate::utils::down_cast_any_ref; - /// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to /// `CASE WHEN a THEN b ELSE c END`. #[derive(Debug, Hash)] diff --git a/native/spark-expr/src/regexp.rs b/native/spark-expr/src/regexp.rs index 221fd1f047..c7626285a2 100644 --- a/native/spark-expr/src/regexp.rs +++ b/native/spark-expr/src/regexp.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::down_cast_any_ref; use crate::SparkError; use arrow::compute::take; use arrow_array::builder::BooleanBuilder; use arrow_array::types::Int32Type; use arrow_array::{Array, BooleanArray, DictionaryArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; diff --git a/native/spark-expr/src/structs.rs b/native/spark-expr/src/structs.rs index 49017b6710..cda8246d90 100644 --- a/native/spark-expr/src/structs.rs +++ b/native/spark-expr/src/structs.rs @@ -19,6 +19,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::{Array, StructArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use std::{ @@ -28,8 +29,6 @@ use std::{ sync::Arc, }; -use crate::utils::down_cast_any_ref; - #[derive(Debug, Hash)] pub struct CreateNamedStruct { values: Vec>, diff --git a/native/spark-expr/src/temporal.rs b/native/spark-expr/src/temporal.rs index 415db6070a..91953dd600 100644 --- a/native/spark-expr/src/temporal.rs +++ b/native/spark-expr/src/temporal.rs @@ -28,10 +28,11 @@ use arrow::{ }; use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue::Utf8}; use datafusion_physical_expr::PhysicalExpr; -use crate::utils::{array_with_timezone, down_cast_any_ref}; +use crate::utils::array_with_timezone; use crate::kernels::temporal::{ date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs index 6945e82b3e..db4ad1956a 100644 --- a/native/spark-expr/src/utils.rs +++ b/native/spark-expr/src/utils.rs @@ -20,7 +20,6 @@ use arrow_array::{ types::{Int32Type, TimestampMicrosecondType}, }; use arrow_schema::{ArrowError, DataType}; -use std::any::Any; use std::sync::Arc; use crate::timezone::Tz; @@ -30,23 +29,6 @@ use arrow::{ }; use chrono::{DateTime, Offset, TimeZone}; -use datafusion_physical_plan::PhysicalExpr; - -/// A utility function from DataFusion. It is not exposed by DataFusion. -pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else { - any - } -} - /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or /// to apply timezone offset. //