diff --git a/datafusion/core/tests/sql/arrow_typeof.rs b/datafusion/core/tests/sql/arrow_typeof.rs index 9f971f27b9346..06b9327f97a2e 100644 --- a/datafusion/core/tests/sql/arrow_typeof.rs +++ b/datafusion/core/tests/sql/arrow_typeof.rs @@ -62,10 +62,21 @@ async fn arrow_typeof_i32() -> Result<()> { } #[tokio::test] -async fn arrow_typeof_f64() -> Result<()> { +async fn arrow_typeof_decimal128() -> Result<()> { let ctx = SessionContext::new(); let sql = "SELECT arrow_typeof(1.0)"; let actual = execute(&ctx, sql).await; + let expected = "Decimal128(2, 1)"; + assert_eq!(expected, &actual[0][0]); + + Ok(()) +} + +#[tokio::test] +async fn arrow_typeof_f64() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "SELECT arrow_typeof(1.0::double)"; + let actual = execute(&ctx, sql).await; let expected = "Float64"; assert_eq!(expected, &actual[0][0]); diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 7c74cdd52f0e8..20af706ecebde 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+---------------+", - "| Float64(1.23) |", - "+---------------+", - "| 1.2300 |", - "+---------------+", + "+---------------------------+", + "| Decimal128(Some(123),3,2) |", + "+---------------------------+", + "| 1.2300 |", + "+---------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+---------------+", - "| Float64(1.23) |", - "+---------------+", - "| 1.2300 |", - "+---------------+", + "+---------------------------+", + "| Decimal128(Some(123),3,2) |", + "+---------------------------+", + "| 1.2300 |", + "+---------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+-----------------+", - "| Float64(1.2345) |", - "+-----------------+", - "| 1.23 |", - "+-----------------+", + "+-----------------------------+", + "| Decimal128(Some(12345),5,4) |", + "+-----------------------------+", + "| 1.23 |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+--------------------------------------+", - "| decimal_simple.c1 / Float64(0.00001) |", - "+--------------------------------------+", - "| 1.000000000000 |", - "| 2.000000000000 |", - "| 2.000000000000 |", - "| 3.000000000000 |", - "| 3.000000000000 |", - "| 3.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "+--------------------------------------+", + "+---------------------------------------------+", + "| decimal_simple.c1 / Decimal128(Some(1),6,5) |", + "+---------------------------------------------+", + "| 1.000000000000 |", + "| 2.000000000000 |", + "| 2.000000000000 |", + "| 3.000000000000 |", + "| 3.000000000000 |", + "| 3.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "+---------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+--------------------------------------+", - "| decimal_simple.c5 % Float64(0.00001) |", - "+--------------------------------------+", - "| 0.0000040 |", - "| 0.0000050 |", - "| 0.0000090 |", - "| 0.0000020 |", - "| 0.0000050 |", - "| 0.0000010 |", - "| 0.0000040 |", - "| 0.0000000 |", - "| 0.0000000 |", - "| 0.0000040 |", - "| 0.0000020 |", - "| 0.0000080 |", - "| 0.0000030 |", - "| 0.0000080 |", - "| 0.0000000 |", - "+--------------------------------------+", + "+---------------------------------------------+", + "| decimal_simple.c5 % Decimal128(Some(1),6,5) |", + "+---------------------------------------------+", + "| 0.0000040 |", + "| 0.0000050 |", + "| 0.0000090 |", + "| 0.0000020 |", + "| 0.0000050 |", + "| 0.0000010 |", + "| 0.0000040 |", + "| 0.0000000 |", + "| 0.0000000 |", + "| 0.0000040 |", + "| 0.0000020 |", + "| 0.0000080 |", + "| 0.0000030 |", + "| 0.0000080 |", + "| 0.0000000 |", + "+---------------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 54d1b24e81e79..2dc95b302ddd3 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -234,13 +234,13 @@ async fn select_values_list() -> Result<()> { let sql = "EXPLAIN VALUES (1, 'a', -1, 1.1),(NULL, 'b', -3, 0.5)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| logical_plan | Values: (Int64(1), Utf8(\"a\"), Int64(-1), Float64(1.1)), (Int64(NULL), Utf8(\"b\"), Int64(-3), Float64(0.5)) |", - "| physical_plan | ValuesExec |", - "| | |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", + "+---------------+----------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+----------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Values: (Int64(1), Utf8(\"a\"), Int64(-1), Decimal128(Some(11),2,1)), (Int64(NULL), Utf8(\"b\"), Int64(-3), Decimal128(Some(5),2,1)) |", + "| physical_plan | ValuesExec |", + "| | |", + "+---------------+----------------------------------------------------------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index dd44459f5081e..95b35945047db 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -1619,20 +1619,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema, ctes)?), UnaryOperator::Minus => { match expr { - // optimization: if it's a number literal, we apply the negative operator - // here directly to calculate the new literal. - SQLExpr::Value(Value::Number(n, _)) => match n.parse::() { - Ok(n) => Ok(lit(-n)), - Err(_) => Ok(lit(-n - .parse::() - .map_err(|_e| { - DataFusionError::Internal(format!( - "negative operator can be only applied to integer and float operands, got: {}", - n)) - })?)), - }, + SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n, true), // not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema, ctes)?))), + _ => Ok(Expr::Negative(Box::new( + self.sql_expr_to_logical_expr(expr, schema, ctes)?, + ))), } } _ => Err(DataFusionError::NotImplemented(format!( @@ -1651,7 +1642,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|row| { row.into_iter() .map(|v| match v { - SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), + SQLExpr::Value(Value::Number(n, _)) => { + parse_sql_number(&n, false) + } SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), ) => Ok(lit(s)), @@ -1695,7 +1688,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ctes: &mut HashMap, ) -> Result { match sql { - SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n), + SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n, false), SQLExpr::Value(Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s)) => Ok(lit(s.clone())), SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)), SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)), @@ -2698,11 +2691,57 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { } } -// Parse number in sql string, convert to Expr::Literal -fn parse_sql_number(n: &str) -> Result { - match n.parse::() { - Ok(n) => Ok(lit(n)), - Err(_) => Ok(lit(n.parse::().unwrap())), +fn create_decimal_from_string(n: &str) -> Result { + let (value, scale) = match n.split_once('.') { + None => { + // it's not an int and there's no decimal + (n.to_string(), 0_u8) + } + Some((w, d)) => { + // the length of d will be the scale, concat w and d for the complete number + let scale = d.len(); + let new_n = [w, d].join(""); + (new_n, scale as u8) + } + }; + + match value.parse::() { + Ok(i) => { + let precision = if i.is_positive() { + value.len() + } else { + value.len() - 1 + }; + // rust considers 0 to be negative, so input of .0 will end up with a length of zero. + // here we make sure precision is always at least 1 + let precision = precision.max(1); + Ok(ScalarValue::Decimal128(Some(i), precision as u8, scale)) + } + Err(e) => Err(DataFusionError::SQL(ParserError(format!( + "Internal error: Unable to parse {} to integer or decimal: {}", + value, e + )))), + } +} + +// Parse number in sql string, convert to Expr::Literal of int or decimal +fn parse_sql_number(n: &str, negative: bool) -> Result { + // if can parse to i64, then do it and return + let n_int = n.parse::(); + if let Ok(n) = n_int { + return if negative { Ok(lit(-n)) } else { Ok(lit(n)) }; + } + + // if it's a negative, add the - to the string + let n = if negative { + ["-", n].join("") + } else { + n.to_string() + }; + + match create_decimal_from_string(&n) { + Ok(scalar) => Ok(lit(scalar)), + Err(e) => Err(e), } } @@ -2722,6 +2761,60 @@ mod tests { ); } + #[test] + fn test_whole_number_negative() { + quick_test( + "SELECT -1", + "Projection: Int64(-1)\ + \n EmptyRelation", + ); + } + + #[test] + fn test_decimal() { + quick_test( + "SELECT 1.0", + "Projection: Decimal128(Some(10),2,1)\ + \n EmptyRelation", + ); + } + + #[test] + fn test_decimal_negative() { + quick_test( + "SELECT -1.0", + "Projection: Decimal128(Some(-10),2,1)\ + \n EmptyRelation", + ); + } + + #[test] + fn test_decimal_parts() { + quick_test( + "SELECT .0 as dot_zero, 0. as zero_dot, 0.0 as zero_dot_zero", + "Projection: Decimal128(Some(0),1,1) AS dot_zero, Decimal128(Some(0),1,0) AS zero_dot, Decimal128(Some(0),1,1) AS zero_dot_zero\ + \n EmptyRelation", + ); + } + + #[test] + fn test_decimal_just_whole() { + quick_test( + "SELECT 0.", + "Projection: Decimal128(Some(0),2,1)\ + \n EmptyRelation", + ); + } + + #[test] + fn test_decimal_negative_dot_decimal() { + quick_test( + "SELECT -.0", + "Projection: Decimal128(Some(0),1,1)\ + \n EmptyRelation", + ); + } + #[test] fn test_real_f32() { quick_test(