diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6b408521c5cf9..0dbb78a2680ec 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1140,6 +1140,12 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Extension(extension_plan) => { + // This check prevents the Filter from being removed when the extension node has no children, + // so we return the original Filter unchanged. + if extension_plan.node.inputs().is_empty() { + filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); @@ -3786,4 +3792,83 @@ Projection: a, b \n TableScan: test"; assert_optimized_plan_eq(plan, expected_after) } + + #[test] + fn test_push_down_filter_to_user_defined_node() -> Result<()> { + // Define a custom user-defined logical node + #[derive(Debug, Hash, Eq, PartialEq)] + struct TestUserNode { + schema: DFSchemaRef, + } + + impl PartialOrd for TestUserNode { + fn partial_cmp(&self, _other: &Self) -> Option { + None + } + } + + impl TestUserNode { + fn new() -> Self { + let schema = Arc::new( + DFSchema::new_with_metadata( + vec![(None, Field::new("a", DataType::Int64, false).into())], + Default::default(), + ) + .unwrap(), + ); + + Self { schema } + } + } + + impl UserDefinedLogicalNodeCore for TestUserNode { + fn name(&self) -> &str { + "test_node" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "TestUserNode") + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result { + assert!(exprs.is_empty()); + assert!(inputs.is_empty()); + Ok(Self { + schema: Arc::clone(&self.schema), + }) + } + } + + // Create a node and build a plan with a filter + let node = LogicalPlan::Extension(Extension { + node: Arc::new(TestUserNode::new()), + }); + + let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?; + + // Check the original plan format (not part of the test assertions) + let expected_before = "Filter: Boolean(false)\ + \n TestUserNode"; + assert_eq!(format!("{plan}"), expected_before); + + // Check that the filter is pushed down to the user-defined node + let expected_after = "Filter: Boolean(false)\n TestUserNode"; + assert_optimized_plan_eq(plan, expected_after) + } }