From ca5c6bd810c6fc07ae8b1d6d551c3f018616fa42 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Fri, 7 Apr 2023 10:50:09 +0800 Subject: [PATCH 1/2] RowAccumulator support for Decimal128 --- .../physical-expr/src/aggregate/average.rs | 39 ++++++++++++++----- .../physical-expr/src/aggregate/min_max.rs | 3 ++ .../src/aggregate/row_accumulator.rs | 1 + datafusion/physical-expr/src/aggregate/sum.rs | 3 ++ datafusion/row/src/accessor.rs | 15 +++++++ datafusion/row/src/layout.rs | 15 ++++--- datafusion/row/src/reader.rs | 17 +++++++- datafusion/row/src/writer.rs | 19 +++++++-- 8 files changed, 93 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index de5f78f0a79f4..6a4025d8b8c3f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -266,16 +266,35 @@ impl RowAccumulator for AvgRowAccumulator { } fn evaluate(&self, accessor: &RowAccessor) -> Result { - assert_eq!(self.sum_datatype, DataType::Float64); - Ok(match accessor.get_u64_opt(self.state_index()) { - None => ScalarValue::Float64(None), - Some(0) => ScalarValue::Float64(None), - Some(n) => ScalarValue::Float64( - accessor - .get_f64_opt(self.state_index() + 1) - .map(|f| f / n as f64), - ), - }) + match self.sum_datatype { + DataType::Decimal128(p, s) => { + Ok(match accessor.get_u64_opt(self.state_index()) { + None => ScalarValue::Decimal128(None, p, s), + Some(0) => ScalarValue::Decimal128(None, p, s), + Some(n) => ScalarValue::Decimal128( + accessor + .get_i128_opt(self.state_index() + 1) + .map(|f| f / n as i128), + p, s), + }) + } + DataType::Float64 => { + Ok(match accessor.get_u64_opt(self.state_index()) { + None => ScalarValue::Float64(None), + Some(0) => ScalarValue::Float64(None), + Some(n) => ScalarValue::Float64( + accessor + .get_f64_opt(self.state_index() + 1) + .map(|f| f / n as f64), + ), + }) + } + _ => { + Err(DataFusionError::Internal( + "Sum should be f64 on average".to_string(), + )) + } + } } #[inline(always)] diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 8a9d39a15f15a..dce10a62b3194 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -485,6 +485,9 @@ macro_rules! min_max_v2 { ScalarValue::Int8(rhs) => { typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP) } + ScalarValue::Decimal128(rhs, ..) => { + typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP) + } e => { return Err(DataFusionError::Internal(format!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs index d26da8f4cec91..00717a113f9be 100644 --- a/datafusion/physical-expr/src/aggregate/row_accumulator.rs +++ b/datafusion/physical-expr/src/aggregate/row_accumulator.rs @@ -79,5 +79,6 @@ pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal128(_, _) ) } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index a815a33c8c7fe..e1356160a739f 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -220,6 +220,9 @@ pub(crate) fn add_to_row( ScalarValue::Int64(rhs) => { sum_row!(index, accessor, rhs, i64) } + ScalarValue::Decimal128(rhs, _, _) => { + sum_row!(index, accessor, rhs, i128) + } _ => { let msg = format!("Row sum updater is not expected to receive a scalar {s:?}"); diff --git a/datafusion/row/src/accessor.rs b/datafusion/row/src/accessor.rs index f8e34578dbdac..e7b4ed85016a3 100644 --- a/datafusion/row/src/accessor.rs +++ b/datafusion/row/src/accessor.rs @@ -193,6 +193,7 @@ impl<'a> RowAccessor<'a> { fn_get_idx!(i64, 8); fn_get_idx!(f32, 4); fn_get_idx!(f64, 8); + fn_get_idx!(i128, 16); fn_get_idx_opt!(bool); fn_get_idx_opt!(u8); @@ -205,6 +206,7 @@ impl<'a> RowAccessor<'a> { fn_get_idx_opt!(i64); fn_get_idx_opt!(f32); fn_get_idx_opt!(f64); + fn_get_idx_opt!(i128); fn_get_idx_scalar!(bool, Boolean); fn_get_idx_scalar!(u8, UInt8); @@ -218,6 +220,14 @@ impl<'a> RowAccessor<'a> { fn_get_idx_scalar!(f32, Float32); fn_get_idx_scalar!(f64, Float64); + fn get_decimal128_scalar(&self, idx: usize, p: u8, s: i8) -> ScalarValue { + if self.is_valid_at(idx) { + ScalarValue::Decimal128(Some(self.get_i128(idx)), p, s) + } else { + ScalarValue::Decimal128(None, p, s) + } + } + pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue { match dt { DataType::Boolean => self.get_bool_scalar(index), @@ -231,6 +241,7 @@ impl<'a> RowAccessor<'a> { DataType::UInt64 => self.get_u64_scalar(index), DataType::Float32 => self.get_f32_scalar(index), DataType::Float64 => self.get_f64_scalar(index), + DataType::Decimal128(p, s) => self.get_decimal128_scalar(index, *p, *s), _ => unreachable!(), } } @@ -264,6 +275,7 @@ impl<'a> RowAccessor<'a> { fn_set_idx!(i64, 8); fn_set_idx!(f32, 4); fn_set_idx!(f64, 8); + fn_set_idx!(i128, 16); fn set_i8(&mut self, idx: usize, value: i8) { self.assert_index_valid(idx); @@ -285,6 +297,7 @@ impl<'a> RowAccessor<'a> { fn_add_idx!(i64); fn_add_idx!(f32); fn_add_idx!(f64); + fn_add_idx!(i128); fn_max_min_idx!(u8, max); fn_max_min_idx!(u16, max); @@ -296,6 +309,7 @@ impl<'a> RowAccessor<'a> { fn_max_min_idx!(i64, max); fn_max_min_idx!(f32, max); fn_max_min_idx!(f64, max); + fn_max_min_idx!(i128, max); fn_max_min_idx!(u8, min); fn_max_min_idx!(u16, min); @@ -307,4 +321,5 @@ impl<'a> RowAccessor<'a> { fn_max_min_idx!(i64, min); fn_max_min_idx!(f32, min); fn_max_min_idx!(f64, min); + fn_max_min_idx!(i128, min); } diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs index 502812cb9f66b..31f05f75cbe92 100644 --- a/datafusion/row/src/layout.rs +++ b/datafusion/row/src/layout.rs @@ -164,11 +164,15 @@ fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec, usiz let mut offset = null_width; for f in schema.fields() { offsets.push(offset); - assert!(!matches!(f.data_type(), DataType::Decimal128(_, _))); - // All of the current support types can fit into one single 8-bytes word. - // When we decide to support Decimal type in the future, its width would be - // of two 8-bytes words and should adapt the width calculation below. - offset += 8; + assert!(!matches!(f.data_type(), DataType::Decimal256(_, _))); + // All of the current support types can fit into one single 8-bytes word except for Decimal128. + // For Decimal128, its width is of two 8-bytes words. + match f.data_type() { + DataType::Decimal128(_, _) => { + offset += 16 + } + _ => offset += 8 + } } (offsets, offset - null_width) } @@ -241,6 +245,7 @@ fn supported_type(dt: &DataType, row_type: RowType) -> bool { | Float64 | Date32 | Date64 + | Decimal128(_, _) ) } } diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs index 634b814ad35ae..53fd332996ced 100644 --- a/datafusion/row/src/reader.rs +++ b/datafusion/row/src/reader.rs @@ -213,6 +213,10 @@ impl<'a> RowReader<'a> { get_idx!(i64, self, idx, 8) } + fn get_decimal128(&self, idx: usize) -> i128 { + get_idx!(i128, self, idx, 16) + } + fn get_utf8(&self, idx: usize) -> &str { self.assert_index_valid(idx); let offset_size = self.get_u64(idx); @@ -260,6 +264,14 @@ impl<'a> RowReader<'a> { } } + fn get_decimal128_opt(&self, idx: usize) -> Option { + if self.is_valid_at(idx) { + Some(self.get_decimal128(idx)) + } else { + None + } + } + fn get_utf8_opt(&self, idx: usize) -> Option<&str> { if self.is_valid_at(idx) { Some(self.get_utf8(idx)) @@ -328,6 +340,7 @@ fn_read_field!(f64, Float64Builder); fn_read_field!(date32, Date32Builder); fn_read_field!(date64, Date64Builder); fn_read_field!(utf8, StringBuilder); +fn_read_field!(decimal128, Decimal128Builder); pub(crate) fn read_field_binary( to: &mut Box, @@ -374,6 +387,7 @@ fn read_field( Date64 => read_field_date64(to, col_idx, row), Utf8 => read_field_utf8(to, col_idx, row), Binary => read_field_binary(to, col_idx, row), + Decimal128(_, _) => read_field_decimal128(to, col_idx, row), _ => unimplemented!(), } } @@ -401,6 +415,7 @@ fn read_field_null_free( Date64 => read_field_date64_null_free(to, col_idx, row), Utf8 => read_field_utf8_null_free(to, col_idx, row), Binary => read_field_binary_null_free(to, col_idx, row), + Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row), _ => unimplemented!(), } -} +} \ No newline at end of file diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs index 12339afe77ddc..765829998dc3f 100644 --- a/datafusion/row/src/writer.rs +++ b/datafusion/row/src/writer.rs @@ -22,9 +22,7 @@ use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw}; -use datafusion_common::cast::{ - as_binary_array, as_date32_array, as_date64_array, as_string_array, -}; +use datafusion_common::cast::{as_binary_array, as_date32_array, as_date64_array, as_decimal128_array, as_string_array}; use datafusion_common::Result; use std::cmp::max; use std::sync::Arc; @@ -225,6 +223,10 @@ impl RowWriter { set_idx!(8, self, idx, value) } + fn set_decimal128(&mut self, idx: usize, value: i128) { + set_idx!(16, self, idx, value) + } + fn set_offset_size(&mut self, idx: usize, size: u32) { let offset_and_size: u64 = (self.varlena_offset as u64) << 32 | (size as u64); self.set_u64(idx, offset_and_size); @@ -375,6 +377,16 @@ pub(crate) fn write_field_binary( to.set_binary(col_idx, s); } +pub(crate) fn write_field_decimal128( + to: &mut RowWriter, + from: &Arc, + col_idx: usize, + row_idx: usize, +) { + let from = as_decimal128_array(from).unwrap(); + to.set_decimal128(col_idx, from.value(row_idx)); +} + fn write_field( col_idx: usize, row_idx: usize, @@ -399,6 +411,7 @@ fn write_field( Date64 => write_field_date64(row, col, col_idx, row_idx), Utf8 => write_field_utf8(row, col, col_idx, row_idx), Binary => write_field_binary(row, col, col_idx, row_idx), + Decimal128(_, _) => write_field_decimal128(row, col, col_idx, row_idx), _ => unimplemented!(), } } From b47846f06bffca61cdf26cb715d056210c6fc8c2 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Wed, 12 Apr 2023 17:21:00 +0800 Subject: [PATCH 2/2] add test --- .../sqllogictests/test_files/aggregate.slt | 25 +++++++ .../physical-expr/src/aggregate/average.rs | 70 +++++++++++-------- datafusion/row/src/layout.rs | 6 +- datafusion/row/src/reader.rs | 2 +- datafusion/row/src/writer.rs | 5 +- 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 10368341d85f3..9e122d3a26e5a 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -1615,6 +1615,31 @@ SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table; ---- NULL +# Creating the decimal table +statement ok +CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1)) + +# Inserting data +statement ok +INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL) + +# aggregate_decimal_with_group_by +query IIRRRRIIR rowsort +select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1 +---- +1 2 15.15 30.3 10.1 20.2 2 0 NULL +2 2 15.15 30.3 10.1 20.2 2 0 NULL +3 2 10.1 20.2 10.1 10.1 1 0 NULL + +# aggregate_decimal_with_group_by_decimal +query RIRRRRIR rowsort +select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3 +---- +100.1 2 10.1 20.2 10.1 10.1 0 NULL +200.2 1 20.2 20.2 20.2 20.2 0 NULL +700.1 2 15.15 30.3 10.1 20.2 0 NULL +NULL 1 10.1 10.1 10.1 10.1 0 NULL + # Restore the default dialect statement ok set datafusion.sql_parser.dialect = 'Generic'; diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 66bde2b13a755..dcb2e9f9cc92d 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -143,7 +143,8 @@ impl AggregateExpr for Avg { ) -> Result> { Ok(Box::new(AvgRowAccumulator::new( start_index, - self.sum_data_type.clone(), + &self.sum_data_type, + &self.rt_data_type, ))) } @@ -236,7 +237,7 @@ impl Accumulator for AvgAccumulator { }) } _ => Err(DataFusionError::Internal( - "Sum should be f64 on average".to_string(), + "Sum should be f64 or decimal128 on average".to_string(), )), } } @@ -250,13 +251,19 @@ impl Accumulator for AvgAccumulator { struct AvgRowAccumulator { state_index: usize, sum_datatype: DataType, + return_data_type: DataType, } impl AvgRowAccumulator { - pub fn new(start_index: usize, sum_datatype: DataType) -> Self { + pub fn new( + start_index: usize, + sum_datatype: &DataType, + return_data_type: &DataType, + ) -> Self { Self { state_index: start_index, - sum_datatype, + sum_datatype: sum_datatype.clone(), + return_data_type: return_data_type.clone(), } } } @@ -300,32 +307,37 @@ impl RowAccumulator for AvgRowAccumulator { fn evaluate(&self, accessor: &RowAccessor) -> Result { match self.sum_datatype { DataType::Decimal128(p, s) => { - Ok(match accessor.get_u64_opt(self.state_index()) { - None => ScalarValue::Decimal128(None, p, s), - Some(0) => ScalarValue::Decimal128(None, p, s), - Some(n) => ScalarValue::Decimal128( - accessor - .get_i128_opt(self.state_index() + 1) - .map(|f| f / n as i128), - p, s), - }) - } - DataType::Float64 => { - Ok(match accessor.get_u64_opt(self.state_index()) { - None => ScalarValue::Float64(None), - Some(0) => ScalarValue::Float64(None), - Some(n) => ScalarValue::Float64( - accessor - .get_f64_opt(self.state_index() + 1) - .map(|f| f / n as f64), - ), - }) - } - _ => { - Err(DataFusionError::Internal( - "Sum should be f64 on average".to_string(), - )) + match accessor.get_u64_opt(self.state_index()) { + None => Ok(ScalarValue::Decimal128(None, p, s)), + Some(0) => Ok(ScalarValue::Decimal128(None, p, s)), + Some(n) => { + // now the sum_type and return type is not the same, need to convert the sum type to return type + accessor.get_i128_opt(self.state_index() + 1).map_or_else( + || Ok(ScalarValue::Decimal128(None, p, s)), + |f| { + calculate_result_decimal_for_avg( + f, + n as i128, + s, + &self.return_data_type, + ) + }, + ) + } + } } + DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) { + None => ScalarValue::Float64(None), + Some(0) => ScalarValue::Float64(None), + Some(n) => ScalarValue::Float64( + accessor + .get_f64_opt(self.state_index() + 1) + .map(|f| f / n as f64), + ), + }), + _ => Err(DataFusionError::Internal( + "Sum should be f64 or decimal128 on average".to_string(), + )), } } diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs index 31f05f75cbe92..6a8e8a78ec9d8 100644 --- a/datafusion/row/src/layout.rs +++ b/datafusion/row/src/layout.rs @@ -168,10 +168,8 @@ fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec, usiz // All of the current support types can fit into one single 8-bytes word except for Decimal128. // For Decimal128, its width is of two 8-bytes words. match f.data_type() { - DataType::Decimal128(_, _) => { - offset += 16 - } - _ => offset += 8 + DataType::Decimal128(_, _) => offset += 16, + _ => offset += 8, } } (offsets, offset - null_width) diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs index 53fd332996ced..a8dc8211f0c75 100644 --- a/datafusion/row/src/reader.rs +++ b/datafusion/row/src/reader.rs @@ -418,4 +418,4 @@ fn read_field_null_free( Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row), _ => unimplemented!(), } -} \ No newline at end of file +} diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs index 765829998dc3f..7bf9ac0267b74 100644 --- a/datafusion/row/src/writer.rs +++ b/datafusion/row/src/writer.rs @@ -22,7 +22,10 @@ use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw}; -use datafusion_common::cast::{as_binary_array, as_date32_array, as_date64_array, as_decimal128_array, as_string_array}; +use datafusion_common::cast::{ + as_binary_array, as_date32_array, as_date64_array, as_decimal128_array, + as_string_array, +}; use datafusion_common::Result; use std::cmp::max; use std::sync::Arc;