diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ffeff3e9df47f..7c6c45f44db7e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -68,8 +68,8 @@ use datafusion::logical_expr::{ }; use datafusion::prelude::{lit, JoinType}; use datafusion::{ - arrow, error::Result, logical_expr::utils::split_conjunction, prelude::Column, - scalar::ScalarValue, + arrow, error::Result, logical_expr::utils::split_conjunction, + logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue, }; use std::collections::HashSet; use std::sync::Arc; @@ -1327,19 +1327,28 @@ pub async fn from_read_rel( table_ref: TableReference, schema: DFSchema, projection: &Option, + filter: &Option>, ) -> Result { let schema = schema.replace_qualifier(table_ref.clone()); + let filters = if let Some(f) = filter { + let filter_expr = consumer.consume_expression(f, &schema).await?; + split_conjunction_owned(filter_expr) + } else { + vec![] + }; + let plan = { let provider = match consumer.resolve_table_ref(&table_ref).await? { Some(ref provider) => Arc::clone(provider), _ => return plan_err!("No table named '{table_ref}'"), }; - LogicalPlanBuilder::scan( + LogicalPlanBuilder::scan_with_filters( table_ref, provider_as_source(Arc::clone(&provider)), None, + filters, )? .build()? }; @@ -1382,6 +1391,7 @@ pub async fn from_read_rel( table_reference, substrait_schema, &read.projection, + &read.filter, ) .await } @@ -1464,6 +1474,7 @@ pub async fn from_read_rel( table_reference, substrait_schema, &read.projection, + &read.filter, ) .await } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index fc24d5bb91f0f..cc7efed419c2d 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -55,6 +55,7 @@ use datafusion::logical_expr::expr::{ AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, WindowFunctionParams, }; +use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; @@ -540,7 +541,7 @@ pub fn to_substrait_rel( } pub fn from_table_scan( - _producer: &mut impl SubstraitProducer, + producer: &mut impl SubstraitProducer, scan: &TableScan, ) -> Result> { let projection = scan.projection.as_ref().map(|p| { @@ -560,11 +561,28 @@ pub fn from_table_scan( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let filter_option = if scan.filters.is_empty() { + None + } else { + let table_schema_qualified = Arc::new( + DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &(scan.source.schema()), + ) + .unwrap(), + ); + + let combined_expr = conjunction(scan.filters.clone()).unwrap(); + let filter_expr = + producer.handle_expr(&combined_expr, &table_schema_qualified)?; + Some(Box::new(filter_expr)) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, base_schema: Some(base_schema), - filter: None, + filter: filter_option, best_effort_filter: None, projection, advanced_extension: None, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index e6b8bdbc047e3..f989d05c80dd1 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -1234,6 +1234,11 @@ async fn roundtrip_repartition_hash() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_read_filter() -> Result<()> { + roundtrip_verify_read_filter_count("SELECT a FROM data where a < 5", 1).await +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { @@ -1319,6 +1324,56 @@ async fn verify_post_join_filter_value(proto: Box) -> Result<()> { Ok(()) } +fn count_read_filters(rel: &Rel, filter_count: &mut u32) -> Result<()> { + // search for target_rel and field value in proto + match &rel.rel_type { + Some(RelType::Read(read)) => { + // increment counter for read filter if not None + if read.filter.is_some() { + *filter_count += 1; + } + Ok(()) + } + Some(RelType::Filter(filter)) => { + count_read_filters(filter.input.as_ref().unwrap().as_ref(), filter_count) + } + _ => Ok(()), + } +} + +async fn assert_read_filter_count( + proto: Box, + expected_filter_count: u32, +) -> Result<()> { + let mut filter_count: u32 = 0; + for relation in &proto.relations { + match relation.rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => { + match count_read_filters(rel, &mut filter_count) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + plan_rel::RelType::Root(root) => { + match count_read_filters( + root.input.as_ref().unwrap(), + &mut filter_count, + ) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + }, + None => return plan_err!("Cannot parse plan relation: None"), + } + } + + assert_eq!(expected_filter_count, filter_count); + + Ok(()) +} + async fn assert_expected_plan_unoptimized( sql: &str, expected_plan_str: &str, @@ -1489,6 +1544,17 @@ async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { verify_post_join_filter_value(proto).await } +async fn roundtrip_verify_read_filter_count( + sql: &str, + expected_filter_count: u32, +) -> Result<()> { + let ctx = create_context().await?; + let proto = roundtrip_with_ctx(sql, ctx).await?; + + // verify that filter counts in read relations are as expected + assert_read_filter_count(proto, expected_filter_count).await +} + async fn roundtrip_all_types(sql: &str) -> Result<()> { roundtrip_with_ctx(sql, create_all_type_context().await?).await?; Ok(())