Skip to content
17 changes: 14 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ use datafusion::logical_expr::{
};
use datafusion::prelude::{lit, JoinType};
use datafusion::{
arrow, error::Result, logical_expr::utils::split_conjunction, prelude::Column,
scalar::ScalarValue,
arrow, error::Result, logical_expr::utils::split_conjunction,
logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue,
};
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -1327,19 +1327,28 @@ pub async fn from_read_rel(
table_ref: TableReference,
schema: DFSchema,
projection: &Option<MaskExpression>,
filter: &Option<Box<Expression>>,
) -> Result<LogicalPlan> {
let schema = schema.replace_qualifier(table_ref.clone());

let filters = if let Some(f) = filter {
let filter_expr = consumer.consume_expression(f, &schema).await?;
split_conjunction_owned(filter_expr)
} else {
vec![]
};

let plan = {
let provider = match consumer.resolve_table_ref(&table_ref).await? {
Some(ref provider) => Arc::clone(provider),
_ => return plan_err!("No table named '{table_ref}'"),
};

LogicalPlanBuilder::scan(
LogicalPlanBuilder::scan_with_filters(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
filters,
)?
.build()?
};
Expand Down Expand Up @@ -1382,6 +1391,7 @@ pub async fn from_read_rel(
table_reference,
substrait_schema,
&read.projection,
&read.filter,
)
.await
}
Expand Down Expand Up @@ -1464,6 +1474,7 @@ pub async fn from_read_rel(
table_reference,
substrait_schema,
&read.projection,
&read.filter,
)
.await
}
Expand Down
22 changes: 20 additions & 2 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion::logical_expr::expr::{
AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
InSubquery, WindowFunction, WindowFunctionParams,
};
use datafusion::logical_expr::utils::conjunction;
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::Expr;
use pbjson_types::Any as ProtoAny;
Expand Down Expand Up @@ -540,7 +541,7 @@ pub fn to_substrait_rel(
}

pub fn from_table_scan(
_producer: &mut impl SubstraitProducer,
producer: &mut impl SubstraitProducer,
scan: &TableScan,
) -> Result<Box<Rel>> {
let projection = scan.projection.as_ref().map(|p| {
Expand All @@ -560,11 +561,28 @@ pub fn from_table_scan(
let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema)?;

let filter_option = if scan.filters.is_empty() {
None
} else {
let table_schema_qualified = Arc::new(
DFSchema::try_from_qualified_schema(
scan.table_name.clone(),
&(scan.source.schema()),
)
.unwrap(),
);

let combined_expr = conjunction(scan.filters.clone()).unwrap();
let filter_expr =
producer.handle_expr(&combined_expr, &table_schema_qualified)?;
Some(Box::new(filter_expr))
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(base_schema),
filter: None,
filter: filter_option,
best_effort_filter: None,
projection,
advanced_extension: None,
Expand Down
66 changes: 66 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,11 @@ async fn roundtrip_repartition_hash() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_read_filter() -> Result<()> {
roundtrip_verify_read_filter_count("SELECT a FROM data where a < 5", 1).await
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Expand Down Expand Up @@ -1319,6 +1324,56 @@ async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> {
Ok(())
}

fn count_read_filters(rel: &Rel, filter_count: &mut u32) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Some(RelType::Read(read)) => {
// increment counter for read filter if not None
if read.filter.is_some() {
*filter_count += 1;
}
Ok(())
}
Some(RelType::Filter(filter)) => {
count_read_filters(filter.input.as_ref().unwrap().as_ref(), filter_count)
}
_ => Ok(()),
}
}

async fn assert_read_filter_count(
proto: Box<Plan>,
expected_filter_count: u32,
) -> Result<()> {
let mut filter_count: u32 = 0;
for relation in &proto.relations {
match relation.rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
match count_read_filters(rel, &mut filter_count) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
plan_rel::RelType::Root(root) => {
match count_read_filters(
root.input.as_ref().unwrap(),
&mut filter_count,
) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
},
None => return plan_err!("Cannot parse plan relation: None"),
}
}

assert_eq!(expected_filter_count, filter_count);

Ok(())
}

async fn assert_expected_plan_unoptimized(
sql: &str,
expected_plan_str: &str,
Expand Down Expand Up @@ -1489,6 +1544,17 @@ async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
verify_post_join_filter_value(proto).await
}

async fn roundtrip_verify_read_filter_count(
sql: &str,
expected_filter_count: u32,
) -> Result<()> {
let ctx = create_context().await?;
let proto = roundtrip_with_ctx(sql, ctx).await?;

// verify that filter counts in read relations are as expected
assert_read_filter_count(proto, expected_filter_count).await
}

async fn roundtrip_all_types(sql: &str) -> Result<()> {
roundtrip_with_ctx(sql, create_all_type_context().await?).await?;
Ok(())
Expand Down