Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,12 @@ impl OptimizerRule for PushDownFilter {
})
}
LogicalPlan::Extension(extension_plan) => {
// This check prevents the Filter from being removed when the extension node has no children,
// so we return the original Filter unchanged.
if extension_plan.node.inputs().is_empty() {
filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();

Expand Down Expand Up @@ -3786,4 +3792,83 @@ Projection: a, b
\n TableScan: test";
assert_optimized_plan_eq(plan, expected_after)
}

#[test]
fn test_push_down_filter_to_user_defined_node() -> Result<()> {
// Define a custom user-defined logical node
#[derive(Debug, Hash, Eq, PartialEq)]
struct TestUserNode {
schema: DFSchemaRef,
}

impl PartialOrd for TestUserNode {
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
None
}
}

impl TestUserNode {
fn new() -> Self {
let schema = Arc::new(
DFSchema::new_with_metadata(
vec![(None, Field::new("a", DataType::Int64, false).into())],
Default::default(),
)
.unwrap(),
);

Self { schema }
}
}

impl UserDefinedLogicalNodeCore for TestUserNode {
fn name(&self) -> &str {
"test_node"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "TestUserNode")
}

fn with_exprs_and_inputs(
&self,
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
assert!(exprs.is_empty());
assert!(inputs.is_empty());
Ok(Self {
schema: Arc::clone(&self.schema),
})
}
}

// Create a node and build a plan with a filter
let node = LogicalPlan::Extension(Extension {
node: Arc::new(TestUserNode::new()),
});

let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;

// Check the original plan format (not part of the test assertions)
let expected_before = "Filter: Boolean(false)\
\n TestUserNode";
assert_eq!(format!("{plan}"), expected_before);

// Check that the filter is pushed down to the user-defined node
let expected_after = "Filter: Boolean(false)\n TestUserNode";
assert_optimized_plan_eq(plan, expected_after)
}
}