From 2f8fb401ad8d9687730f8a1ddbc3e0f7aafa7399 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Oct 2025 14:50:29 +0800 Subject: [PATCH 1/2] Enhance CastColumnExpr integration and support in projection mapping and utility functions --- .../src/equivalence/projection.rs | 126 +++++++++++++++--- .../src/equivalence/properties/mod.rs | 13 +- .../src/expressions/cast_column.rs | 31 +++++ .../physical-expr/src/intervals/utils.rs | 28 +++- 4 files changed, 174 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index a4ed8187cfadd..149f22cdb6c1c 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -18,7 +18,7 @@ use std::ops::Deref; use std::sync::Arc; -use crate::expressions::Column; +use crate::expressions::{CastColumnExpr, Column}; use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; @@ -96,28 +96,56 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_field = input_schema.field(idx); - let matching_name = matching_field.name(); - if col.name() != matching_name { - return internal_err!( - "Input field name {} does not match with the projection expression {}", - matching_name, - col.name() + let source_expr = expr + .transform_down(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_field = input_schema.field(idx); + let matching_name = matching_field.name(); + if col.name() != matching_name { + return internal_err!( + "Input field name {} does not match with the projection expression {}", + matching_name, + col.name() + ); + } + let matching_column = Column::new(matching_name, idx); + Ok(Transformed::yes(Arc::new(matching_column))) + } else if let Some(cast_column) = + e.as_any().downcast_ref::() + { + let new_input_field = cast_column + .expr() + .as_any() + .downcast_ref::() + .and_then(|col| { + input_schema + .fields() + .get(col.index()) + .map(Arc::clone) + }) + .unwrap_or_else(|| Arc::clone(cast_column.input_field())); + + if new_input_field.as_ref() == cast_column.input_field().as_ref() { + return Ok(Transformed::no(e)); + } + + let new_expr = CastColumnExpr::new( + Arc::clone(cast_column.expr()), + new_input_field, + Arc::clone(cast_column.target_field()), + Some(cast_column.cast_options().clone()), ); + Ok(Transformed::yes(Arc::new(new_expr))) + } else { + Ok(Transformed::no(e)) } - let matching_column = Column::new(matching_name, idx); - Ok(Transformed::yes(Arc::new(matching_column))) - } - None => Ok(Transformed::no(e)), - }) - .data()?; + }) + .data()?; map.entry(source_expr) .or_default() .push((target_expr, expr_idx)); @@ -253,7 +281,7 @@ mod tests { use super::*; use crate::equivalence::tests::output_schema; use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; - use crate::expressions::{col, BinaryExpr}; + use crate::expressions::{col, BinaryExpr, CastColumnExpr}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; @@ -278,6 +306,12 @@ mod tests { let col_d = &col("d", &schema)?; let col_e = &col("e", &schema)?; let col_ts = &col("ts", &schema)?; + let cast_column_expr = Arc::new(CastColumnExpr::new( + Arc::clone(col_a), + Arc::new(schema.field(0).clone()), + Arc::new(Field::new("a_cast", DataType::Int64, true)), + None, + )) as Arc; let a_plus_b = Arc::new(BinaryExpr::new( Arc::clone(col_a), Operator::Plus, @@ -713,6 +747,26 @@ mod tests { vec![("c_new", option_asc), ("b_new", option_desc)], ], ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // projection exprs + vec![ + (col_a, "a_new".to_string()), + (&cast_column_expr, "a_cast".to_string()), + ], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + // [a_cast ASC] + vec![("a_cast", option_asc)], + ], + ), ]; for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() @@ -757,6 +811,34 @@ mod tests { Ok(()) } + #[test] + fn projection_mapping_normalizes_cast_column_input_field() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let column = col("a", &schema)?; + let mismatched_input = Arc::new(Field::new("alias_a", DataType::Int32, true)); + let target_field = Arc::new(Field::new("a", DataType::Int32, true)); + let cast_column = Arc::new(CastColumnExpr::new( + Arc::clone(&column), + mismatched_input, + target_field, + None, + )) as Arc; + + let mapping = ProjectionMapping::try_new( + vec![(cast_column, "a_alias".to_string())], + &schema, + )?; + + let (source_expr, _) = mapping.iter().next().unwrap(); + let cast_column = source_expr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(cast_column.input_field().name(), "a"); + assert_eq!(cast_column.input_field().data_type(), &DataType::Int32); + Ok(()) + } + #[test] fn project_orderings2() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 2404b8f0dd3eb..32788361582db 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -33,7 +33,7 @@ use self::dependency::{ use crate::equivalence::{ AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; +use crate::expressions::{with_new_schema, CastColumnExpr, CastExpr, Column, Literal}; use crate::{ ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -850,6 +850,17 @@ impl EquivalenceProperties { sort_expr.options, )); } + } else if let Some(cast_column) = + r_expr.as_any().downcast_ref::() + { + if cast_column.expr().eq(&sort_expr.expr) + && cast_column.is_widening_cast(&expr_type) + { + result.push(PhysicalSortExpr::new( + r_expr, + sort_expr.options, + )); + } } } result.push(sort_expr); diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index 80d71c3def408..3a08ce4fdd833 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -108,6 +108,37 @@ impl CastColumnExpr { pub fn target_field(&self) -> &FieldRef { &self.target_field } + + /// Options forwarded to [`cast_column`]. + pub fn cast_options(&self) -> &CastOptions<'static> { + &self.cast_options + } + + /// Returns `true` if this cast widens the source type and therefore + /// preserves ordering semantics. + pub fn is_widening_cast(&self, src: &DataType) -> bool { + if self.target_field.data_type() == src { + return true; + } + + use DataType::*; + + matches!( + (src, self.target_field.data_type()), + (Int8, Int16 | Int32 | Int64) + | (Int16, Int32 | Int64) + | (Int32, Int64) + | (UInt8, UInt16 | UInt32 | UInt64) + | (UInt16, UInt32 | UInt64) + | (UInt32, UInt64) + | ( + Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32, + Float32 | Float64 + ) + | (Int64 | UInt64, Float64) + | (Utf8, LargeUtf8) + ) + } } impl Display for CastColumnExpr { diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 22752a00e9259..5fef90bfba0ed 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::{ - expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, + expressions::{BinaryExpr, CastColumnExpr, CastExpr, Column, Literal, NegativeExpr}, PhysicalExpr, }; @@ -55,6 +55,8 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } } else if let Some(cast) = expr_any.downcast_ref::() { check_support(cast.expr(), schema) + } else if let Some(cast_column) = expr_any.downcast_ref::() { + check_support(cast_column.expr(), schema) } else if let Some(negative) = expr_any.downcast_ref::() { check_support(negative.arg(), schema) } else { @@ -62,6 +64,30 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } } +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, CastColumnExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + #[test] + fn supports_cast_column_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let column = col("a", &schema).unwrap(); + let input_field = Arc::new(schema.field(0).clone()); + let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true)); + let expr = Arc::new(CastColumnExpr::new( + Arc::clone(&column), + input_field, + target_field, + None, + )) as Arc; + + assert!(check_support(&expr, &schema)); + } +} + // This function returns the inverse operator of the given operator. pub fn get_inverse_op(op: Operator) -> Result { match op { From 7c9d879eaef5d587c10c39f7c5fcef00900ac012 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Oct 2025 18:34:56 +0800 Subject: [PATCH 2/2] Refactor tests for `CastColumnExpr` support: move test module to the end of the file --- .../physical-expr/src/intervals/utils.rs | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 5fef90bfba0ed..7fdbd26c959ca 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -64,30 +64,6 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } } -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::{col, CastColumnExpr}; - use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; - - #[test] - fn supports_cast_column_expr() { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); - let column = col("a", &schema).unwrap(); - let input_field = Arc::new(schema.field(0).clone()); - let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true)); - let expr = Arc::new(CastColumnExpr::new( - Arc::clone(&column), - input_field, - target_field, - None, - )) as Arc; - - assert!(check_support(&expr, &schema)); - } -} - // This function returns the inverse operator of the given operator. pub fn get_inverse_op(op: Operator) -> Result { match op { @@ -217,3 +193,27 @@ fn interval_dt_to_duration_ms(dt: &IntervalDayTime) -> Result { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, CastColumnExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + #[test] + fn supports_cast_column_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let column = col("a", &schema).unwrap(); + let input_field = Arc::new(schema.field(0).clone()); + let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true)); + let expr = Arc::new(CastColumnExpr::new( + Arc::clone(&column), + input_field, + target_field, + None, + )) as Arc; + + assert!(check_support(&expr, &schema)); + } +}