Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 127 additions & 50 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion_common::{
Result, ScalarValue, exec_err,
nested_struct::validate_struct_compatibility,
Expand Down Expand Up @@ -260,20 +260,20 @@ impl DefaultPhysicalExprAdapter {
impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
let rewriter = DefaultPhysicalExprAdapterRewriter {
logical_file_schema: &self.logical_file_schema,
physical_file_schema: &self.physical_file_schema,
logical_file_schema: Arc::clone(&self.logical_file_schema),
physical_file_schema: Arc::clone(&self.physical_file_schema),
};
expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
.data()
}
}

struct DefaultPhysicalExprAdapterRewriter<'a> {
logical_file_schema: &'a Schema,
physical_file_schema: &'a Schema,
struct DefaultPhysicalExprAdapterRewriter {
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
}

impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
impl DefaultPhysicalExprAdapterRewriter {
fn rewrite_expr(
&self,
expr: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -421,18 +421,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
};
let physical_field = self.physical_file_schema.field(physical_column_index);

let column = match (
column.index() == physical_column_index,
logical_field.data_type() == physical_field.data_type(),
) {
// If the column index matches and the data types match, we can use the column as is
(true, true) => return Ok(Transformed::no(expr)),
// If the indexes or data types do not match, we need to create a new column expression
(true, _) => column.clone(),
(false, _) => {
Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
}
};
if column.index() == physical_column_index
&& logical_field.data_type() == physical_field.data_type()
{
return Ok(Transformed::no(expr));
}

let column = self.resolve_column(column, physical_column_index)?;

if logical_field.data_type() == physical_field.data_type() {
// If the data types match, we can use the column as is
Expand All @@ -443,24 +438,60 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
// TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
// since that's much cheaper to evalaute.
// See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
//
self.create_cast_column_expr(column, logical_field)
}

/// Resolves a column expression, handling index and type mismatches.
///
/// Returns the appropriate Column expression when the column's index or data type
/// don't match the physical schema. Assumes that the early-exit case (both index
/// and type match) has already been checked by the caller.
fn resolve_column(
&self,
column: &Column,
physical_column_index: usize,
) -> Result<Column> {
if column.index() == physical_column_index {
Ok(column.clone())
} else {
Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())
}
}

/// Validates type compatibility and creates a CastColumnExpr if needed.
///
/// Checks whether the physical field can be cast to the logical field type,
/// handling both struct and scalar types. Returns a CastColumnExpr with the
/// appropriate configuration.
fn create_cast_column_expr(
&self,
column: Column,
logical_field: &Field,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
let actual_physical_field = self.physical_file_schema.field(column.index());

// For struct types, use validate_struct_compatibility which handles:
// - Missing fields in source (filled with nulls)
// - Extra fields in source (ignored)
// - Recursive validation of nested structs
// For non-struct types, use Arrow's can_cast_types
match (physical_field.data_type(), logical_field.data_type()) {
match (actual_physical_field.data_type(), logical_field.data_type()) {
(DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
validate_struct_compatibility(physical_fields, logical_fields)?;
validate_struct_compatibility(
physical_fields.as_ref(),
logical_fields.as_ref(),
)?;
}
_ => {
let is_compatible =
can_cast_types(physical_field.data_type(), logical_field.data_type());
let is_compatible = can_cast_types(
actual_physical_field.data_type(),
logical_field.data_type(),
);
if !is_compatible {
return exec_err!(
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
column.name(),
physical_field.data_type(),
actual_physical_field.data_type(),
logical_field.data_type()
);
}
Expand All @@ -469,7 +500,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {

let cast_expr = Arc::new(CastColumnExpr::new(
Arc::new(column),
Arc::new(physical_field.clone()),
Arc::new(actual_physical_field.clone()),
Arc::new(logical_field.clone()),
None,
));
Expand Down Expand Up @@ -777,30 +808,32 @@ mod tests {

let result = adapter.rewrite(column_expr).unwrap();

let physical_struct_fields: Fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let physical_field = Arc::new(Field::new(
"data",
DataType::Struct(physical_struct_fields),
false,
));

let logical_struct_fields: Fields = vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
]
.into();
let logical_field = Arc::new(Field::new(
"data",
DataType::Struct(logical_struct_fields),
false,
));

let expected = Arc::new(CastColumnExpr::new(
Arc::new(Column::new("data", 0)),
Arc::new(Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into(),
),
false,
)),
Arc::new(Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
]
.into(),
),
false,
)),
physical_field,
logical_field,
None,
)) as Arc<dyn PhysicalExpr>;

Expand Down Expand Up @@ -1193,8 +1226,8 @@ mod tests {
)]);

let rewriter = DefaultPhysicalExprAdapterRewriter {
logical_file_schema: &logical_schema,
physical_file_schema: &physical_schema,
logical_file_schema: Arc::new(logical_schema),
physical_file_schema: Arc::new(physical_schema),
};

// Test that when a field exists in physical schema, it returns None
Expand Down Expand Up @@ -1415,4 +1448,48 @@ mod tests {
assert!(format!("{:?}", adapter1).contains("BatchAdapter"));
assert!(format!("{:?}", adapter2).contains("BatchAdapter"));
}

#[test]
fn test_rewrite_column_index_and_type_mismatch() {
let physical_schema = Schema::new(vec![
Field::new("b", DataType::Utf8, true),
Field::new("a", DataType::Int32, false), // Index 1
]);

let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int64, false), // Index 0, Different Type
Field::new("b", DataType::Utf8, true),
]);

let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory
.create(Arc::new(logical_schema), Arc::new(physical_schema))
.unwrap();

// Logical column "a" is at index 0
let column_expr = Arc::new(Column::new("a", 0));

let result = adapter.rewrite(column_expr).unwrap();

// Should be a CastColumnExpr
let cast_expr = result
.as_any()
.downcast_ref::<CastColumnExpr>()
.expect("Expected CastColumnExpr");

// Verify the inner column points to the correct physical index (1)
let inner_col = cast_expr
.expr()
.as_any()
.downcast_ref::<Column>()
.expect("Expected inner Column");
assert_eq!(inner_col.name(), "a");
assert_eq!(inner_col.index(), 1); // Physical index is 1

// Verify cast types
assert_eq!(
cast_expr.data_type(&Schema::empty()).unwrap(),
DataType::Int64
);
}
}