Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/math.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

##########
## Math expression Tests
##########

statement ok
CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv';

# Round
query R
SELECT ROUND(c1) FROM aggregate_simple
----
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0

# Round
query R
SELECT round(c1/3, 2) FROM aggregate_simple order by c1
----
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0

# Round
query R
SELECT round(c1, 4) FROM aggregate_simple order by c1
----
0
0
0
0
0
0
0
0
0
0
0.0001
0.0001
0.0001
0.0001
0.0001

# Round
query RRRRRRRR
SELECT round(125.2345, -3), round(125.2345, -2), round(125.2345, -1), round(125.2345), round(125.2345, 0), round(125.2345, 1), round(125.2345, 2), round(125.2345, 3)
----
0 100 130 125 125 125.2 125.23 125.235
5 changes: 3 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ scalar_expr!(
num,
"nearest integer greater than or equal to argument"
);
scalar_expr!(Round, round, num, "round to nearest integer");
nary_scalar_expr!(Round, round, "round to nearest integer");
scalar_expr!(Trunc, trunc, num, "truncate toward zero");
scalar_expr!(Abs, abs, num, "absolute value");
scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
Expand Down Expand Up @@ -766,7 +766,8 @@ mod test {
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Floor, floor);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Round, round);
test_nary_scalar_expr!(Round, round, input);
test_nary_scalar_expr!(Round, round, input, decimal_places);
test_unary_scalar_expr!(Trunc, trunc);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
BuiltinScalarFunction::Round => Arc::new(math_expressions::round),
BuiltinScalarFunction::Round => {
Arc::new(|args| make_scalar_function(math_expressions::round)(args))
}
BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum),
BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin),
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
Expand Down
102 changes: 101 additions & 1 deletion datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 {
})
.collect::<$ARRAY_TYPE>()
}};
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);

arg1.iter()
.zip(arg2.iter())
.map(|(a1, a2)| match (a1, a2) {
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
_ => None,
})
.collect::<$ARRAY_TYPE1>()
}};
}

math_unary_function!("sqrt", sqrt);
Expand All @@ -124,7 +136,6 @@ math_unary_function!("acos", acos);
math_unary_function!("atan", atan);
math_unary_function!("floor", floor);
math_unary_function!("ceil", ceil);
math_unary_function!("round", round);
math_unary_function!("trunc", trunc);
math_unary_function!("abs", abs);
math_unary_function!("signum", signum);
Expand All @@ -149,6 +160,59 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}

/// Round SQL function
pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 && args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"round function requires one or two arguments, got {}",
args.len()
)));
}

let mut decimal_places =
&(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef);
Comment on lines +172 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more efficient to make a ColumnarValue to avoid creating a large empty array. That would complicate the implementation as it would have to handle both the array argument form and the scalar argument form

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, not sure if it is worth.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, there is other function that also create an array like that for one arg case, e.g. log, so I will change them all in a follow up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it is reasonable to start with a working implementation and optimize as a follow on 👍


if args.len() == 2 {
decimal_places = &args[1];
}

match args[0].data_type() {
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float64Array,
Int64Array,
{
|value: f64, decimal_places: i64| {
(value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round()
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
}
}
)) as ArrayRef),

DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float32Array,
Int64Array,
{
|value: f32, decimal_places: i64| {
(value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round()
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
}
}
)) as ArrayRef),

other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function round"
))),
}
}

/// Power SQL function
pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
Expand Down Expand Up @@ -365,4 +429,40 @@ mod tests {
assert_eq!(floats.value(2), 4.0);
assert_eq!(floats.value(3), 4.0);
}

#[test]
fn test_round_f32() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float32Array::from(vec![125.2345; 10])), // input
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
];

let result = round(&args).expect("failed to initialize function round");
let floats =
as_float32_array(&result).expect("failed to initialize function round");

let expected = Float32Array::from(vec![
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spot checked that these answers are consistent with postgres 👍 (including negative)

postgres=# select round(125.2345, -1);
 round
-------
   130
(1 row)

postgres=# select round(125.2345, -2);
 round
-------
   100
(1 row)

postgres=# select round(125.2345, 3);
  round
---------
 125.235
(1 row)

]);

assert_eq!(floats, &expected);
}

#[test]
fn test_round_f64() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from(vec![125.2345; 10])), // input
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
];

let result = round(&args).expect("failed to initialize function round");
let floats =
as_float64_array(&result).expect("failed to initialize function round");

let expected = Float64Array::from(vec![
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
]);

assert_eq!(floats, &expected);
}
}
7 changes: 6 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,12 @@ pub fn parse_expr(
ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)),
ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)),
ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)),
ScalarFunction::Round => Ok(round(parse_expr(&args[0], registry)?)),
ScalarFunction::Round => Ok(round(
args.to_owned()
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)),
ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)),
ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)),
Expand Down
Loading