diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index b4d723eb1f..3cdf799c28 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -17,9 +17,39 @@ //! Converts Spark physical plan to DataFusion physical plan +use super::expressions::EvalMode; +use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; +use crate::{ + errors::ExpressionError, + execution::{ + datafusion::{ + expressions::{ + avg::Avg, + avg_decimal::AvgDecimal, + bitwise_not::BitwiseNotExpr, + bloom_filter_might_contain::BloomFilterMightContain, + checkoverflow::CheckOverflow, + correlation::Correlation, + covariance::Covariance, + negative, + stats::StatsType, + stddev::Stddev, + strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr}, + subquery::Subquery, + sum_decimal::SumDecimal, + unbound::UnboundColumn, + variance::Variance, + NormalizeNaNAndZero, + }, + operators::expand::CometExpandExec, + shuffle_writer::ShuffleWriterExec, + }, + operators::{CopyExec, ExecutionError, ScanExec}, + serde::to_arrow_datatype, + }, +}; use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; -use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; @@ -49,53 +79,6 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, - JoinType as DFJoinType, ScalarValue, -}; -use datafusion_expr::expr::find_df_window_func; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; -use datafusion_physical_expr::window::WindowExpr; -use datafusion_physical_expr_common::aggregate::create_aggregate_expr; -use datafusion_physical_expr_common::expressions::Literal; -use itertools::Itertools; -use jni::objects::GlobalRef; -use num::{BigInt, ToPrimitive}; -use std::cmp::max; -use std::{collections::HashMap, sync::Arc}; - -use crate::{ - errors::ExpressionError, - execution::{ - datafusion::{ - expressions::{ - avg::Avg, - avg_decimal::AvgDecimal, - bitwise_not::BitwiseNotExpr, - bloom_filter_might_contain::BloomFilterMightContain, - checkoverflow::CheckOverflow, - correlation::Correlation, - covariance::Covariance, - negative, - stats::StatsType, - stddev::Stddev, - strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr}, - subquery::Subquery, - sum_decimal::SumDecimal, - unbound::UnboundColumn, - variance::Variance, - NormalizeNaNAndZero, - }, - operators::expand::CometExpandExec, - shuffle_writer::ShuffleWriterExec, - }, - operators::{CopyExec, ExecutionError, ScanExec}, - serde::to_arrow_datatype, - }, -}; - -use super::expressions::EvalMode; -use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; use datafusion_comet_proto::{ spark_expression::{ self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, @@ -112,6 +95,20 @@ use datafusion_comet_spark_expr::{ Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, }; +use datafusion_common::{ + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, + JoinType as DFJoinType, ScalarValue, +}; +use datafusion_expr::expr::find_df_window_func; +use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; +use datafusion_physical_expr::window::WindowExpr; +use datafusion_physical_expr_common::aggregate::create_aggregate_expr; +use datafusion_physical_expr_common::expressions::Literal; +use itertools::Itertools; +use jni::objects::GlobalRef; +use num::{BigInt, ToPrimitive}; +use std::cmp::max; +use std::{collections::HashMap, sync::Arc}; // For clippy error on type_complexity. type ExecResult = Result; @@ -1234,15 +1231,38 @@ impl PhysicalPlanner { ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { + assert!(!expr.children.is_empty()); + // Using `count_udaf` from Comet is exceptionally slow for some reason, so + // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` + // https://github.com/apache/datafusion-comet/issues/744 + let children = expr .children .iter() .map(|child| self.create_expr(child, schema.clone())) .collect::, _>>()?; + // create `IS NOT NULL expr` and join them with `AND` if there are multiple + let not_null_expr: Arc = children.iter().skip(1).fold( + Arc::new(IsNotNullExpr::new(children[0].clone())) as Arc, + |acc, child| { + Arc::new(BinaryExpr::new( + acc, + DataFusionOperator::And, + Arc::new(IsNotNullExpr::new(child.clone())), + )) + }, + ); + + let child = Arc::new(IfExpr::new( + not_null_expr, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + create_aggregate_expr( - &count_udaf(), - &children, + &sum_udaf(), + &[child], &[], &[], &[], diff --git a/spark/benchmarks/CometAggregateBenchmark-jdk11-results.txt b/spark/benchmarks/CometAggregateBenchmark-jdk11-results.txt deleted file mode 100644 index 9e3e15bc67..0000000000 --- a/spark/benchmarks/CometAggregateBenchmark-jdk11-results.txt +++ /dev/null @@ -1,24 +0,0 @@ -================================================================================================ -Grouped Aggregate (single group key + single aggregate SUM) -================================================================================================ - -OpenJDK 64-Bit Server VM 11.0.24+8-post-Ubuntu-1ubuntu322.04 on Linux 6.5.0-41-generic -AMD Ryzen 9 7950X3D 16-Core Processor -Grouped HashAgg Exec: single group key (cardinality 1048576), single aggregate SUM: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------------------------------- -SQL Parquet - Spark (SUM) 2663 2744 115 3.9 254.0 1.0X -SQL Parquet - Comet (Scan, Exec) (SUM) 1067 1084 24 9.8 101.8 2.5X - - -================================================================================================ -Grouped Aggregate (single group key + single aggregate COUNT) -================================================================================================ - -OpenJDK 64-Bit Server VM 11.0.24+8-post-Ubuntu-1ubuntu322.04 on Linux 6.5.0-41-generic -AMD Ryzen 9 7950X3D 16-Core Processor -Grouped HashAgg Exec: single group key (cardinality 1048576), single aggregate COUNT: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------------------------------- -SQL Parquet - Spark (COUNT) 2532 2552 28 4.1 241.5 1.0X -SQL Parquet - Comet (Scan, Exec) (COUNT) 4590 4592 4 2.3 437.7 0.6X - -