Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions datafusion/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,19 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
}))
}
LogicalPlan::Filter(Filter { predicate, input }) => {
let schemas = plan.all_schemas();
let all_schema =
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
lhs.merge(rhs);
lhs
});
let data_type = predicate.get_type(&all_schema)?;
let schema = plan.schema().as_ref().clone();
let data_type = if let Ok(data_type) = predicate.get_type(&schema) {
data_type
} else {
// predicate type could not be resolved in schema, fall back to all schemas
let schemas = plan.all_schemas();
let all_schema =
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
lhs.merge(rhs);
lhs
});
predicate.get_type(&all_schema)?
};

let mut id_array = vec![];
expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?;
Expand Down
32 changes: 32 additions & 0 deletions datafusion/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logical_plan::{col, Expr};
use datafusion::{datasource::MemTable, prelude::JoinType};
use datafusion_expr::lit;

#[tokio::test]
async fn join() -> Result<()> {
Expand Down Expand Up @@ -120,3 +121,34 @@ async fn sort_on_unprojected_columns() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn filter_with_alias_overwrite() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from_slice(&[1, 10, 10, 100]))],
)
.unwrap();

let mut ctx = ExecutionContext::new();
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();

let df = ctx
.table("t")
.unwrap()
.select(vec![(col("a").eq(lit(10))).alias("a")])
.unwrap()
.filter(col("a"))
.unwrap();
let results = df.collect().await.unwrap();

let expected = vec![
"+------+", "| a |", "+------+", "| true |", "| true |", "+------+",
];
assert_batches_eq!(expected, &results);

Ok(())
}