diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index ed9a68c19536c..31eafc7443900 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -334,7 +334,7 @@ impl FunctionalDependencies { left_func_dependencies.extend(right_func_dependencies); left_func_dependencies } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // These joins preserve functional dependencies of the left side: left_func_dependencies } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index d502e7836da3a..e98f34199b277 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -44,6 +44,20 @@ pub enum JoinType { LeftAnti, /// Right Anti Join RightAnti, + /// Left Mark join + /// + /// Returns one record for each record from the left input. The output contains an additional + /// column "mark" which is true if there is at least one match in the right input where the + /// join condition evaluates to true. Otherwise, the mark column is false. For more details see + /// [1]. This join type is used to decorrelate EXISTS subqueries used inside disjunctive + /// predicates. + /// + /// Note: This we currently do not implement the full null semantics for the mark join described + /// in [1] which will be needed if we and ANY subqueries. In our version the mark column will + /// only be true for had a match and false when no match was found, never null. + /// + /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf + LeftMark, } impl JoinType { @@ -63,6 +77,7 @@ impl Display for JoinType { JoinType::RightSemi => "RightSemi", JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", + JoinType::LeftMark => "LeftMark", }; write!(f, "{join_type}") } @@ -82,6 +97,7 @@ impl FromStr for JoinType { "RIGHTSEMI" => Ok(JoinType::RightSemi), "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), + "LEFTMARK" => Ok(JoinType::LeftMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } @@ -101,6 +117,7 @@ impl Display for JoinSide { match self { JoinSide::Left => write!(f, "left"), JoinSide::Right => write!(f, "right"), + JoinSide::None => write!(f, "none"), } } } @@ -113,6 +130,9 @@ pub enum JoinSide { Left, /// Right side of the join Right, + /// Neither side of the join, used for Mark joins where the mark column does not belong to + /// either side of the join + None, } impl JoinSide { @@ -121,6 +141,7 @@ impl JoinSide { match self { JoinSide::Left => JoinSide::Right, JoinSide::Right => JoinSide::Left, + JoinSide::None => JoinSide::None, } } } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index e5d352a63c7a3..2c71cb80d7558 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -3864,6 +3864,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftAnti, JoinType::RightAnti, + JoinType::LeftMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -3881,7 +3882,10 @@ mod tests { let join_schema = physical_plan.schema(); match join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => { let left_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c1", &join_schema)?), Arc::new(Column::new_with_schema("c2", &join_schema)?), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index aa4bcb6837493..ff8f16f4ee9c0 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -328,7 +328,8 @@ fn adjust_input_keys_ordering( JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti - | JoinType::Full => vec![], + | JoinType::Full + | JoinType::LeftMark => vec![], }; } PartitionMode::Auto => { @@ -1959,6 +1960,7 @@ pub(crate) mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; @@ -1981,7 +1983,8 @@ pub(crate) mod tests { | JoinType::Right | JoinType::Full | JoinType::LeftSemi - | JoinType::LeftAnti => { + | JoinType::LeftAnti + | JoinType::LeftMark => { // Join on (a == c) let top_join_on = vec![( Arc::new(Column::new_with_schema("a", &join.schema()).unwrap()) @@ -1999,7 +2002,7 @@ pub(crate) mod tests { let expected = match join_type { // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![ top_join_plan.as_str(), join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", @@ -2098,7 +2101,7 @@ pub(crate) mod tests { assert_optimized!(expected, top_join.clone(), true); assert_optimized!(expected, top_join, false); } - JoinType::LeftSemi | JoinType::LeftAnti => {} + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {} } } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 1c63df1f0281f..2bf706f33d609 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -132,6 +132,9 @@ fn swap_join_type(join_type: JoinType) -> JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, + JoinType::LeftMark => { + unreachable!("LeftMark join type does not support swapping") + } } } @@ -573,6 +576,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.right().equivalence_properties(), hash_join.right().schema(), ), + JoinSide::None => return false, }; let name = schema.field(*index).name(); @@ -588,6 +592,7 @@ fn hash_join_convert_symmetric_subrule( match side { JoinSide::Left => hash_join.left().output_ordering(), JoinSide::Right => hash_join.right().output_ordering(), + JoinSide::None => unreachable!(), } .map(|p| p.to_vec()) }) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index c7677d725b036..fdbda1fe52f72 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -384,6 +384,7 @@ fn try_pushdown_requirements_to_join( return Ok(None); } } + JoinSide::None => return Ok(None), }; let join_type = smj.join_type(); let probe_side = SortMergeJoinExec::probe_side(&join_type); @@ -410,6 +411,7 @@ fn try_pushdown_requirements_to_join( JoinSide::Right => { required_input_ordering[1] = new_req; } + JoinSide::None => unreachable!(), } required_input_ordering })) @@ -421,7 +423,11 @@ fn expr_source_side( left_columns_len: usize, ) -> Option { match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { let all_column_sides = required_exprs .iter() .filter_map(|r| { diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 5c03bc3a91108..d7a3460e49879 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -234,6 +234,30 @@ async fn test_anti_join_1k_filtered() { .await } +#[tokio::test] +async fn test_left_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + None, + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + type JoinFilterBuilder = Box, Arc) -> JoinFilter>; struct JoinFuzzTestCase { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index e50ffb59d24a4..b7839c4873af8 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; @@ -1326,6 +1327,25 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { }) .collect() } + +fn mark_field(schema: &DFSchema) -> (Option, Arc) { + let mut table_references = schema + .iter() + .filter_map(|(qualifier, _)| qualifier) + .collect::>(); + table_references.dedup(); + let table_reference = if table_references.len() == 1 { + table_references.pop().cloned() + } else { + None + }; + + ( + table_reference, + Arc::new(Field::new("mark", DataType::Boolean, false)), + ) +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( @@ -1392,6 +1412,10 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::LeftMark => left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(right))) + .collect(), JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema right_fields diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index a301c48659d7c..8ba2a44842bcf 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -532,7 +532,9 @@ impl LogicalPlan { left.head_output_expr() } } - JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.head_output_expr() + } JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { @@ -1290,7 +1292,9 @@ impl LogicalPlan { _ => None, } } - JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.max_rows() + } JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 0ffc954388f5a..fa04835f0967b 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -181,7 +181,10 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re })?; Ok(()) } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => { check_inner_plan(left, can_contain_outer_ref)?; check_inner_plan(right, false) } diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index cc1687cffe921..7fdad5ba4b6e9 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,7 +17,6 @@ //! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; -use std::iter; use std::ops::Deref; use std::sync::Arc; @@ -34,11 +33,10 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; -use itertools::chain; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins @@ -138,17 +136,14 @@ fn rewrite_inner_subqueries( Expr::Exists(Exists { subquery: Subquery { subquery, .. }, negated, - }) => { - match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? - { - Some((plan, exists_expr)) => { - cur_input = plan; - Ok(Transformed::yes(exists_expr)) - } - None if negated => Ok(Transformed::no(not_exists(subquery))), - None => Ok(Transformed::no(exists(subquery))), + }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) } - } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + }, Expr::InSubquery(InSubquery { expr, subquery: Subquery { subquery, .. }, @@ -159,7 +154,7 @@ fn rewrite_inner_subqueries( .map_or(plan_err!("single expression required."), |output_expr| { Ok(Expr::eq(*expr.clone(), output_expr)) })?; - match existence_join( + match mark_join( &cur_input, Arc::clone(&subquery), Some(in_predicate), @@ -283,10 +278,6 @@ fn build_join_top( build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) } -/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join -/// and checking if the column is null or not. If native support is added for Existence/Mark then -/// we should use that instead. -/// /// This is used to handle the case when the subquery is embedded in a more complex boolean /// expression like and OR. For example /// @@ -296,37 +287,26 @@ fn build_join_top( /// /// ```text /// Projection: t1.id -/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL -/// Left Join: Filter: t1.id = __correlated_sq_1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.mark +/// LeftMark Join: Filter: t1.id = __correlated_sq_1.id /// TableScan: t1 /// SubqueryAlias: __correlated_sq_1 -/// Projection: t2.id, true as __exists +/// Projection: t2.id /// TableScan: t2 -fn existence_join( +fn mark_join( left: &LogicalPlan, subquery: Arc, in_predicate_opt: Option, negated: bool, alias_generator: &Arc, ) -> Result> { - // Add non nullable column to emulate existence join - let always_true_expr = lit(true).alias("__exists"); - let cols = chain( - subquery.schema().columns().into_iter().map(Expr::Column), - iter::once(always_true_expr), - ); - let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?; let alias = alias_generator.next("__correlated_sq"); - let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists")); - let exists_expr = if negated { - exists_col.is_null() - } else { - exists_col.is_not_null() - }; + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark")); + let exists_expr = if negated { !exists_col } else { exists_col }; Ok( - build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)? + build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)? .map(|plan| (plan, exists_expr)), ) } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 42eff7100fbe1..94c04d6328ed6 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -677,7 +677,11 @@ fn split_join_requirements( ) -> (RequiredIndicies, RequiredIndicies) { match join_type { // In these cases requirements are split between left/right children: - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a0262d7d95dfe..11436be98e39a 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -161,7 +161,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::Full => (false, false), // No columns from the right side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => (true, false), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. JoinType::RightSemi | JoinType::RightAnti => (false, true), @@ -186,6 +186,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftSemi | JoinType::RightSemi => (true, true), JoinType::LeftAnti => (false, true), JoinType::RightAnti => (true, false), + JoinType::LeftMark => (false, true), } } @@ -677,11 +678,13 @@ fn infer_join_predicates_from_on_filters( on_filters, inferred_predicates, ), - JoinType::Left | JoinType::LeftSemi => infer_join_predicates_impl::( - join_col_keys, - on_filters, - inferred_predicates, - ), + JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } JoinType::Right | JoinType::RightSemi => { infer_join_predicates_impl::( join_col_keys, diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index ec7a0a1364b6a..8a3aa4bb84599 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -248,7 +248,7 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { Left | Right | Full | Inner => (Some(limit), Some(limit)), - LeftAnti | LeftSemi => (Some(limit), None), + LeftAnti | LeftSemi | LeftMark => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), } } else { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index c1851ddb22b53..7305bc1b0a2b8 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -632,7 +632,7 @@ impl EquivalenceGroup { } result } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 2d11e03814a31..c56c179c17ebd 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -524,6 +524,7 @@ impl HashJoinExec { | JoinType::Full | JoinType::LeftAnti | JoinType::LeftSemi + | JoinType::LeftMark )); let mode = if pipeline_breaking { @@ -3091,6 +3092,94 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[test] fn join_with_hash_collision() -> Result<()> { let mut hashmap_left = RawTable::with_capacity(2); @@ -3476,6 +3565,15 @@ mod tests { "| 30 | 6 | 90 |", "+----+----+----+", ]; + let expected_left_mark = vec![ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3486,6 +3584,7 @@ mod tests { (JoinType::LeftAnti, expected_left_anti), (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), + (JoinType::LeftMark, expected_left_mark), ]; for (join_type, expected) in test_cases { @@ -3768,6 +3867,7 @@ mod tests { JoinType::LeftAnti, JoinType::RightSemi, JoinType::RightAnti, + JoinType::LeftMark, ]; for join_type in join_types { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 358ff02473a67..957230f513720 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -1244,6 +1244,37 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark_with_filter() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches) = multi_partitioned_join_collect( + left, + right, + &JoinType::LeftMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + let expected = [ + "+----+----+-----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+-----+-------+", + "| 11 | 8 | 110 | false |", + "| 5 | 5 | 50 | true |", + "| 9 | 8 | 90 | false |", + "+----+----+-----+-------+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[tokio::test] async fn test_overallocation() -> Result<()> { let left = build_table( @@ -1269,6 +1300,7 @@ pub(crate) mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index b299b495c5044..20fafcc347737 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -35,7 +35,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::{self, concat_batches, filter_record_batch, take, SortOptions}; +use arrow::compute::{ + self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, +}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; @@ -178,7 +180,8 @@ impl SortMergeJoinExec { | JoinType::Left | JoinType::Full | JoinType::LeftAnti - | JoinType::LeftSemi => JoinSide::Left, + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, } } @@ -186,7 +189,10 @@ impl SortMergeJoinExec { fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { JoinType::Inner => vec![true, false], - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => vec![true, false], JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { vec![false, true] } @@ -784,6 +790,29 @@ fn get_corrected_filter_mask( corrected_mask.extend(vec![Some(false); null_matched]); Some(corrected_mask.finish()) } + JoinType::LeftMark => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } JoinType::LeftSemi => { for i in 0..row_indices_length { let last_index = @@ -860,6 +889,7 @@ impl Stream for SMJStream { self.join_type, JoinType::Left | JoinType::LeftSemi + | JoinType::LeftMark | JoinType::Right | JoinType::LeftAnti ) @@ -943,6 +973,7 @@ impl Stream for SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::LeftMark ) { continue; @@ -964,6 +995,7 @@ impl Stream for SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::LeftMark ) { let out = self.filter_joined_batch()?; @@ -1264,6 +1296,8 @@ impl SMJStream { let mut join_streamed = false; // Whether to join buffered rows let mut join_buffered = false; + // For Mark join we store a dummy id to indicate the the row has a match + let mut mark_row_as_match = false; // determine whether we need to join streamed/buffered rows match self.current_ordering { @@ -1275,12 +1309,14 @@ impl SMJStream { | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti + | JoinType::LeftMark ) { join_streamed = !self.streamed_joined; } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) { + if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before // the `join_filter_matched_idxs` keeps track on if streamed index has a successful @@ -1357,9 +1393,11 @@ impl SMJStream { } else { Some(self.buffered_data.scanning_batch_idx) }; + // For Mark join we store a dummy id to indicate the the row has a match + let scanning_idx = mark_row_as_match.then_some(0); self.streamed_batch - .append_output_pair(scanning_batch_idx, None); + .append_output_pair(scanning_batch_idx, scanning_idx); self.output_size += 1; self.buffered_data.scanning_finish(); self.streamed_joined = true; @@ -1461,24 +1499,25 @@ impl SMJStream { // The row indices of joined buffered batch let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut buffered_columns = - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - vec![] - } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - get_buffered_columns( - &self.buffered_data, - buffered_idx, - &buffered_indices, - )? - } else { - // If buffered batch none, meaning it is null joined batch. - // We need to create null arrays for buffered columns to join with streamed rows. - self.buffered_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>() - }; + let mut buffered_columns = if matches!(self.join_type, JoinType::LeftMark) { + vec![Arc::new(is_not_null(&buffered_indices)?) as ArrayRef] + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + vec![] + } else if let Some(buffered_idx) = chunk.buffered_batch_idx { + get_buffered_columns( + &self.buffered_data, + buffered_idx, + &buffered_indices, + )? + } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + create_unmatched_columns( + self.join_type, + &self.buffered_schema, + buffered_indices.len(), + ) + }; let streamed_columns_length = streamed_columns.len(); @@ -1489,7 +1528,7 @@ impl SMJStream { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) } else if matches!( self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark ) { // unwrap is safe here as we check is_some on top of if statement let buffered_columns = get_buffered_columns( @@ -1517,7 +1556,6 @@ impl SMJStream { }; let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - // Apply join filter if any if !filter_columns.is_empty() { if let Some(f) = &self.filter { @@ -1553,6 +1591,7 @@ impl SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::LeftMark ) { self.output_record_batches .batches @@ -1691,6 +1730,7 @@ impl SMJStream { | JoinType::LeftSemi | JoinType::Right | JoinType::LeftAnti + | JoinType::LeftMark )) { self.output_record_batches.batches.clear(); @@ -1721,16 +1761,18 @@ impl SMJStream { let buffered_columns_length = self.buffered_schema.fields.len(); let streamed_columns_length = self.streamed_schema.fields.len(); - if matches!(self.join_type, JoinType::Left | JoinType::Right) { + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftMark | JoinType::Right + ) { let null_mask = compute::not(corrected_mask)?; let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; - let mut buffered_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), null_joined_batch.num_rows())) - .collect::>(); + let mut buffered_columns = create_unmatched_columns( + self.join_type, + &self.buffered_schema, + null_joined_batch.num_rows(), + ); let columns = if matches!(self.join_type, JoinType::Right) { let streamed_columns = null_joined_batch @@ -1777,6 +1819,22 @@ impl SMJStream { } } +fn create_unmatched_columns( + join_type: JoinType, + schema: &SchemaRef, + size: usize, +) -> Vec { + if matches!(join_type, JoinType::LeftMark) { + vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] + } else { + schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), size)) + .collect::>() + } +} + /// Gets the arrays which join filters are applied on. fn get_filter_column( join_filter: &Option, @@ -2716,6 +2774,39 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftMark).await?; + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { let left = build_table( @@ -3047,7 +3138,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3125,7 +3216,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3181,7 +3272,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3282,7 +3373,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index eb6a30d17e925..3e0cd48da2bf3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -62,6 +62,7 @@ use arrow::array::{ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_buffer::ArrowNativeType; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; @@ -670,7 +671,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> if build_side == JoinSide::Left { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + JoinType::Left + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftSemi + | JoinType::LeftMark ) } else { matches!( @@ -709,6 +714,20 @@ where { // Store the result in a tuple let result = match (build_side, join_type) { + (JoinSide::Left, JoinType::LeftMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) @@ -872,6 +891,7 @@ pub(crate) fn join_with_probe_batch( JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi + | JoinType::LeftMark | JoinType::RightSemi ) { Ok(None) @@ -1707,6 +1727,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1791,6 +1812,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1855,6 +1877,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1906,6 +1929,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1933,6 +1957,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2298,6 +2323,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2380,6 +2406,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -2454,6 +2481,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 090cf9aa628a7..e7c191f9835ec 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; +use std::iter::once; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; @@ -619,6 +620,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::RightSemi => false, // doesn't introduce nulls JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) + JoinType::LeftMark => false, }; if force_nullable { @@ -635,44 +637,10 @@ pub fn build_join_schema( right: &Schema, join_type: &JoinType, ) -> (Schema, Vec) { - let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left - .fields() - .iter() - .map(|f| output_join_field(f, join_type, true)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Left, - }, - ) - }); - let right_fields = right - .fields() - .iter() - .map(|f| output_join_field(f, join_type, false)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Right, - }, - ) - }); - - // left then right - left_fields.chain(right_fields).unzip() - } - JoinType::LeftSemi | JoinType::LeftAnti => left - .fields() + let left_fields = || { + left.fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, true)) .enumerate() .map(|(index, f)| { ( @@ -683,11 +651,13 @@ pub fn build_join_schema( }, ) }) - .unzip(), - JoinType::RightSemi | JoinType::RightAnti => right + }; + + let right_fields = || { + right .fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, false)) .enumerate() .map(|(index, f)| { ( @@ -698,7 +668,25 @@ pub fn build_join_schema( }, ) }) - .unzip(), + }; + + let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + // left then right + left_fields().chain(right_fields()).unzip() + } + JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), + JoinType::LeftMark => { + let right_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + left_fields().chain(right_field).unzip() + } + JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), }; let metadata = left @@ -902,6 +890,16 @@ fn estimate_join_cardinality( column_statistics: outer_stats.column_statistics, }) } + + JoinType::LeftMark => { + let num_rows = *left_stats.num_rows.get_value()?; + let mut column_statistics = left_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -1153,7 +1151,11 @@ impl OnceFut { pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Full ) } @@ -1171,6 +1173,13 @@ pub(crate) fn get_final_indices_from_bit_map( join_type: JoinType, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); + if join_type == JoinType::LeftMark { + let left_indices = (0..left_size as u64).collect::(); + let right_indices = (0..left_size) + .map(|idx| left_bit_map.get_bit(idx).then_some(0)) + .collect::(); + return (left_indices, right_indices); + } let left_indices = if join_type == JoinType::LeftSemi { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) @@ -1254,7 +1263,10 @@ pub(crate) fn build_batch_from_indices( let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for column_index in column_indices { - let array = if column_index.side == build_side { + let array = if column_index.side == JoinSide::None { + // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + Arc::new(compute::is_not_null(probe_indices)?) + } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); if array.is_empty() || build_indices.null_count() == build_indices.len() { // Outer join would generate a null index when finding no match at our side. @@ -1323,7 +1335,7 @@ pub(crate) fn adjust_indices_by_join_type( // the left_indices will not be used later for the `right anti` join Ok((left_indices, right_indices)) } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop Ok(( @@ -1646,7 +1658,7 @@ pub(crate) fn symmetric_join_output_partitioning( let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); match join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), @@ -1671,11 +1683,13 @@ pub(crate) fn asymmetric_join_output_partitioning( left.schema().fields().len(), ), JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { - Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ) - } + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftMark => Partitioning::UnknownPartitioning( + right.output_partitioning().partition_count(), + ), } } diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 7f8bce6b206e3..65cd33d523cd6 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -84,6 +84,7 @@ enum JoinType { LEFTANTI = 5; RIGHTSEMI = 6; RIGHTANTI = 7; + LEFTMARK = 8; } enum JoinConstraint { @@ -541,9 +542,10 @@ message ParquetOptions { string created_by = 16; } -enum JoinSide{ +enum JoinSide { LEFT_SIDE = 0; RIGHT_SIDE = 1; + NONE = 2; } message Precision{ diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index d848f795c6841..a554e4ed28051 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -778,6 +778,7 @@ impl From for JoinSide { match t { protobuf::JoinSide::LeftSide => JoinSide::Left, protobuf::JoinSide::RightSide => JoinSide::Right, + protobuf::JoinSide::None => JoinSide::None, } } } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index e8b46fbf7012f..e8235ef7b9dd5 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -3761,6 +3761,7 @@ impl serde::Serialize for JoinSide { let variant = match self { Self::LeftSide => "LEFT_SIDE", Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", }; serializer.serialize_str(variant) } @@ -3774,6 +3775,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide { const FIELDS: &[&str] = &[ "LEFT_SIDE", "RIGHT_SIDE", + "NONE", ]; struct GeneratedVisitor; @@ -3816,6 +3818,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide { match value { "LEFT_SIDE" => Ok(JoinSide::LeftSide), "RIGHT_SIDE" => Ok(JoinSide::RightSide), + "NONE" => Ok(JoinSide::None), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -3838,6 +3841,7 @@ impl serde::Serialize for JoinType { Self::Leftanti => "LEFTANTI", Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", }; serializer.serialize_str(variant) } @@ -3857,6 +3861,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "LEFTANTI", "RIGHTSEMI", "RIGHTANTI", + "LEFTMARK", ]; struct GeneratedVisitor; @@ -3905,6 +3910,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "LEFTANTI" => Ok(JoinType::Leftanti), "RIGHTSEMI" => Ok(JoinType::Rightsemi), "RIGHTANTI" => Ok(JoinType::Rightanti), + "LEFTMARK" => Ok(JoinType::Leftmark), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 939a4b3c2cd2a..68e7f74c7f493 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -883,6 +883,7 @@ pub enum JoinType { Leftanti = 5, Rightsemi = 6, Rightanti = 7, + Leftmark = 8, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -899,6 +900,7 @@ impl JoinType { Self::Leftanti => "LEFTANTI", Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -912,6 +914,7 @@ impl JoinType { "LEFTANTI" => Some(Self::Leftanti), "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), + "LEFTMARK" => Some(Self::Leftmark), _ => None, } } @@ -1069,6 +1072,7 @@ impl CompressionTypeVariant { pub enum JoinSide { LeftSide = 0, RightSide = 1, + None = 2, } impl JoinSide { /// String value of the enum field names used in the ProtoBuf definition. @@ -1079,6 +1083,7 @@ impl JoinSide { match self { Self::LeftSide => "LEFT_SIDE", Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1086,6 +1091,7 @@ impl JoinSide { match value { "LEFT_SIDE" => Some(Self::LeftSide), "RIGHT_SIDE" => Some(Self::RightSide), + "NONE" => Some(Self::None), _ => None, } } diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index f9b8973e2d413..02a642a4af937 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -759,6 +759,7 @@ impl From for protobuf::JoinSide { match t { JoinSide::Left => protobuf::JoinSide::LeftSide, JoinSide::Right => protobuf::JoinSide::RightSide, + JoinSide::None => protobuf::JoinSide::None, } } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 939a4b3c2cd2a..68e7f74c7f493 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -883,6 +883,7 @@ pub enum JoinType { Leftanti = 5, Rightsemi = 6, Rightanti = 7, + Leftmark = 8, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -899,6 +900,7 @@ impl JoinType { Self::Leftanti => "LEFTANTI", Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -912,6 +914,7 @@ impl JoinType { "LEFTANTI" => Some(Self::Leftanti), "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), + "LEFTMARK" => Some(Self::Leftmark), _ => None, } } @@ -1069,6 +1072,7 @@ impl CompressionTypeVariant { pub enum JoinSide { LeftSide = 0, RightSide = 1, + None = 2, } impl JoinSide { /// String value of the enum field names used in the ProtoBuf definition. @@ -1079,6 +1083,7 @@ impl JoinSide { match self { Self::LeftSide => "LEFT_SIDE", Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1086,6 +1091,7 @@ impl JoinSide { match value { "LEFT_SIDE" => Some(Self::LeftSide), "RIGHT_SIDE" => Some(Self::RightSide), + "NONE" => Some(Self::None), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 27bda7dd5ace6..f25fb0bf2561c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -213,6 +213,7 @@ impl From for JoinType { protobuf::JoinType::Rightsemi => JoinType::RightSemi, protobuf::JoinType::Leftanti => JoinType::LeftAnti, protobuf::JoinType::Rightanti => JoinType::RightAnti, + protobuf::JoinType::Leftmark => JoinType::LeftMark, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 5a6f3a32c668e..8af7b19d9091e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -685,6 +685,7 @@ impl From for protobuf::JoinType { JoinType::RightSemi => protobuf::JoinType::Rightsemi, JoinType::LeftAnti => protobuf::JoinType::Leftanti, JoinType::RightAnti => protobuf::JoinType::Rightanti, + JoinType::LeftMark => protobuf::JoinType::Leftmark, } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 2c38a1d36c1ea..6348aba49082e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -552,7 +552,7 @@ impl Unparser<'_> { relation, global: false, join_operator: self - .join_operator_to_sql(join.join_type, join_constraint), + .join_operator_to_sql(join.join_type, join_constraint)?, }; let mut from = select.pop_from().unwrap(); from.push_join(ast_join); @@ -855,8 +855,8 @@ impl Unparser<'_> { &self, join_type: JoinType, constraint: ast::JoinConstraint, - ) -> ast::JoinOperator { - match join_type { + ) -> Result { + Ok(match join_type { JoinType::Inner => ast::JoinOperator::Inner(constraint), JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), @@ -865,7 +865,8 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), - } + JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join type"), + }) } /// Convert the components of a USING clause to the USING AST. Returns diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 36de19f1c3aa7..027b5ca8dcfb0 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1056,13 +1056,11 @@ where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1085,13 +1083,12 @@ where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1) +07)----------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1113,13 +1110,11 @@ where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1132,6 +1127,9 @@ where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) 22 b 2 44 d 4 +statement ok +set datafusion.explain.logical_plan_only = false; + # not_exists_subquery_to_join_with_correlated_outer_filter_disjunction query TT explain select t1.t1_id, @@ -1142,13 +1140,27 @@ where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--FilterExec: t1_id@0 > 40 OR NOT mark@3, projection=[t1_id@0, t1_name@1, t1_int@2] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(t1_id@0, t2_id@0)] +05)--------CoalesceBatchesExec: target_batch_size=2 +06)----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +09)--------CoalesceBatchesExec: target_batch_size=2 +10)----------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +11)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)--------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.explain.logical_plan_only = true; query ITI rowsort select t1.t1_id, @@ -1170,16 +1182,14 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists -04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) -05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id -06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int] -07)----------SubqueryAlias: __correlated_sq_1 -08)------------TableScan: t3 projection=[t3_id] -09)--------SubqueryAlias: __correlated_sq_2 -10)----------Projection: t2.t2_id, Boolean(true) AS __exists -11)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------TableScan: t3 projection=[t3_id] +08)------SubqueryAlias: __correlated_sq_2 +09)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1192,6 +1202,18 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s 22 b 2 44 d 4 +# Handle duplicate values in exists query +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + # Nested subqueries query ITI rowsort select t1.t1_id, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 2aaf8ec0aa06b..289aa7b7f4489 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1226,6 +1226,7 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Outer => Ok(JoinType::Full), join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 408885f70687f..c73029f130ad8 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -725,7 +725,10 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::Full => join_rel::JoinType::Outer, JoinType::LeftAnti => join_rel::JoinType::LeftAnti, JoinType::LeftSemi => join_rel::JoinType::LeftSemi, - JoinType::RightAnti | JoinType::RightSemi => unimplemented!(), + JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightAnti | JoinType::RightSemi => { + unimplemented!() + } } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 04530dd34d4bf..8fbdefe2852ef 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -453,15 +453,15 @@ async fn roundtrip_inlist_5() -> Result<()> { // on roundtrip there is an additional projection during TableScan which includes all column of the table, // using assert_expected_plan here as a workaround assert_expected_plan( - "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Projection: data.a, data.f\ - \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\ - \n Projection: data.a, data.f, Boolean(true)\ - \n Left Join: data.a = data2.a\ - \n TableScan: data projection=[a, f]\ - \n Projection: data2.a, Boolean(true)\ - \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ - \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\ + \n LeftMark Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", true).await }