diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 2ed45be25bc17..895bcf4f5978d 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -111,13 +111,19 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result { - let schemas = plan.all_schemas(); - let all_schema = - schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }); - let data_type = predicate.get_type(&all_schema)?; + let schema = plan.schema().as_ref().clone(); + let data_type = if let Ok(data_type) = predicate.get_type(&schema) { + data_type + } else { + // predicate type could not be resolved in schema, fall back to all schemas + let schemas = plan.all_schemas(); + let all_schema = + schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }); + predicate.get_type(&all_schema)? + }; let mut id_array = vec![]; expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?; diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index 6520e749350c2..116315e9b9b27 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -28,6 +28,7 @@ use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::{col, Expr}; use datafusion::{datasource::MemTable, prelude::JoinType}; +use datafusion_expr::lit; #[tokio::test] async fn join() -> Result<()> { @@ -120,3 +121,34 @@ async fn sort_on_unprojected_columns() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn filter_with_alias_overwrite() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from_slice(&[1, 10, 10, 100]))], + ) + .unwrap(); + + let mut ctx = ExecutionContext::new(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + ctx.register_table("t", Arc::new(provider)).unwrap(); + + let df = ctx + .table("t") + .unwrap() + .select(vec![(col("a").eq(lit(10))).alias("a")]) + .unwrap() + .filter(col("a")) + .unwrap(); + let results = df.collect().await.unwrap(); + + let expected = vec![ + "+------+", "| a |", "+------+", "| true |", "| true |", "+------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +}