diff --git a/Cargo.lock b/Cargo.lock index b9dbd37463dd6..b19ac864176d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2755,7 +2755,6 @@ dependencies = [ "substrait", "tokio", "url", - "uuid", ] [[package]] diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 8bfec86497ef0..a101572b88906 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -46,7 +46,6 @@ prost = { workspace = true } substrait = { version = "0.62", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } -uuid = { version = "1.19.0", features = ["v4"] } [dev-dependencies] datafusion = { workspace = true, features = ["nested_expressions", "unicode_expressions"] } diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs index 07f9a34888fc4..d216d4ecf3188 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -62,20 +62,7 @@ pub async fn from_project_rel( // to transform it into a column reference window_exprs.insert(e.clone()); } - // Substrait plans are ordinal based, so they do not provide names for columns. - // Names for columns are generated by Datafusion during conversion, and for literals - // Datafusion produces names based on the literal value. It is possible to construct - // valid Substrait plans that result in duplicated names if the same literal value is - // used in multiple relations. To avoid this issue, we alias literals with unique names. - // The name tracker will ensure that two literals in the same project would have - // unique names but, it does not ensure that if a literal column exists in a previous - // project say before a join that it is deduplicated with respect to those columns. - // See: https://github.com/apache/datafusion/pull/17299 - let maybe_apply_alias = match e { - lit @ Expr::Literal(_, _) => lit.alias(uuid::Uuid::new_v4().to_string()), - _ => e, - }; - explicit_exprs.push(name_tracker.get_uniquely_named_expr(maybe_apply_alias)?); + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } let input = if !window_exprs.is_empty() { diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index 9325926c278ad..e18be424ab366 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -359,7 +359,10 @@ fn compatible_nullabilities( } pub(super) struct NameTracker { - seen_names: HashSet, + /// Tracks seen schema names (from expr.schema_name()) + seen_schema_names: HashSet, + /// Tracks seen qualified names (the name part from expr.qualified_name()) + seen_qualified_names: HashSet, } pub(super) enum NameTrackerStatus { @@ -370,25 +373,42 @@ pub(super) enum NameTrackerStatus { impl NameTracker { pub(super) fn new() -> Self { NameTracker { - seen_names: HashSet::default(), + seen_schema_names: HashSet::default(), + seen_qualified_names: HashSet::default(), } } - pub(super) fn get_unique_name( + + /// Gets a unique name that is unique in both schema names and qualified names + fn get_unique_name( &mut self, - name: String, + schema_name: String, + qualified_name: String, ) -> (String, NameTrackerStatus) { - match self.seen_names.insert(name.clone()) { - true => (name, NameTrackerStatus::NeverSeen), - false => { - let mut counter = 0; - loop { - let candidate_name = format!("{name}__temp__{counter}"); - if self.seen_names.insert(candidate_name.clone()) { - return (candidate_name, NameTrackerStatus::SeenBefore); - } - counter += 1; - } + // Check if both names are unique + let schema_unique = !self.seen_schema_names.contains(&schema_name); + let qualified_unique = !self.seen_qualified_names.contains(&qualified_name); + + if schema_unique && qualified_unique { + // Both are unique, mark them as seen and return + self.seen_schema_names.insert(schema_name.clone()); + self.seen_qualified_names.insert(qualified_name); + return (schema_name, NameTrackerStatus::NeverSeen); + } + + // Need to generate a unique name + let mut counter = 0; + loop { + let candidate_name = format!("{schema_name}__temp__{counter}"); + + // Check if the candidate is unique in both sets + if !self.seen_schema_names.contains(&candidate_name) + && !self.seen_qualified_names.contains(&candidate_name) + { + self.seen_schema_names.insert(candidate_name.clone()); + self.seen_qualified_names.insert(candidate_name.clone()); + return (candidate_name, NameTrackerStatus::SeenBefore); } + counter += 1; } } @@ -396,7 +416,11 @@ impl NameTracker { &mut self, expr: Expr, ) -> datafusion::common::Result { - match self.get_unique_name(expr.name_for_alias()?) { + // Get both the schema name and the qualified name + let schema_name = expr.schema_name().to_string(); + let (_qualifier, qualified_name) = expr.qualified_name(); + + match self.get_unique_name(schema_name, qualified_name) { (_, NameTrackerStatus::NeverSeen) => Ok(expr), (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), } diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 194098cf060e3..ea7da33beb3a7 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -608,7 +608,7 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: left.count(Int64(1)) AS count_first, left.category, left.count(Int64(1)):1 AS count_second, right.count(Int64(1)) AS count_third + Projection: left.count(Int64(1)) AS count_first, left.category, left.count(Int64(1)):1 AS count_second, right.count(Int64(1)) AS right.count(Int64(1))__temp__0 AS count_third Left Join: left.id = right.id SubqueryAlias: left Projection: left.id, left.count(Int64(1)), left.id:1, left.category, right.id AS id:2, right.count(Int64(1)) AS count(Int64(1)):1 @@ -651,31 +651,23 @@ mod tests { #[tokio::test] async fn test_multiple_unions() -> Result<()> { let plan_str = test_plan_to_string("multiple_unions.json").await?; - - let mut settings = insta::Settings::clone_current(); - settings.add_filter( - r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", - "[UUID]", - ); - settings.bind(|| { - assert_snapshot!( - plan_str, - @r#" - Projection: [UUID] AS product_category, [UUID] AS product_type, product_key - Union - Projection: Utf8("people") AS [UUID], Utf8("people") AS [UUID], sales.product_key - Left Join: sales.product_key = food.@food_id - TableScan: sales - TableScan: food - Union - Projection: people.$f3, people.$f5, people.product_key0 - Left Join: people.product_key0 = food.@food_id - TableScan: people - TableScan: food - TableScan: more_products - "# + assert_snapshot!( + plan_str, + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# ); - }); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 5ebacaf5336d4..41f08c579f471 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -157,28 +157,21 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; - let mut settings = insta::Settings::clone_current(); - settings.add_filter( - r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", - "[UUID]", + assert_snapshot!( + plan, + @r" + Projection: left.A, left.Utf8(NULL) AS C, right.D, Utf8(NULL) AS Utf8(NULL)__temp__0 AS E + Left Join: left.A = right.A + SubqueryAlias: left + Union + Projection: A.A, Utf8(NULL) + TableScan: A + Projection: B.A, CAST(B.C AS Utf8) + TableScan: B + SubqueryAlias: right + TableScan: C + " ); - settings.bind(|| { - assert_snapshot!( - plan, - @r" - Projection: left.A, left.[UUID] AS C, right.D, Utf8(NULL) AS [UUID] AS E - Left Join: left.A = right.A - SubqueryAlias: left - Union - Projection: A.A, Utf8(NULL) AS [UUID] - TableScan: A - Projection: B.A, CAST(B.C AS Utf8) - TableScan: B - SubqueryAlias: right - TableScan: C - " - ); - }); // Trigger execution to ensure plan validity DataFrame::new(ctx.state(), plan).show().await?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 98b35bf082ec4..1948de61e72d9 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -789,17 +789,50 @@ async fn roundtrip_outer_join() -> Result<()> { async fn roundtrip_self_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?; - roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + // Test second variant + let sql2 = "SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b"; + let df2 = ctx.sql(sql2).await?; + let plan3 = df2.into_optimized_plan()?; + let plan4 = substrait_roundtrip(&plan3, &ctx).await?; + assert_eq!(plan3.schema(), plan4.schema()); + DataFrame::new(ctx.state(), plan4).show().await?; + + Ok(()) } #[tokio::test] async fn roundtrip_self_implicit_cross_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + Ok(()) } #[tokio::test] @@ -1456,16 +1489,26 @@ async fn roundtrip_values_empty_relation() -> Result<()> { async fn roundtrip_values_duplicate_column_join() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. - // This roundtrip works because we set aliases to what the Substrait consumer will generate. - roundtrip( - "SELECT left.column1 as c1, right.column1 as c2 \ + // The improved NameTracker now adds __temp__0 suffix to handle naming conflicts. + // We verify semantic equivalence rather than exact string match. + let ctx = create_context().await?; + let sql = "SELECT left.column1 as c1, right.column1 as c2 \ FROM \ (VALUES (1)) AS left \ JOIN \ (VALUES (2)) AS right \ - ON left.column1 == right.column1", - ) - .await + ON left.column1 == right.column1"; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let plan2 = substrait_roundtrip(&plan, &ctx).await?; + + // Verify schemas are equivalent + assert_eq!(plan.schema(), plan2.schema()); + + // Execute to ensure plan validity + DataFrame::new(ctx.state(), plan2).show().await?; + + Ok(()) } #[tokio::test]