diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index d726b9fad1256..d40709a467cf3 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -22,9 +22,10 @@ use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray use arrow::datatypes::Field; use arrow::datatypes::{ ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, FieldRef, Float64Type, Int64Type, UInt64Type, - DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL32_MAX_PRECISION, - DECIMAL64_MAX_PRECISION, + Decimal64Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, FieldRef, Float64Type, Int64Type, + TimeUnit, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use datafusion_common::types::{ logical_float64, logical_int16, logical_int32, logical_int64, logical_int8, @@ -93,6 +94,27 @@ macro_rules! downcast_sum { DataType::Decimal256(_, _) => { $helper!(Decimal256Type, $args.return_field.data_type().clone()) } + DataType::Duration(TimeUnit::Second) => { + $helper!(DurationSecondType, $args.return_field.data_type().clone()) + } + DataType::Duration(TimeUnit::Millisecond) => { + $helper!( + DurationMillisecondType, + $args.return_field.data_type().clone() + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + $helper!( + DurationMicrosecondType, + $args.return_field.data_type().clone() + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + $helper!( + DurationNanosecondType, + $args.return_field.data_type().clone() + ) + } _ => { not_impl_err!( "Sum not supported for {}: {}", @@ -159,6 +181,9 @@ impl Sum { vec![TypeSignatureClass::Float], NativeType::Float64, )]), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Duration, + )]), ], Volatility::Immutable, ), @@ -208,6 +233,7 @@ impl AggregateUDFImpl for Sum { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), other => { exec_err!("[return_type] SUM not supported for {}", other) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 83faae0db5956..a1b868b0b028f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4623,6 +4623,16 @@ SELECT max(column1), max(column2), max(column3), max(column4) FROM d; ---- 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 0.022 secs 0 days 0 hours 0 mins 0.000033 secs 0 days 0 hours 0 mins 0.000000044 secs +query ???? +SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs + +query ???? +SELECT sum(column1), sum(column2), sum(column3), sum(column4) FROM d; +---- +0 days 0 hours 0 mins 12 secs 0 days 0 hours 0 mins 0.024 secs 0 days 0 hours 0 mins 0.000036 secs 0 days 0 hours 0 mins 0.000000048 secs + # GROUP BY follows a different code path query ????I SELECT min(column1), min(column2), min(column3), min(column4), column5 FROM d GROUP BY column5; @@ -4634,6 +4644,16 @@ SELECT max(column1), max(column2), max(column3), max(column4), column5 FROM d GR ---- 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 0.022 secs 0 days 0 hours 0 mins 0.000033 secs 0 days 0 hours 0 mins 0.000000044 secs 1 +query ????I +SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1 + +query ????I +SELECT sum(column1), sum(column2), sum(column3), sum(column4), column5 FROM d GROUP BY column5; +---- +0 days 0 hours 0 mins 12 secs 0 days 0 hours 0 mins 0.024 secs 0 days 0 hours 0 mins 0.000036 secs 0 days 0 hours 0 mins 0.000000048 secs 1 + statement ok INSERT INTO d VALUES (arrow_cast(3, 'Duration(Second)'), arrow_cast(1, 'Duration(Millisecond)'), arrow_cast(7, 'Duration(Microsecond)'), arrow_cast(2, 'Duration(Nanosecond)'), 1), @@ -4649,6 +4669,16 @@ SELECT min(column1), min(column2), min(column3), min(column4), column5 FROM d GR ---- 0 days 0 hours 0 mins 0 secs 0 days 0 hours 0 mins 0.001 secs 0 days 0 hours 0 mins 0.000003 secs 0 days 0 hours 0 mins 0.000000002 secs 1 +query ????I +SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5 ORDER BY column5; +---- +0 days 0 hours 0 mins 3 secs 0 days 0 hours 0 mins 0.008 secs 0 days 0 hours 0 mins 0.000012 secs 0 days 0 hours 0 mins 0.000000014 secs 1 + +query ????I +SELECT sum(column1), sum(column2), sum(column3), sum(column4), column5 FROM d GROUP BY column5 ORDER BY column5; +---- +0 days 0 hours 0 mins 15 secs 0 days 0 hours 0 mins 0.034 secs 0 days 0 hours 0 mins 0.000048 secs 0 days 0 hours 0 mins 0.000000058 secs 1 + statement ok drop table d;