Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,11 @@ impl SessionState {
rules.push(Arc::new(FilterNullJoinKeys::default()));
}
rules.push(Arc::new(ReduceOuterJoin::new()));
// TODO: https://github.com/apache/arrow-datafusion/issues/3557
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it makes sense to me that we need to simplify expressons after coercion

// remove this, after the issue fixed.
rules.push(Arc::new(TypeCoercion::new()));
// after the type coercion, can do simplify expression again
rules.push(Arc::new(SimplifyExpressions::new()));
rules.push(Arc::new(FilterPushDown::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ mod tests {
use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;

use crate::error::Result;
Expand Down Expand Up @@ -525,7 +526,7 @@ mod tests {
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(1u32),
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
Expand Down Expand Up @@ -568,7 +569,7 @@ mod tests {
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(1u32),
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
Expand Down
22 changes: 15 additions & 7 deletions datafusion/core/src/physical_plan/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,14 @@ mod tests {
physical_plan::collect,
};
use arrow::array::Float32Array;
use arrow::datatypes::DataType::Decimal128;
use arrow::record_batch::RecordBatch;
use arrow::{
array::{Int64Array, Int8Array, StringArray},
datatypes::{DataType, Field},
};
use chrono::{TimeZone, Utc};
use datafusion_expr::{col, lit};
use datafusion_expr::{cast, col, lit};
use futures::StreamExt;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
Expand Down Expand Up @@ -1768,6 +1769,7 @@ mod tests {
// In this case, construct four types of statistics to filtered with the decimal predication.

// INT32: c1 > 5, the c1 is decimal(9,2)
// The type of scalar value if decimal(9,2), don't need to do cast
let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2)));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]);
Expand Down Expand Up @@ -1809,11 +1811,15 @@ mod tests {
);

// INT32: c1 > 5, but parquet decimal type has different precision or scale to arrow decimal
// The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2).
// We should convert all type to the coercion type, which is decimal(11,2)
// The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0)
let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 5, 2)));
let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast(
lit(ScalarValue::Decimal128(Some(500), 5, 2)),
Decimal128(11, 2),
));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(5, 2), false)]);
// The decimal of parquet is decimal(9,0)
Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]);
let schema_descr = get_test_schema_descr(vec![(
"c1",
PhysicalType::INT32,
Expand Down Expand Up @@ -1901,11 +1907,13 @@ mod tests {
vec![1]
);

// FIXED_LENGTH_BYTE_ARRAY: c1 = 100, the c1 is decimal(28,2)
// FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2)
// the type of parquet is decimal(18,2)
let expr = col("c1").eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let schema =
Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 3), false)]);
Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]);
// cast the type of c1 to decimal(28,3)
let left = cast(col("c1"), DataType::Decimal128(28, 3));
let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let schema_descr = get_test_schema_descr(vec![(
"c1",
PhysicalType::FIXED_LEN_BYTE_ARRAY,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,7 @@ mod tests {
use crate::execution::runtime_env::RuntimeEnv;
use crate::logical_plan::plan::Extension;
use crate::physical_plan::{
expressions, DisplayFormatType, Partitioning, Statistics,
expressions, DisplayFormatType, Partitioning, PhysicalPlanner, Statistics,
};
use crate::prelude::{SessionConfig, SessionContext};
use crate::scalar::ScalarValue;
Expand Down Expand Up @@ -1736,10 +1736,10 @@ mod tests {
let exec_plan = plan(&logical_plan).await?;

// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
assert!(format!("{:?}", exec_plan).contains(expected));

Ok(())
}

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
assert_eq!(results.len(), 1);

let expected = vec![
"+--------------+---------------------------+---------------------------+---------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"+--------------+-------------------------+-------------------------+-------------------------+",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+-------------------------+-------------------------+-------------------------+",
];
assert_batches_sorted_eq!(expected, &results);

Expand Down
114 changes: 57 additions & 57 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------+",
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+----------------------------------------------------+",
"+------------------------------+",
"| decimal_simple.c1 + Int64(1) |",
"+------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+------------------------------+",
];
assert_batches_eq!(expected, &actual);
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
Expand Down Expand Up @@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------+",
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+----------------------------------------------------+",
"+------------------------------+",
"| decimal_simple.c1 - Int64(1) |",
"+------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-----------------------------------------------------+",
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
"+-----------------------------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-----------------------------------------------------+",
"+-------------------------------+",
"| decimal_simple.c1 * Int64(20) |",
"+-------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ async fn csv_in_set_test() -> Result<()> {

#[tokio::test]
async fn multiple_or_predicates() -> Result<()> {
// TODO https://github.com/apache/arrow-datafusion/issues/3587
let ctx = SessionContext::new();
register_tpch_csv(&ctx, "lineitem").await?;
register_tpch_csv(&ctx, "part").await?;
Expand Down Expand Up @@ -424,15 +425,13 @@ async fn multiple_or_predicates() -> Result<()> {
let plan = state.optimize(&plan)?;
// Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
// factored out and appear only once in the following plan
let expected =vec![
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,12 @@ async fn use_between_expression_in_select_query() -> Result<()> {
.unwrap()
.to_string();

// TODO https://github.com/apache/arrow-datafusion/issues/3587
// Only test that the projection exprs are correct, rather than entire output
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]";
assert_contains!(&formatted, needle);
let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
assert_contains!(&formatted, needle);

Ok(())
}

Expand Down
Loading