From 47ef59e71e1cc4bb070ac41a299e673d2b1ea4e0 Mon Sep 17 00:00:00 2001 From: DreaMer963 Date: Sat, 15 Jan 2022 16:16:58 +0800 Subject: [PATCH] fix: sql planner creates cross join instead of inner join from select predicates --- datafusion/src/sql/planner.rs | 109 ++++++++++++++++++++++++++++------ 1 file changed, 90 insertions(+), 19 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index bbd5aa7c5696b..ae9f2724db3f4 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -697,7 +697,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { alias: Option, ) -> Result { let plans = self.plan_from_tables(&select.from, ctes)?; - let plan = match &select.selection { Some(predicate_expr) => { // build join schema @@ -714,33 +713,80 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; let mut all_join_keys = HashSet::new(); - let mut left = plans[0].clone(); - for right in plans.iter().skip(1) { - let left_schema = left.schema(); - let right_schema = right.schema(); + + let mut plans = plans.into_iter(); + let mut left = plans.next().unwrap(); // have at least one plan + + // List of the plans that have not yet been joined + let mut remaining_plans: Vec> = + plans.into_iter().map(Some).collect(); + + // Take from the list of remaining plans, + loop { let mut join_keys = vec![]; - for (l, r) in &possible_join_keys { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - { - join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - { - join_keys.push((r.clone(), l.clone())); - } - } + + // Search all remaining plans for the next to + // join. Prefer the first one that has a join + // predicate in the predicate lists + let plan_with_idx = + remaining_plans.iter().enumerate().find(|(_idx, plan)| { + // skip plans that have been joined already + let plan = if let Some(plan) = plan { + plan + } else { + return false; + }; + + // can we find a match? + let left_schema = left.schema(); + let right_schema = plan.schema(); + for (l, r) in &possible_join_keys { + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() + { + join_keys.push((l.clone(), r.clone())); + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() + { + join_keys.push((r.clone(), l.clone())); + } + } + // stop if we found join keys + !join_keys.is_empty() + }); + + // If we did not find join keys, either there are + // no more plans, or we can't find any plans that + // can be joined with predicates if join_keys.is_empty() { - left = - LogicalPlanBuilder::from(left).cross_join(right)?.build()?; + assert!(plan_with_idx.is_none()); + + // pick the first non null plan to join + let plan_with_idx = remaining_plans + .iter() + .enumerate() + .find(|(_idx, plan)| plan.is_some()); + if let Some((idx, _)) = plan_with_idx { + let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); + left = LogicalPlanBuilder::from(left) + .cross_join(&plan)? + .build()?; + } else { + // no more plans to join + break; + } } else { + // have a plan + let (idx, _) = plan_with_idx.expect("found plan node"); + let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); + let left_keys: Vec = join_keys.iter().map(|(l, _)| l.clone()).collect(); let right_keys: Vec = join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(left); left = builder - .join(right, JoinType::Inner, (left_keys, right_keys))? + .join(&plan, JoinType::Inner, (left_keys, right_keys))? .build()?; } @@ -3818,6 +3864,31 @@ mod tests { \n TableScan: public.person projection=None"; quick_test(sql, expected); } + + #[test] + fn cross_join_to_inner_join() { + let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;"; + let expected = "Projection: #person.id\ + \n Join: #lineitem.l_description = #orders.o_item_id\ + \n Join: #person.id = #lineitem.l_item_id\ + \n TableScan: person projection=None\ + \n TableScan: lineitem projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn cross_join_not_to_inner_join() { + let sql = "select person.id from person, orders, lineitem where person.id = person.age;"; + let expected = "Projection: #person.id\ + \n Filter: #person.id = #person.age\ + \n CrossJoin:\ + \n CrossJoin:\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None\ + \n TableScan: lineitem projection=None"; + quick_test(sql, expected); + } } fn parse_sql_number(n: &str) -> Result {