diff --git a/README.md b/README.md index 2d0e3f70943b0..ceb71f64ba4ed 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,20 @@ Optional features: [apache parquet]: https://parquet.apache.org/ [parquet modular encryption]: https://parquet.apache.org/docs/file-format/data-pages/encryption/ +## Schema adaptation and nested casting + +Data sources can evolve independently from the table schema a query expects. +DataFusion's [`SchemaAdapter`](docs/source/library-user-guide/schema_adapter.md) +bridges this gap by invoking `cast_column` to coerce arrays into the desired +[`Field`] types. The function walks nested `Struct` values, fills in missing +fields with `NULL`, and ensures each level matches the target schema. + +See [Schema Adapter and Column Casting](docs/source/library-user-guide/schema_adapter.md) +for examples and notes on performance trade-offs when deeply nested structs are +cast. + +[`field`]: https://docs.rs/arrow/latest/arrow/datatypes/struct.Field.html + ## DataFusion API Evolution and Deprecation Guidelines Public methods in Apache DataFusion evolve over time: while we try to maintain a diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 3a558fa867894..e70d1d27fe23b 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -82,6 +82,7 @@ pub use functional_dependencies::{ }; use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use nested_struct::cast_column; pub use null_equality::NullEquality; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index f349b360f2385..6e8a380df9215 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{DataFusionError, Result, _plan_err}; +use crate::error::{Result, _plan_err}; use arrow::{ array::{new_null_array, Array, ArrayRef, StructArray}, - compute::cast, + compute::{cast_with_options, CastOptions}, datatypes::{DataType::Struct, Field, FieldRef}, }; use std::sync::Arc; @@ -52,36 +52,44 @@ use std::sync::Arc; fn cast_struct_column( source_col: &ArrayRef, target_fields: &[Arc], + cast_options: &CastOptions, ) -> Result { - if let Some(struct_array) = source_col.as_any().downcast_ref::() { - let mut children: Vec<(Arc, Arc)> = Vec::new(); + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + validate_struct_compatibility(source_struct.fields(), target_fields)?; + + let mut fields: Vec> = Vec::with_capacity(target_fields.len()); + let mut arrays: Vec = Vec::with_capacity(target_fields.len()); let num_rows = source_col.len(); for target_child_field in target_fields { - let field_arc = Arc::clone(target_child_field); - match struct_array.column_by_name(target_child_field.name()) { + fields.push(Arc::clone(target_child_field)); + match source_struct.column_by_name(target_child_field.name()) { Some(source_child_col) => { let adapted_child = - cast_column(source_child_col, target_child_field)?; - children.push((field_arc, adapted_child)); + cast_column(source_child_col, target_child_field, cast_options) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; + arrays.push(adapted_child); } None => { - children.push(( - field_arc, - new_null_array(target_child_field.data_type(), num_rows), - )); + arrays.push(new_null_array(target_child_field.data_type(), num_rows)); } } } - let struct_array = StructArray::from(children); + let struct_array = + StructArray::new(fields.into(), arrays, source_struct.nulls().cloned()); Ok(Arc::new(struct_array)) } else { // Return error if source is not a struct type - Err(DataFusionError::Plan(format!( + _plan_err!( "Cannot cast column of type {:?} to struct type. Source must be a struct to cast to struct.", source_col.data_type() - ))) + ) } } @@ -94,6 +102,28 @@ fn cast_struct_column( /// - **Struct Types**: Delegates to `cast_struct_column` for struct-to-struct casting only /// - **Non-Struct Types**: Uses Arrow's standard `cast` function for primitive type conversions /// +/// ## Cast Options +/// The `cast_options` argument controls how Arrow handles values that cannot be represented +/// in the target type. When `safe` is `false` (DataFusion's default) the cast will return an +/// error if such a value is encountered. Setting `safe` to `true` instead produces `NULL` +/// for out-of-range or otherwise invalid values. The options also allow customizing how +/// temporal values are formatted when cast to strings. +/// +/// ``` +/// use std::sync::Arc; +/// use arrow::array::{Int64Array, ArrayRef}; +/// use arrow::compute::CastOptions; +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::nested_struct::cast_column; +/// +/// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); +/// let target = Field::new("ints", DataType::Int32, true); +/// // Permit lossy conversions by producing NULL on overflow instead of erroring +/// let options = CastOptions { safe: true, ..Default::default() }; +/// let result = cast_column(&source, &target, &options).unwrap(); +/// assert!(result.is_null(1)); +/// ``` +/// /// ## Struct Casting Requirements /// The struct casting logic requires that the source column must already be a struct type. /// This makes the function useful for: @@ -104,6 +134,7 @@ fn cast_struct_column( /// # Arguments /// * `source_col` - The source array to cast /// * `target_field` - The target field definition (including type and metadata) +/// * `cast_options` - Options that govern strictness and formatting of the cast /// /// # Returns /// A `Result` containing the cast array @@ -114,10 +145,20 @@ fn cast_struct_column( /// - Arrow's cast function fails for non-struct types /// - Memory allocation fails during struct construction /// - Invalid data type combinations are encountered -pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result { +pub fn cast_column( + source_col: &ArrayRef, + target_field: &Field, + cast_options: &CastOptions, +) -> Result { match target_field.data_type() { - Struct(target_fields) => cast_struct_column(source_col, target_fields), - _ => Ok(cast(source_col, target_field.data_type())?), + Struct(target_fields) => { + cast_struct_column(source_col, target_fields, cast_options) + } + _ => Ok(cast_with_options( + source_col, + target_field.data_type(), + cast_options, + )?), } } @@ -141,7 +182,7 @@ pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result Resulti64, 'b','c' ignored, 'd' filled with nulls +/// // Result: Ok(()) - 'a' can cast i32->i64, 'b','c' ignored, 'd' filled with nulls /// /// // Incompatible: matching field has incompatible types /// // Source: {a: string} @@ -159,7 +200,7 @@ pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result Result { +) -> Result<()> { // Check compatibility for each target field for target_field in target_fields { // Look for matching field in source by name @@ -167,6 +208,15 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } // Check if the matching field types are compatible match (source_field.data_type(), target_field.data_type()) { // Recursively validate nested structs @@ -193,15 +243,21 @@ pub fn validate_struct_compatibility( } // Extra fields in source are OK - they'll be ignored - Ok(true) + Ok(()) } #[cfg(test)] mod tests { + use super::*; + use crate::format::DEFAULT_CAST_OPTIONS; use arrow::{ - array::{Int32Array, Int64Array, StringArray}, - datatypes::{DataType, Field}, + array::{ + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, + MapBuilder, StringArray, StringBuilder, + }, + buffer::NullBuffer, + datatypes::{DataType, Field, FieldRef, Int32Type}, }; /// Macro to extract and downcast a column from a StructArray macro_rules! get_column_as { @@ -215,11 +271,35 @@ mod tests { }; } + fn field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, true) + } + + fn non_null_field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, false) + } + + fn arc_field(name: &str, data_type: DataType) -> FieldRef { + Arc::new(field(name, data_type)) + } + + fn struct_type(fields: Vec) -> DataType { + Struct(fields.into()) + } + + fn struct_field(name: &str, fields: Vec) -> Field { + field(name, struct_type(fields)) + } + + fn arc_struct_field(name: &str, fields: Vec) -> FieldRef { + Arc::new(struct_field(name, fields)) + } + #[test] fn test_cast_simple_column() { let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; - let target_field = Field::new("ints", DataType::Int64, true); - let result = cast_column(&source, &target_field).unwrap(); + let target_field = field("ints", DataType::Int64); + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 3); assert_eq!(result.value(0), 1); @@ -227,28 +307,45 @@ mod tests { assert_eq!(result.value(2), 3); } + #[test] + fn test_cast_column_with_options() { + let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef; + let target_field = field("ints", DataType::Int32); + + let safe_opts = CastOptions { + // safe: false - return Err for failure + safe: false, + ..DEFAULT_CAST_OPTIONS + }; + assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + + let unsafe_opts = CastOptions { + // safe: true - return Null for failure + safe: true, + ..DEFAULT_CAST_OPTIONS + }; + let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.value(0), 1); + assert!(result.is_null(1)); + } + #[test] fn test_cast_struct_with_missing_field() { let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef; let source_struct = StructArray::from(vec![( - Arc::new(Field::new("a", DataType::Int32, true)), + arc_field("a", DataType::Int32), Arc::clone(&a_array), )]); let source_col = Arc::new(source_struct) as ArrayRef; - let target_field = Field::new( + let target_field = struct_field( "s", - Struct( - vec![ - Arc::new(Field::new("a", DataType::Int32, true)), - Arc::new(Field::new("b", DataType::Utf8, true)), - ] - .into(), - ), - true, + vec![field("a", DataType::Int32), field("b", DataType::Utf8)], ); - let result = cast_column(&source_col, &target_field).unwrap(); + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.fields().len(), 2); let a_result = get_column_as!(&struct_array, "a", Int32Array); @@ -264,13 +361,9 @@ mod tests { #[test] fn test_cast_struct_source_not_struct() { let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; - let target_field = Field::new( - "s", - Struct(vec![Arc::new(Field::new("a", DataType::Int32, true))].into()), - true, - ); + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source, &target_field); + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast column of type")); @@ -278,16 +371,34 @@ mod tests { assert!(error_msg.contains("Source must be a struct")); } + #[test] + fn test_cast_struct_incompatible_child_type() { + let a_array = Arc::new(BinaryArray::from(vec![ + Some(b"a".as_ref()), + Some(b"b".as_ref()), + ])) as ArrayRef; + let source_struct = + StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'a'")); + } + #[test] fn test_validate_struct_compatibility_incompatible_types() { // Source struct: {field1: Binary, field2: String} let source_fields = vec![ - Arc::new(Field::new("field1", DataType::Binary, true)), - Arc::new(Field::new("field2", DataType::Utf8, true)), + arc_field("field1", DataType::Binary), + arc_field("field2", DataType::Utf8), ]; // Target struct: {field1: Int32} - let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; + let target_fields = vec![arc_field("field1", DataType::Int32)]; let result = validate_struct_compatibility(&source_fields, &target_fields); assert!(result.is_err()); @@ -301,29 +412,293 @@ mod tests { fn test_validate_struct_compatibility_compatible_types() { // Source struct: {field1: Int32, field2: String} let source_fields = vec![ - Arc::new(Field::new("field1", DataType::Int32, true)), - Arc::new(Field::new("field2", DataType::Utf8, true)), + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), ]; // Target struct: {field1: Int64} (Int32 can cast to Int64) - let target_fields = vec![Arc::new(Field::new("field1", DataType::Int64, true))]; + let target_fields = vec![arc_field("field1", DataType::Int64)]; let result = validate_struct_compatibility(&source_fields, &target_fields); assert!(result.is_ok()); - assert!(result.unwrap()); } #[test] fn test_validate_struct_compatibility_missing_field_in_source() { // Source struct: {field2: String} (missing field1) - let source_fields = vec![Arc::new(Field::new("field2", DataType::Utf8, true))]; + let source_fields = vec![arc_field("field2", DataType::Utf8)]; // Target struct: {field1: Int32} - let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; + let target_fields = vec![arc_field("field1", DataType::Int32)]; // Should be OK - missing fields will be filled with nulls let result = validate_struct_compatibility(&source_fields, &target_fields); assert!(result.is_ok()); - assert!(result.unwrap()); + } + + #[test] + fn test_validate_struct_compatibility_additional_field_in_source() { + // Source struct: {field1: Int32, field2: String} (extra field2) + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + // Should be OK - extra fields in source are ignored + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_cast_struct_parent_nulls_retained() { + let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let fields = vec![arc_field("a", DataType::Int32)]; + let nulls = Some(NullBuffer::from(vec![true, false])); + let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int64)]); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.null_count(), 1); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_null(1)); + + let a_result = get_column_as!(&struct_array, "a", Int64Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + } + + #[test] + fn test_validate_struct_compatibility_nullable_to_non_nullable() { + // Source struct: {field1: Int32 nullable} + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32 non-nullable} + let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("field1")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_validate_struct_compatibility_non_nullable_to_nullable() { + // Source struct: {field1: Int32 non-nullable} + let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + // Target struct: {field1: Int32 nullable} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() { + // Source struct: {field1: {nested: Int32 nullable}} + let source_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![field("nested", DataType::Int32)]), + ))]; + + // Target struct: {field1: {nested: Int32 non-nullable}} + let target_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![non_null_field("nested", DataType::Int32)]), + ))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("nested")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_cast_nested_struct_with_extra_and_missing_fields() { + // Source inner struct has fields a, b, extra + let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef; + let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef; + + let inner = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + (arc_field("extra", DataType::Int32), extra), + ]); + + let source_struct = StructArray::from(vec![( + arc_struct_field( + "inner", + vec![ + field("a", DataType::Int32), + field("b", DataType::Int32), + field("extra", DataType::Int32), + ], + ), + Arc::new(inner) as ArrayRef, + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target inner struct reorders fields, adds "missing", and drops "extra" + let target_field = struct_field( + "outer", + vec![struct_field( + "inner", + vec![ + field("b", DataType::Int64), + field("a", DataType::Int32), + field("missing", DataType::Int32), + ], + )], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.fields().len(), 3); + + let b = get_column_as!(inner, "b", Int64Array); + assert_eq!(b.value(0), 2); + assert_eq!(b.value(1), 3); + assert!(!b.is_null(0)); + assert!(!b.is_null(1)); + + let a = get_column_as!(inner, "a", Int32Array); + assert_eq!(a.value(0), 1); + assert!(a.is_null(1)); + + let missing = get_column_as!(inner, "missing", Int32Array); + assert!(missing.is_null(0)); + assert!(missing.is_null(1)); + } + + #[test] + fn test_cast_struct_with_array_and_map_fields() { + // Array field with second row null + let arr_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ])) as ArrayRef; + + // Map field with second row null + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.append(true).unwrap(); + map_builder.append(false).unwrap(); + let map_array = Arc::new(map_builder.finish()) as ArrayRef; + + let source_struct = StructArray::from(vec![ + ( + arc_field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + arr_array, + ), + ( + arc_field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + map_array, + ), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![ + field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + ], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let arr = get_column_as!(&struct_array, "arr", ListArray); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + let arr0 = arr.value(0); + let values = arr0.as_any().downcast_ref::().unwrap(); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + + let map = get_column_as!(&struct_array, "map", MapArray); + assert!(!map.is_null(0)); + assert!(map.is_null(1)); + let map0 = map.value(0); + let entries = map0.as_any().downcast_ref::().unwrap(); + let keys = get_column_as!(entries, "keys", StringArray); + let vals = get_column_as!(entries, "values", Int32Array); + assert_eq!(keys.value(0), "a"); + assert_eq!(vals.value(0), 1); + } + + #[test] + fn test_cast_struct_field_order_differs() { + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("b", DataType::Int64), field("a", DataType::Int32)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let b_col = get_column_as!(&struct_array, "b", Int64Array); + assert_eq!(b_col.value(0), 3); + assert!(b_col.is_null(1)); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); } } diff --git a/datafusion/datasource/src/schema_adapter.rs b/datafusion/datasource/src/schema_adapter.rs index 16de00500b020..17d99ce0761f7 100644 --- a/datafusion/datasource/src/schema_adapter.rs +++ b/datafusion/datasource/src/schema_adapter.rs @@ -26,8 +26,8 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, }; use datafusion_common::{ - nested_struct::{cast_column, validate_struct_compatibility}, - plan_err, ColumnStatistics, + cast_column, format::DEFAULT_CAST_OPTIONS, + nested_struct::validate_struct_compatibility, plan_err, ColumnStatistics, }; use std::{fmt::Debug, sync::Arc}; /// Function used by [`SchemaMapping`] to adapt a column from the file schema to @@ -245,18 +245,18 @@ pub(crate) struct DefaultSchemaAdapter { /// Checks if a file field can be cast to a table field /// -/// Returns Ok(true) if casting is possible, or an error explaining why casting is not possible +/// Returns `Ok(())` if casting is possible, or an error explaining why casting is not possible pub(crate) fn can_cast_field( file_field: &Field, table_field: &Field, -) -> datafusion_common::Result { +) -> datafusion_common::Result<()> { match (file_field.data_type(), table_field.data_type()) { - (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { - validate_struct_compatibility(source_fields, target_fields) + (DataType::Struct(file_fields), DataType::Struct(table_fields)) => { + validate_struct_compatibility(file_fields, table_fields) } _ => { if can_cast_types(file_field.data_type(), table_field.data_type()) { - Ok(true) + Ok(()) } else { plan_err!( "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", @@ -302,7 +302,9 @@ impl SchemaAdapter for DefaultSchemaAdapter { Arc::new(SchemaMapping::new( Arc::clone(&self.projected_table_schema), field_mappings, - Arc::new(|array: &ArrayRef, field: &Field| cast_column(array, field)), + Arc::new(|array: &ArrayRef, field: &Field| { + cast_column(array, field, &DEFAULT_CAST_OPTIONS) + }), )), projection, )) @@ -321,7 +323,7 @@ pub(crate) fn create_field_mapping( can_map_field: F, ) -> datafusion_common::Result<(Vec>, Vec)> where - F: Fn(&Field, &Field) -> datafusion_common::Result, + F: Fn(&Field, &Field) -> datafusion_common::Result<()>, { let mut projection = Vec::with_capacity(file_schema.fields().len()); let mut field_mappings = vec![None; projected_table_schema.fields().len()]; @@ -330,10 +332,9 @@ where if let Some((table_idx, table_field)) = projected_table_schema.fields().find(file_field.name()) { - if can_map_field(file_field, table_field)? { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } + can_map_field(file_field, table_field)?; + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); } } @@ -462,26 +463,50 @@ impl SchemaMapper for SchemaMapping { mod tests { use super::*; use arrow::{ - array::{Array, ArrayRef, StringBuilder, StructArray, TimestampMillisecondArray}, + array::{ + Array, ArrayRef, Float64Array, Int32Array, Int32Builder, Int64Array, + ListArray, MapArray, MapBuilder, StringArray, StringBuilder, StructArray, + TimestampMillisecondArray, + }, compute::cast, - datatypes::{DataType, Field, TimeUnit}, + datatypes::{DataType, Field, Int32Type, TimeUnit}, record_batch::RecordBatch, }; use datafusion_common::{stats::Precision, Result, ScalarValue, Statistics}; + fn field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, true) + } + + fn schema(fields: Vec) -> Schema { + Schema::new(fields) + } + + fn schema_ref(fields: Vec) -> SchemaRef { + Arc::new(schema(fields)) + } + + fn struct_field(name: &str, fields: Vec) -> Field { + field(name, DataType::Struct(fields.into())) + } + + fn arc_field(name: &str, data_type: DataType) -> Arc { + Arc::new(field(name, data_type)) + } + #[test] fn test_schema_mapping_map_statistics_basic() { // Create table schema (a, b, c) - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), - ])); + let table_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + field("c", DataType::Float64), + ]); // Create file schema (b, a) - different order, missing c - let file_schema = Schema::new(vec![ - Field::new("b", DataType::Utf8, true), - Field::new("a", DataType::Int32, true), + let file_schema = schema(vec![ + field("b", DataType::Utf8), + field("a", DataType::Int32), ]); // Create SchemaAdapter @@ -527,13 +552,13 @@ mod tests { #[test] fn test_schema_mapping_map_statistics_empty() { // Create schemas - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ])); - let file_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), + let table_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + ]); + let file_schema = schema(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), ]); let adapter = DefaultSchemaAdapter { @@ -556,24 +581,37 @@ mod tests { #[test] fn test_can_cast_field() { // Same type should work - let from_field = Field::new("col", DataType::Int32, true); - let to_field = Field::new("col", DataType::Int32, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); + let from_field = field("col", DataType::Int32); + let to_field = field("col", DataType::Int32); + can_cast_field(&from_field, &to_field).unwrap(); // Casting Int32 to Float64 is allowed - let from_field = Field::new("col", DataType::Int32, true); - let to_field = Field::new("col", DataType::Float64, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); + let from_field = field("col", DataType::Int32); + let to_field = field("col", DataType::Float64); + can_cast_field(&from_field, &to_field).unwrap(); // Casting Float64 to Utf8 should work (converts to string) - let from_field = Field::new("col", DataType::Float64, true); - let to_field = Field::new("col", DataType::Utf8, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); + let from_field = field("col", DataType::Float64); + let to_field = field("col", DataType::Utf8); + can_cast_field(&from_field, &to_field).unwrap(); + + // Struct fields with compatible child types should work + let from_field = struct_field("col", vec![field("a", DataType::Int32)]); + let to_field = struct_field("col", vec![field("a", DataType::Int64)]); + can_cast_field(&from_field, &to_field).unwrap(); + + // Struct fields with incompatible child types should fail + let from_field = struct_field("col", vec![field("a", DataType::Binary)]); + let to_field = struct_field("col", vec![field("a", DataType::Int32)]); + let result = can_cast_field(&from_field, &to_field); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'a'")); // Binary to Utf8 is not supported - this is an example of a cast that should fail // Note: We use Binary instead of Utf8->Int32 because Arrow actually supports that cast - let from_field = Field::new("col", DataType::Binary, true); - let to_field = Field::new("col", DataType::Decimal128(10, 2), true); + let from_field = field("col", DataType::Binary); + let to_field = field("col", DataType::Decimal128(10, 2)); let result = can_cast_field(&from_field, &to_field); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); @@ -583,21 +621,21 @@ mod tests { #[test] fn test_create_field_mapping() { // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), - ])); + let table_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + field("c", DataType::Float64), + ]); // Define file schema: different order, missing column c, and b has different type - let file_schema = Schema::new(vec![ - Field::new("b", DataType::Float64, true), // Different type but castable to Utf8 - Field::new("a", DataType::Int32, true), // Same type - Field::new("d", DataType::Boolean, true), // Not in table schema + let file_schema = schema(vec![ + field("b", DataType::Float64), // Different type but castable to Utf8 + field("a", DataType::Int32), // Same type + field("d", DataType::Boolean), // Not in table schema ]); // Custom can_map_field function that allows all mappings for testing - let allow_all = |_: &Field, _: &Field| Ok(true); + let allow_all = |_: &Field, _: &Field| Ok(()); // Test field mapping let (field_mappings, projection) = @@ -610,15 +648,6 @@ mod tests { assert_eq!(field_mappings, vec![Some(1), Some(0), None]); assert_eq!(projection, vec![0, 1]); // Projecting file columns b, a - // Test with a failing mapper - let fails_all = |_: &Field, _: &Field| Ok(false); - let (field_mappings, projection) = - create_field_mapping(&file_schema, &table_schema, fails_all).unwrap(); - - // Should have no mappings or projections if all cast checks fail - assert_eq!(field_mappings, vec![None, None, None]); - assert_eq!(projection, Vec::::new()); - // Test with error-producing mapper let error_mapper = |_: &Field, _: &Field| plan_err!("Test error"); let result = create_field_mapping(&file_schema, &table_schema, error_mapper); @@ -629,10 +658,10 @@ mod tests { #[test] fn test_schema_mapping_new() { // Define the projected table schema - let projected_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ])); + let projected_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + ]); // Define field mappings from table to file let field_mappings = vec![Some(1), Some(0)]; @@ -641,7 +670,9 @@ mod tests { let mapping = SchemaMapping::new( Arc::clone(&projected_schema), field_mappings.clone(), - Arc::new(|array: &ArrayRef, field: &Field| cast_column(array, field)), + Arc::new(|array: &ArrayRef, field: &Field| { + cast_column(array, field, &DEFAULT_CAST_OPTIONS) + }), ); // Check that fields were set correctly @@ -650,13 +681,13 @@ mod tests { // Test with a batch to ensure it works properly let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("b_file", DataType::Utf8, true), - Field::new("a_file", DataType::Int32, true), - ])), + schema_ref(vec![ + field("b_file", DataType::Utf8), + field("a_file", DataType::Int32), + ]), vec![ - Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])), - Arc::new(arrow::array::Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["hello", "world"])), + Arc::new(Int32Array::from(vec![1, 2])), ], ) .unwrap(); @@ -674,17 +705,17 @@ mod tests { #[test] fn test_map_schema_error_path() { // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Decimal128(10, 2), true), // Use Decimal which has stricter cast rules - ])); + let table_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + field("c", DataType::Decimal128(10, 2)), // Use Decimal which has stricter cast rules + ]); // Define file schema with incompatible type for column c - let file_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Float64, true), // Different but castable - Field::new("c", DataType::Binary, true), // Not castable to Decimal128 + let file_schema = schema(vec![ + field("a", DataType::Int32), + field("b", DataType::Float64), // Different but castable + field("c", DataType::Binary), // Not castable to Decimal128 ]); // Create DefaultSchemaAdapter @@ -702,11 +733,11 @@ mod tests { #[test] fn test_map_schema_happy_path() { // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Decimal128(10, 2), true), - ])); + let table_schema = schema_ref(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + field("c", DataType::Decimal128(10, 2)), + ]); // Create DefaultSchemaAdapter let adapter = DefaultSchemaAdapter { @@ -714,9 +745,9 @@ mod tests { }; // Define compatible file schema (missing column c) - let compatible_file_schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), // Can be cast to Int32 - Field::new("b", DataType::Float64, true), // Can be cast to Utf8 + let compatible_file_schema = schema(vec![ + field("a", DataType::Int64), // Can be cast to Int32 + field("b", DataType::Float64), // Can be cast to Utf8 ]); // Test successful schema mapping @@ -729,8 +760,8 @@ mod tests { let file_batch = RecordBatch::try_new( Arc::new(compatible_file_schema.clone()), vec![ - Arc::new(arrow::array::Int64Array::from(vec![100, 200])), - Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])), + Arc::new(Int64Array::from(vec![100, 200])), + Arc::new(Float64Array::from(vec![1.5, 2.5])), ], ) .unwrap(); @@ -798,6 +829,257 @@ mod tests { Ok(()) } + #[test] + fn test_map_batch_nested_struct_with_extra_and_missing_fields() -> Result<()> { + // File schema has extra field "address"; table schema adds field "salary" and casts age + let file_schema = schema(vec![struct_field( + "person", + vec![ + field("name", DataType::Utf8), + field("age", DataType::Int32), + field("address", DataType::Utf8), + ], + )]); + + let table_schema = schema_ref(vec![struct_field( + "person", + vec![ + field("age", DataType::Int64), + field("name", DataType::Utf8), + field("salary", DataType::Int32), + ], + )]); + + let name = Arc::new(StringArray::from(vec![Some("Alice"), None])) as ArrayRef; + let age = Arc::new(Int32Array::from(vec![Some(30), Some(40)])) as ArrayRef; + let address = + Arc::new(StringArray::from(vec![Some("Earth"), Some("Mars")])) as ArrayRef; + let person = StructArray::from(vec![ + (arc_field("name", DataType::Utf8), name), + (arc_field("age", DataType::Int32), age), + (arc_field("address", DataType::Utf8), address), + ]); + let batch = + RecordBatch::try_new(Arc::new(file_schema.clone()), vec![Arc::new(person)])?; + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(&file_schema)?; + let mapped = mapper.map_batch(batch)?; + assert_eq!(*mapped.schema(), *table_schema); + + let person = mapped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let age = person + .column_by_name("age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(age.value(0), 30); + assert_eq!(age.value(1), 40); + let name = person + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name.value(0), "Alice"); + assert!(name.is_null(1)); + let salary = person + .column_by_name("salary") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(salary.is_null(0)); + assert!(salary.is_null(1)); + Ok(()) + } + + #[test] + fn test_map_batch_struct_with_array_and_map() -> Result<()> { + fn map_type(value_type: DataType) -> DataType { + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", value_type, true), + ] + .into(), + ), + false, + )), + false, + ) + } + + let file_schema = Schema::new(vec![Field::new( + "s", + DataType::Struct( + vec![ + Field::new( + "arr", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + Field::new("map", map_type(DataType::Int32), true), + ] + .into(), + ), + true, + )]); + + let table_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct( + vec![ + Field::new( + "arr", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + Field::new("map", map_type(DataType::Int32), true), + ] + .into(), + ), + true, + )])); + + let arr = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ])) as ArrayRef; + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("a"); + builder.values().append_value(1); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + let map = Arc::new(builder.finish()) as ArrayRef; + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new( + "arr", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + arr, + ), + ( + Arc::new(Field::new("map", map_type(DataType::Int32), true)), + map, + ), + ]); + let batch = RecordBatch::try_new( + Arc::new(file_schema.clone()), + vec![Arc::new(struct_array)], + )?; + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(&file_schema)?; + let mapped = mapper.map_batch(batch)?; + + let s = mapped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let arr = s + .column_by_name("arr") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + let arr0 = arr.value(0); + let first = arr0.as_any().downcast_ref::().unwrap(); + assert_eq!(first.value(0), 1); + assert_eq!(first.value(1), 2); + + let map = s + .column_by_name("map") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!map.is_null(0)); + assert!(map.is_null(1)); + let map0 = map.value(0); + let entries = map0.as_any().downcast_ref::().unwrap(); + let keys = entries + .column_by_name("keys") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(keys.value(0), "a"); + Ok(()) + } + + #[test] + fn test_map_batch_field_order_differs() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int64, true), + Field::new("a", DataType::Int32, true), + ])); + + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(projection, vec![0, 1]); + + let batch = RecordBatch::try_new( + Arc::new(file_schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef, + ], + ) + .unwrap(); + + let mapped = mapper.map_batch(batch).unwrap(); + assert_eq!(*mapped.schema(), *table_schema); + let b = mapped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b.value(0), 3); + assert!(b.is_null(1)); + let a = mapped + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.value(0), 1); + assert_eq!(a.value(1), 2); + } + fn create_test_schemas_with_nested_fields() -> (SchemaRef, SchemaRef) { let file_schema = Arc::new(Schema::new(vec![Field::new( "info", @@ -913,7 +1195,7 @@ mod tests { .expect("Expected location field in struct"); let location_array = location_col .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Expected location to be a StringArray"); assert_eq!(location_array.value(0), "San Francisco"); assert_eq!(location_array.value(1), "New York"); diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index a21ad5bbbcc30..3e15a8d453af2 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -247,8 +247,11 @@ impl fmt::Display for ColumnarValue { #[cfg(test)] mod tests { use super::*; - use arrow::array::Int32Array; - + use arrow::{ + array::{Int32Array, Int64Array, StructArray}, + compute::CastOptions, + datatypes::{DataType, Field}, + }; #[test] fn values_to_arrays() { // (input, expected) @@ -359,6 +362,66 @@ mod tests { Arc::new(Int32Array::from(vec![val; len])) } + #[test] + fn cast_struct_respects_safe_cast_options() { + let int64_array = Arc::new(Int64Array::from(vec![i64::from(i32::MAX) + 1])); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int64, true)), + int64_array as ArrayRef, + )]); + let value = ColumnarValue::Array(Arc::new(struct_array)); + + let target_type = DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ); + + let cast_options = CastOptions { + safe: true, + ..DEFAULT_CAST_OPTIONS + }; + + let result = value + .cast_to(&target_type, Some(&cast_options)) + .expect("cast should succeed"); + + let arr = match result { + ColumnarValue::Array(arr) => arr, + other => panic!("expected array, got {other:?}"), + }; + + let struct_array = arr.as_any().downcast_ref::().unwrap(); + let int_array = struct_array + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(int_array.is_null(0)); + } + + #[test] + fn cast_struct_respects_unsafe_cast_options() { + let int64_array = Arc::new(Int64Array::from(vec![i64::from(i32::MAX) + 1])); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int64, true)), + int64_array as ArrayRef, + )]); + let value = ColumnarValue::Array(Arc::new(struct_array)); + + let target_type = DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ); + + let cast_options = CastOptions { + safe: false, + ..DEFAULT_CAST_OPTIONS + }; + + let result = value.cast_to(&target_type, Some(&cast_options)); + assert!(result.is_err()); + } + #[test] fn test_display_scalar() { let column = ColumnarValue::from(ScalarValue::from("foo")); diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 5e92dbe227fdd..6cdc4a0e94b13 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -24,29 +24,28 @@ use std::sync::Arc; use arrow::array::AsArray; use arrow::{ - array::{new_null_array, ArrayRef, BooleanArray}, + array::{new_null_array, ArrayRef, BooleanArray, StringArray}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; // pub use for backwards compatibility pub use datafusion_common::pruning::PruningStatistics; -use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; -use datafusion_physical_plan::metrics::Count; -use log::{debug, trace}; - -use datafusion_common::error::{DataFusionError, Result}; -use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ + cast_column, + error::{DataFusionError, Result}, + format::DEFAULT_CAST_OPTIONS, internal_err, plan_datafusion_err, plan_err, - tree_node::{Transformed, TreeNode}, - ScalarValue, + tree_node::{Transformed, TransformedResult, TreeNode}, + Column, DFSchema, ScalarValue, }; -use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion_physical_plan::metrics::Count; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; +use log::{debug, trace}; /// Used to prove that arbitrary predicates (boolean expression) can not /// possibly evaluate to `true` given information about a column provided by @@ -929,7 +928,17 @@ fn build_statistics_record_batch( // cast statistics array to required data type (e.g. parquet // provides timestamp statistics as "Int64") - let array = arrow::compute::cast(&array, data_type)?; + let array = if matches!(array.data_type(), DataType::Binary) + && matches!(stat_field.data_type(), DataType::Utf8) + { + let array = array.as_binary::(); + let array = StringArray::from_iter(array.iter().map(|maybe_bytes| { + maybe_bytes.and_then(|b| String::from_utf8(b.to_vec()).ok()) + })); + Arc::new(array) as ArrayRef + } else { + cast_column(&array, stat_field, &DEFAULT_CAST_OPTIONS)? + }; arrays.push(array); } @@ -1863,7 +1872,7 @@ pub(crate) enum StatisticsType { #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use std::ops::{Not, Rem}; use super::*; @@ -1873,7 +1882,9 @@ mod tests { use arrow::array::Decimal128Array; use arrow::{ - array::{BinaryArray, Int32Array, Int64Array, StringArray, UInt64Array}, + array::{ + BinaryArray, Int32Array, Int64Array, StringArray, StructArray, UInt64Array, + }, datatypes::TimeUnit, }; use datafusion_expr::expr::InList; @@ -2496,6 +2507,230 @@ mod tests { "); } + #[test] + fn test_build_statistics_struct_casting() { + // Request a struct column where statistics provide a struct with a different + // inner type + let field = Field::new( + "s_struct_min", + DataType::Struct(vec![Field::new("a", DataType::Int32, true)].into()), + true, + ); + let required_columns = RequiredColumns::from(vec![( + phys_expr::Column::new("s", 0), + StatisticsType::Min, + field.clone(), + )]); + + // statistics return struct with Int64 child that should be cast to Int32 + let stats_array: ArrayRef = Arc::new(StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int64, true)), + Arc::new(Int64Array::from(vec![1])) as ArrayRef, + )])); + + struct TestStats { + min: ArrayRef, + } + + impl PruningStatistics for TestStats { + fn min_values(&self, column: &Column) -> Option { + if column.name() == "s" { + Some(self.min.clone()) + } else { + None + } + } + + fn max_values(&self, _column: &Column) -> Option { + None + } + + fn num_containers(&self) -> usize { + 1 + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } + } + + let statistics = TestStats { min: stats_array }; + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); + + let struct_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let child = struct_array + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(child.value(0), 1); + assert_eq!(batch.schema().field(0), &field); + } + + #[test] + fn test_build_statistics_struct_incompatible_type() { + // Request a struct column where statistics provide an incompatible field type + let field = Field::new( + "s_struct_min", + DataType::Struct(vec![Field::new("a", DataType::Int32, true)].into()), + true, + ); + let required_columns = RequiredColumns::from(vec![( + phys_expr::Column::new("s", 0), + StatisticsType::Min, + field, + )]); + + // statistics return struct with nested struct child that cannot be cast to Int32 + let inner_field = Arc::new(Field::new("b", DataType::Int32, true)); + let inner_struct: ArrayRef = Arc::new(StructArray::from(vec![( + inner_field.clone(), + Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef, + )])); + let stats_array: ArrayRef = Arc::new(StructArray::from(vec![( + Arc::new(Field::new( + "a", + DataType::Struct(vec![inner_field].into()), + true, + )), + inner_struct, + )])); + + struct TestStats { + min: ArrayRef, + } + + impl PruningStatistics for TestStats { + fn min_values(&self, column: &Column) -> Option { + if column.name() == "s" { + Some(self.min.clone()) + } else { + None + } + } + + fn max_values(&self, _column: &Column) -> Option { + None + } + + fn num_containers(&self) -> usize { + 1 + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } + } + + let statistics = TestStats { min: stats_array }; + let err = + build_statistics_record_batch(&statistics, &required_columns).unwrap_err(); + assert!( + err.to_string().contains("Cannot cast struct field"), + "{}", + err + ); + } + + #[test] + fn test_build_statistics_struct_incompatible_nullability() { + // Request a non-nullable child field but statistics provide a nullable field + let field = Field::new( + "s_struct_min", + DataType::Struct(vec![Field::new("a", DataType::Int32, false)].into()), + true, + ); + let required_columns = RequiredColumns::from(vec![( + phys_expr::Column::new("s", 0), + StatisticsType::Min, + field, + )]); + + // statistics return struct with nullable child + let stats_array: ArrayRef = Arc::new(StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef, + )])); + + struct TestStats { + min: ArrayRef, + } + + impl PruningStatistics for TestStats { + fn min_values(&self, column: &Column) -> Option { + if column.name() == "s" { + Some(self.min.clone()) + } else { + None + } + } + + fn max_values(&self, _column: &Column) -> Option { + None + } + + fn num_containers(&self) -> usize { + 1 + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } + } + + let statistics = TestStats { min: stats_array }; + let err = + build_statistics_record_batch(&statistics, &required_columns).unwrap_err(); + assert!( + err.to_string() + .contains("Cannot cast nullable struct field"), + "{}", + err + ); + } + #[test] fn test_build_statistics_no_required_stats() { let required_columns = RequiredColumns::new(); @@ -2540,6 +2775,37 @@ mod tests { "); } + #[test] + fn test_build_statistics_invalid_utf8_input() { + let required_columns = RequiredColumns::from(vec![( + phys_expr::Column::new("s1", 1), + StatisticsType::Min, + Field::new("s1_min", DataType::Utf8, true), + )]); + + let statistics = OneContainerStats { + min_values: Some(Arc::new(BinaryArray::from(vec![ + Some(b"ok".as_ref()), + Some([0xffu8, 0xfeu8].as_ref()), + None, + ]))), + max_values: None, + num_containers: 3, + }; + + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); + assert_snapshot!(batches_to_string(&[batch]), @r" + +--------+ + | s1_min | + +--------+ + | ok | + | | + | | + +--------+ + "); + } + #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays diff --git a/docs/source/index.rst b/docs/source/index.rst index 2fc7970f094b7..5bd9afe82445d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -131,6 +131,7 @@ To get started, see library-user-guide/using-the-dataframe-api library-user-guide/building-logical-plans library-user-guide/catalogs + library-user-guide/schema_adapter library-user-guide/functions/index library-user-guide/custom-table-providers library-user-guide/table-constraints diff --git a/docs/source/library-user-guide/index.md b/docs/source/library-user-guide/index.md index fd126a1120edf..c7ba66fe81011 100644 --- a/docs/source/library-user-guide/index.md +++ b/docs/source/library-user-guide/index.md @@ -38,6 +38,10 @@ DataFusion is designed to be extensible at all points, including - [x] User Defined `LogicalPlan` nodes - [x] User Defined `ExecutionPlan` nodes +For adapting columns between evolving schemas, see +[Schema Adapter and Column Casting](schema_adapter.md), which explains how +`cast_column` reconciles nested structs and the trade-offs of deep casting. + [user guide]: ../user-guide/example-usage.md [contributor guide]: ../contributor-guide/index.md [docs]: https://docs.rs/datafusion/latest/datafusion/#architecture diff --git a/docs/source/library-user-guide/schema_adapter.md b/docs/source/library-user-guide/schema_adapter.md new file mode 100644 index 0000000000000..5f8eaac155ecf --- /dev/null +++ b/docs/source/library-user-guide/schema_adapter.md @@ -0,0 +1,112 @@ + + +# Schema Adapter and Column Casting + +DataFusion's `SchemaAdapter` maps `RecordBatch`es produced by a data source to the +schema expected by a query. When a field exists in both schemas but their types +differ, the adapter invokes [`cast_column`](../../../datafusion/common/src/nested_struct.rs) +to coerce the column to the target [`Field`] type. `cast_column` recursively +handles nested `Struct` values, inserting `NULL` arrays for fields that are +missing in the source and ignoring extra fields. + +## Casting structs with nullable fields + +```rust +use std::sync::Arc; +use arrow::array::{Int32Array, StructArray}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::nested_struct::cast_column; +use datafusion_common::format::DEFAULT_CAST_OPTIONS; + +// source schema: { info: { id: Int32 } } +let source_field = Field::new( + "info", + DataType::Struct(vec![Field::new("id", DataType::Int32, true)].into()), + false, +); + +let target_field = Field::new( + "info", + DataType::Struct(vec![ + Field::new("id", DataType::Int64, true), + Field::new("score", DataType::Int32, true), + ].into()), + true, +); + +let id = Arc::new(Int32Array::from(vec![Some(1), None])) as _; +let source_struct = StructArray::from(vec![(source_field.children()[0].clone(), id)]); +let casted = cast_column( + &Arc::new(source_struct) as _, + &target_field, + &DEFAULT_CAST_OPTIONS, +) +.unwrap(); +assert_eq!(casted.data_type(), target_field.data_type()); +``` + +The new `score` field is filled with `NULL` and `id` is promoted to a nullable +`Int64`, demonstrating how nested casting can reconcile schema differences. + +## Adapting a RecordBatch + +```rust +use std::sync::Arc; +use arrow::{array::Int32Array, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch}; +use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; + +// RecordBatch with `id: Int32` +let source_schema = Schema::new(vec![Field::new("id", DataType::Int32, true)]); +let batch = RecordBatch::try_new( + Arc::new(source_schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2])) as _], +).unwrap(); + +// Target schema expects `id: Int64` and an extra `name: Utf8` column +let target_schema = Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("name", DataType::Utf8, true), +]); + +let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(target_schema)); +let (mapper, _) = adapter.map_schema(&source_schema).unwrap(); +let adapted_batch = mapper.map_batch(batch).unwrap(); +assert_eq!(adapted_batch.schema().fields().len(), 2); +``` + +When `mapper.map_batch` runs, the adapter calls `cast_column` for any field whose +type differs between source and target schemas. In this example, `id` is cast from +`Int32` to `Int64` and the missing `name` column is filled with `NULL` values. + +## Pitfalls and performance + +- **Field name mismatches**: only matching field names are cast. Extra source + fields are dropped and missing target fields are filled with `NULL`. +- **Non-struct sources**: attempting to cast a non-struct array to a struct + results in an error. +- **Complex type limitations**: `cast_column` only supports `Struct` arrays and + returns an error for other complex types like `List` or `Map`. Future support + may extend [`cast_column`](../../../datafusion/common/src/nested_struct.rs) to + handle these nested types. +- **Nested cost**: each level of nesting requires building new arrays. Deep or + wide structs can increase memory use and CPU time, so avoid unnecessary + casting in hot paths. + +[`field`]: https://docs.rs/arrow/latest/arrow/datatypes/struct.Field.html