Skip to content
96 changes: 96 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2304,3 +2304,99 @@ async fn error_cross_join() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// reduce to inner join
let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1";
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);
let expected = vec![
"+-------+---------+--------+-------+---------+--------+",
"| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
"+-------+---------+--------+-------+---------+--------+",
"| 11 | a | 1 | 22 | y | 1 |",
"| 33 | c | 3 | 44 | x | 3 |",
"| 44 | d | 4 | 55 | w | 3 |",
"+-------+---------+--------+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}

#[tokio::test]
async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// reduce to inner join, t2.t2_id will insert cast.
let sql =
"select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
Comment on lines +2372 to +2373
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these two projections are not collapsed. This is unrelated to this PR. But we need to take a look at the PushDownProjection rule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jackwener
maybe @jackwener has fixed that in #4487

Copy link
Member

@jackwener jackwener Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't related with this PR. This projection is generated due to keep the same order.
It should need do more optimization in PushDownProjection.

Recent I am doing a enhancement about PushDownProjection, it should can resolve this.

" Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
" Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);
let expected = vec![
"+-------+-------+---------+",
"| t1_id | t2_id | t1_name |",
"+-------+-------+---------+",
"| 11 | 22 | a |",
"| 33 | 44 | c |",
"| 44 | 55 | d |",
"+-------+-------+---------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}
4 changes: 3 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ pub use function::{
};
pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
pub use logical_plan::{
builder::{build_join_schema, union, UNNAMED_TABLE},
builder::{
build_join_schema, union, wrap_projection_for_join_if_necessary, UNNAMED_TABLE,
},
Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable,
CreateMemoryTable, CreateView, CrossJoin, Distinct, DropTable, DropView,
EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit,
Expand Down
65 changes: 64 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ use datafusion_common::{
ToDFSchema,
};
use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;

/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";
Expand Down Expand Up @@ -995,6 +996,68 @@ pub fn table_scan(
LogicalPlanBuilder::scan(name.unwrap_or(UNNAMED_TABLE), table_source, projection)
}

/// Wrap projection for a plan, if the join keys contains normal expression.
pub fn wrap_projection_for_join_if_necessary(
join_keys: &[Expr],
input: LogicalPlan,
) -> Result<(LogicalPlan, Vec<Column>, bool)> {
let input_schema = input.schema();
let alias_join_keys: Vec<Expr> = join_keys
.iter()
.map(|key| {
// The display_name() of cast expression will ignore the cast info, and show the inner expression name.
// If we do not add alais, it will throw same field name error in the schema when adding projection.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good comments

// For example:
// input scan : [a, b, c],
// join keys: [cast(a as int)]
//
// then a and cast(a as int) will use the same field name - `a` in projection schema.
Comment on lines +1008 to +1014
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the Cast expressions ignore the cast info ? Looks like this is not the expected behavior.
I think no matter the Cast is explicit or implicit, it should display the cast info.

https://github.com/apache/arrow-datafusion/blob/cedb05aedf3cea030bfa8774b8575d8f4806a1c8/datafusion/expr/src/expr.rs#L1108-L1115

@andygrove @liukun4515 @jackwener
How do you think?

Copy link
Member

@jackwener jackwener Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we need to consider to the Cast.

  1. sql can explicitly include cast
  2. after infer cast, it will generate cast.

But I need to time to think carefully about its detail.

// https://github.com/apache/arrow-datafusion/issues/4478
if matches!(key, Expr::Cast(_))
|| matches!(
key,
Expr::TryCast {
expr: _,
data_type: _
}
)
{
let alias = format!("{:?}", key);
key.clone().alias(alias)
} else {
key.clone()
}
})
.collect::<Vec<_>>();

let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_)));
let plan = if need_project {
let mut projection = expand_wildcard(input_schema, &input)?;
let join_key_items = alias_join_keys
.iter()
.flat_map(|expr| expr.try_into_col().is_err().then_some(expr))
.cloned()
.collect::<HashSet<Expr>>();
projection.extend(join_key_items);

LogicalPlanBuilder::from(input)
.project(projection)?
.build()?
} else {
input
};

let join_on = alias_join_keys
.into_iter()
.map(|key| {
key.try_into_col()
.or_else(|_| Ok(Column::from_name(key.display_name()?)))
})
.collect::<Result<Vec<_>>>()?;

Ok((plan, join_on, need_project))
}

/// Basic TableSource implementation intended for use in tests and documentation. It is expected
/// that users will provide their own TableSource implementations or use DataFusion's
/// DefaultTableSource.
Expand Down
Loading