diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index b86dc0f48c149..3f9a04cb978c3 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -368,15 +368,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse ON expression let expr = self.sql_to_rex(sql_expr, &join_schema)?; + // expression that didn't match equi-join pattern + let mut filter = vec![]; + // extract join keys - extract_join_keys(&expr, &mut keys)?; + extract_join_keys(&expr, &mut keys, &mut filter); let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); // return the logical plan representing the join - LogicalPlanBuilder::from(left) - .join(right, join_type, left_keys, right_keys)? + let join = LogicalPlanBuilder::from(left) + .join(right, join_type, left_keys, right_keys)?; + + if filter.is_empty() { + join.build() + } else if join_type == JoinType::Inner { + join.filter( + filter + .iter() + .skip(1) + .fold(filter[0].clone(), |acc, e| acc.and(e.clone())), + )? .build() + } else { + Err(DataFusionError::NotImplemented(format!( + "Unsupported expressions in {:?} JOIN: {:?}", + join_type, filter + ))) + } } JoinConstraint::Using(idents) => { let keys: Vec = idents @@ -1549,39 +1568,41 @@ fn remove_join_expressions( } } -/// Parse equijoin ON condition which could be a single Eq or multiple conjunctive Eqs +/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs +/// Filters matching this pattern are added to `accum` +/// Filters that don't match this pattern are added to `accum_filter` +/// Examples: /// -/// Examples +/// foo = bar => accum=[(foo, bar)] accum_filter=[] +/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] +/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] /// -/// foo = bar -/// foo = bar AND bar = baz AND ... -/// -fn extract_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) -> Result<()> { +fn extract_join_keys( + expr: &Expr, + accum: &mut Vec<(Column, Column)>, + accum_filter: &mut Vec, +) { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { accum.push((l.clone(), r.clone())); - Ok(()) } - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), + _other => { + accum_filter.push(expr.clone()); + } }, Operator::And => { - extract_join_keys(left, accum)?; - extract_join_keys(right, accum) + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + _other => { + accum_filter.push(expr.clone()); } - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), }, - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), + _other => { + accum_filter.push(expr.clone()); + } } } @@ -2701,6 +2722,20 @@ mod tests { quick_test(sql, expected); } + #[test] + fn equijoin_unsupported_expression() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders \ + ON id = customer_id AND order_id > 1 "; + let expected = "Projection: #person.id, #orders.order_id\ + \n Filter: #orders.order_id Gt Int64(1)\ + \n Join: #person.id = #orders.customer_id\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn join_with_table_name() { let sql = "SELECT id, order_id \ diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index c06a4bb1462ee..3445df3b741b7 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1694,6 +1694,28 @@ async fn equijoin() -> Result<()> { Ok(()) } +#[tokio::test] +async fn equijoin_and_other_condition() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["11", "a", "z"], vec!["22", "b", "y"]]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_and_unsupported_condition() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let res = ctx.create_logical_plan(sql); + assert!(res.is_err()); + assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t2.t2_name GtEq Utf8(\"y\")]"); + Ok(()) +} + #[tokio::test] async fn left_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?;