diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 7b94ed263b0e4..5a9ee8502eaa9 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -26,7 +26,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ Result, ScalarValue, exec_err, nested_struct::validate_struct_compatibility, @@ -260,20 +260,20 @@ impl DefaultPhysicalExprAdapter { impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &self.logical_file_schema, - physical_file_schema: &self.physical_file_schema, + logical_file_schema: Arc::clone(&self.logical_file_schema), + physical_file_schema: Arc::clone(&self.physical_file_schema), }; expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } } -struct DefaultPhysicalExprAdapterRewriter<'a> { - logical_file_schema: &'a Schema, - physical_file_schema: &'a Schema, +struct DefaultPhysicalExprAdapterRewriter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, } -impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { +impl DefaultPhysicalExprAdapterRewriter { fn rewrite_expr( &self, expr: Arc, @@ -421,18 +421,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let physical_field = self.physical_file_schema.field(physical_column_index); - let column = match ( - column.index() == physical_column_index, - logical_field.data_type() == physical_field.data_type(), - ) { - // If the column index matches and the data types match, we can use the column as is - (true, true) => return Ok(Transformed::no(expr)), - // If the indexes or data types do not match, we need to create a new column expression - (true, _) => column.clone(), - (false, _) => { - Column::new_with_schema(logical_field.name(), self.physical_file_schema)? - } - }; + if column.index() == physical_column_index + && logical_field.data_type() == physical_field.data_type() + { + return Ok(Transformed::no(expr)); + } + + let column = self.resolve_column(column, physical_column_index)?; if logical_field.data_type() == physical_field.data_type() { // If the data types match, we can use the column as is @@ -443,24 +438,60 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 - // + self.create_cast_column_expr(column, logical_field) + } + + /// Resolves a column expression, handling index and type mismatches. + /// + /// Returns the appropriate Column expression when the column's index or data type + /// don't match the physical schema. Assumes that the early-exit case (both index + /// and type match) has already been checked by the caller. + fn resolve_column( + &self, + column: &Column, + physical_column_index: usize, + ) -> Result { + if column.index() == physical_column_index { + Ok(column.clone()) + } else { + Column::new_with_schema(column.name(), self.physical_file_schema.as_ref()) + } + } + + /// Validates type compatibility and creates a CastColumnExpr if needed. + /// + /// Checks whether the physical field can be cast to the logical field type, + /// handling both struct and scalar types. Returns a CastColumnExpr with the + /// appropriate configuration. + fn create_cast_column_expr( + &self, + column: Column, + logical_field: &Field, + ) -> Result>> { + let actual_physical_field = self.physical_file_schema.field(column.index()); + // For struct types, use validate_struct_compatibility which handles: // - Missing fields in source (filled with nulls) // - Extra fields in source (ignored) // - Recursive validation of nested structs // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { + match (actual_physical_field.data_type(), logical_field.data_type()) { (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility(physical_fields, logical_fields)?; + validate_struct_compatibility( + physical_fields.as_ref(), + logical_fields.as_ref(), + )?; } _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); + let is_compatible = can_cast_types( + actual_physical_field.data_type(), + logical_field.data_type(), + ); if !is_compatible { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", column.name(), - physical_field.data_type(), + actual_physical_field.data_type(), logical_field.data_type() ); } @@ -469,7 +500,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { let cast_expr = Arc::new(CastColumnExpr::new( Arc::new(column), - Arc::new(physical_field.clone()), + Arc::new(actual_physical_field.clone()), Arc::new(logical_field.clone()), None, )); @@ -777,30 +808,32 @@ mod tests { let result = adapter.rewrite(column_expr).unwrap(); + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let physical_field = Arc::new(Field::new( + "data", + DataType::Struct(physical_struct_fields), + false, + )); + + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(); + let logical_field = Arc::new(Field::new( + "data", + DataType::Struct(logical_struct_fields), + false, + )); + let expected = Arc::new(CastColumnExpr::new( Arc::new(Column::new("data", 0)), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ] - .into(), - ), - false, - )), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int64, false), - Field::new("name", DataType::Utf8View, true), - ] - .into(), - ), - false, - )), + physical_field, + logical_field, None, )) as Arc; @@ -1193,8 +1226,8 @@ mod tests { )]); let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &logical_schema, - physical_file_schema: &physical_schema, + logical_file_schema: Arc::new(logical_schema), + physical_file_schema: Arc::new(physical_schema), }; // Test that when a field exists in physical schema, it returns None @@ -1415,4 +1448,48 @@ mod tests { assert!(format!("{:?}", adapter1).contains("BatchAdapter")); assert!(format!("{:?}", adapter2).contains("BatchAdapter")); } + + #[test] + fn test_rewrite_column_index_and_type_mismatch() { + let physical_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Index 1 + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Index 0, Different Type + Field::new("b", DataType::Utf8, true), + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); + + // Logical column "a" is at index 0 + let column_expr = Arc::new(Column::new("a", 0)); + + let result = adapter.rewrite(column_expr).unwrap(); + + // Should be a CastColumnExpr + let cast_expr = result + .as_any() + .downcast_ref::() + .expect("Expected CastColumnExpr"); + + // Verify the inner column points to the correct physical index (1) + let inner_col = cast_expr + .expr() + .as_any() + .downcast_ref::() + .expect("Expected inner Column"); + assert_eq!(inner_col.name(), "a"); + assert_eq!(inner_col.index(), 1); // Physical index is 1 + + // Verify cast types + assert_eq!( + cast_expr.data_type(&Schema::empty()).unwrap(), + DataType::Int64 + ); + } }