From 1a21669b4345032591b2758be70b789435e6045e Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Fri, 2 May 2025 08:27:27 -0500 Subject: [PATCH 1/5] infer placeholder datatype for InSubquery only infer subquery if exactly 1 field --- datafusion/expr/src/expr.rs | 103 +++++++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b8e4204a9c9ea..7b2210f50b315 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1775,6 +1775,26 @@ impl Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated: _, + }) => { + let subquery_schema = subquery.subquery.schema(); + let fields = subquery_schema.fields(); + + // only supports subquery with exactly 1 field + if let [first_field] = &fields[..] { + rewrite_placeholder( + expr.as_mut(), + &Expr::Column(Column { + relation: None, + name: first_field.name().clone(), + }), + schema, + )?; + } + } Expr::Placeholder(_) => { has_placeholder = true; } @@ -3198,7 +3218,8 @@ mod test { use crate::expr_fn::col; use crate::{ case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, TableScan, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -3260,6 +3281,86 @@ mod test { } } + #[test] + fn infer_placeholder_in_subquery() -> Result<()> { + // Schema for my_table: A (Int32), B (Int32) + let schema = Arc::new(Schema::new(vec![ + Field::new("A", DataType::Int32, true), + Field::new("B", DataType::Int32, true), + ])); + + let source = Arc::new(LogicalTableSource::new(schema.clone())); + + // Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3); + let placeholder = Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }); + + // Subquery: SELECT A FROM my_table WHERE B > 3 + let subquery_filter = Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("B")), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))), + }); + + let subquery_scan = LogicalPlan::TableScan(TableScan { + table_name: TableReference::from("my_table"), + source, + projected_schema: Arc::new(DFSchema::try_from(schema.clone())?), + projection: None, + filters: vec![subquery_filter.clone()], + fetch: None, + }); + + let projected_fields = vec![Field::new("A", DataType::Int32, true)]; + let projected_schema = Arc::new(DFSchema::from_unqualified_fields( + projected_fields.into(), + Default::default(), + )?); + + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::Projection(Projection { + expr: vec![col("A")], + input: Arc::new(subquery_scan), + schema: projected_schema, + })), + outer_ref_columns: vec![], + }; + + let in_subquery = Expr::InSubquery(InSubquery { + expr: Box::new(placeholder), + subquery, + negated: false, + }); + + let df_schema = DFSchema::try_from(schema)?; + + let (inferred_expr, contains_placeholder) = + in_subquery.infer_placeholder_types(&df_schema)?; + + assert!( + contains_placeholder, + "Expression should contain a placeholder" + ); + + match inferred_expr { + Expr::InSubquery(in_subquery) => match *in_subquery.expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder $1 should infer Int32" + ); + } + _ => panic!("Expected Placeholder expression in InSubquery"), + }, + _ => panic!("Expected InSubquery expression"), + } + + Ok(()) + } + #[test] fn infer_placeholder_like_and_similar_to() { // name LIKE $1 From cc97f1f25081bc0636c5b8b2e7782c5126b3d4af Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Wed, 7 May 2025 09:48:04 -0500 Subject: [PATCH 2/5] infer placeholder after limit as Int64 --- datafusion/expr/src/expr.rs | 2 + datafusion/expr/src/logical_plan/plan.rs | 48 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7b2210f50b315..813bae7302dbb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1790,6 +1790,7 @@ impl Expr { &Expr::Column(Column { relation: None, name: first_field.name().clone(), + spans: Spans::new(), }), schema, )?; @@ -3326,6 +3327,7 @@ mod test { schema: projected_schema, })), outer_ref_columns: vec![], + spans: Spans::new(), }; let in_subquery = Expr::InSubquery(InSubquery { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index edf5f1126be93..be928bf493e30 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1494,6 +1494,14 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { + if let LogicalPlan::Limit(Limit { fetch: Some(e), .. }) = plan { + if let Expr::Placeholder(Placeholder { id, data_type }) = &**e { + param_types.insert( + id.clone(), + Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)), + ); + } + } plan.apply_expressions(|expr| { expr.apply(|expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { @@ -1507,6 +1515,9 @@ impl LogicalPlan { (_, Some(dt)) => { param_types.insert(id.clone(), Some(dt.clone())); } + (Some(Some(_)), None) => { + // we have already inferred the datatype + } _ => { param_types.insert(id.clone(), None); } @@ -4029,6 +4040,43 @@ mod tests { .build() } + #[test] + fn test_resolved_placeholder_limit() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)])); + let source = Arc::new(LogicalTableSource::new(schema.clone())); + + let placeholder_value = "$1"; + + // SELECT * FROM my_table LIMIT $1 + let plan = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(Expr::Placeholder(Placeholder { + id: placeholder_value.to_string(), + data_type: None, + }))), + input: Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::from("my_table"), + source, + projected_schema: Arc::new(DFSchema::try_from(schema.clone())?), + projection: None, + filters: vec![], + fetch: None, + })), + }); + + let params = plan.get_parameter_types().expect("to infer type"); + assert_eq!(params.len(), 1); + + let parameter_type = params + .clone() + .get(placeholder_value) + .expect("to get type") + .clone(); + assert_eq!(parameter_type, Some(DataType::Int64)); + + Ok(()) + } + #[test] fn test_display_indent() -> Result<()> { let plan = display_plan()?; From 44084f0d83387f5892e9aca825a06c2eb15bc1be Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Wed, 7 May 2025 10:50:24 -0500 Subject: [PATCH 3/5] fix clippy lints --- datafusion/expr/src/expr.rs | 4 ++-- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 813bae7302dbb..3e5d95a2652a6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3290,7 +3290,7 @@ mod test { Field::new("B", DataType::Int32, true), ])); - let source = Arc::new(LogicalTableSource::new(schema.clone())); + let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema))); // Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3); let placeholder = Expr::Placeholder(Placeholder { @@ -3308,7 +3308,7 @@ mod test { let subquery_scan = LogicalPlan::TableScan(TableScan { table_name: TableReference::from("my_table"), source, - projected_schema: Arc::new(DFSchema::try_from(schema.clone())?), + projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?), projection: None, filters: vec![subquery_filter.clone()], fetch: None, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index be928bf493e30..27f3d662917b4 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4043,7 +4043,7 @@ mod tests { #[test] fn test_resolved_placeholder_limit() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)])); - let source = Arc::new(LogicalTableSource::new(schema.clone())); + let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema))); let placeholder_value = "$1"; @@ -4057,7 +4057,7 @@ mod tests { input: Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::from("my_table"), source, - projected_schema: Arc::new(DFSchema::try_from(schema.clone())?), + projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?), projection: None, filters: vec![], fetch: None, From 17c806d8d2da908abf02b4aeecd7a041563cd362 Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Thu, 8 May 2025 11:24:19 -0500 Subject: [PATCH 4/5] add comment about subqueries --- datafusion/expr/src/expr.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e5d95a2652a6..165d275c3012c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1784,6 +1784,7 @@ impl Expr { let fields = subquery_schema.fields(); // only supports subquery with exactly 1 field + // https://github.com/apache/datafusion/blob/main/datafusion/sql/src/expr/subquery.rs#L120 if let [first_field] = &fields[..] { rewrite_placeholder( expr.as_mut(), From 2f5c686d368eec57cc6232a4b65bc04187ce5647 Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Thu, 8 May 2025 12:37:10 -0500 Subject: [PATCH 5/5] add comments; improve test; infer Limit skip as Int64 --- datafusion/expr/src/logical_plan/plan.rs | 91 ++++++++++++++++++++---- 1 file changed, 77 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 27f3d662917b4..52ec7065d17a3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1494,13 +1494,29 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { - if let LogicalPlan::Limit(Limit { fetch: Some(e), .. }) = plan { - if let Expr::Placeholder(Placeholder { id, data_type }) = &**e { + if let LogicalPlan::Limit(Limit { + fetch: Some(f), + skip, + .. + }) = plan + { + if let Expr::Placeholder(Placeholder { id, data_type }) = &**f { + // Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242 param_types.insert( id.clone(), Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)), ); } + + if let Some(s) = skip { + if let Expr::Placeholder(Placeholder { id, data_type }) = &**s { + // Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242 + param_types.insert( + id.clone(), + Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)), + ); + } + } } plan.apply_expressions(|expr| { expr.apply(|expr| { @@ -1516,7 +1532,8 @@ impl LogicalPlan { param_types.insert(id.clone(), Some(dt.clone())); } (Some(Some(_)), None) => { - // we have already inferred the datatype + // we have already inferred the datatype like + // the LIMIT case handled specially above. } _ => { param_types.insert(id.clone(), None); @@ -4045,13 +4062,16 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)])); let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema))); - let placeholder_value = "$1"; + let placeholders = ["$1", "$2"]; - // SELECT * FROM my_table LIMIT $1 + // SELECT * FROM my_table LIMIT $1 OFFSET $2 let plan = LogicalPlan::Limit(Limit { - skip: None, + skip: Some(Box::new(Expr::Placeholder(Placeholder { + id: placeholders[1].to_string(), + data_type: None, + }))), fetch: Some(Box::new(Expr::Placeholder(Placeholder { - id: placeholder_value.to_string(), + id: placeholders[0].to_string(), data_type: None, }))), input: Arc::new(LogicalPlan::TableScan(TableScan { @@ -4064,15 +4084,58 @@ mod tests { })), }); + // try to infer the placeholder datatypes for the plan + let schema = DFSchema::try_from(Arc::clone(&schema))?; + let plan = plan + .map_expressions(|e| { + let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; + Ok(if !has_placeholder { + Transformed::no(e) + } else { + Transformed::yes(e) + }) + }) + .expect("map expressions") + .data; + + let LogicalPlan::Limit(Limit { + fetch: Some(f), + skip: Some(s), + .. + }) = &plan + else { + panic!("plan is not Limit with fetch and skip"); + }; + + if !matches!( + (&**f, &**s), + ( + Expr::Placeholder(Placeholder { + data_type: None, + .. + }), + Expr::Placeholder(Placeholder { + data_type: None, + .. + }) + ) + ) { + panic!( + "expected fetch and skip to be placeholders with datatypes uninferred" + ); + } + let params = plan.get_parameter_types().expect("to infer type"); - assert_eq!(params.len(), 1); + assert_eq!(params.len(), 2); - let parameter_type = params - .clone() - .get(placeholder_value) - .expect("to get type") - .clone(); - assert_eq!(parameter_type, Some(DataType::Int64)); + for placeholder in placeholders { + let parameter_type = params + .clone() + .get(placeholder) + .expect("to get fetch type") + .clone(); + assert_eq!(parameter_type, Some(DataType::Int64)); + } Ok(()) }