diff --git a/parquet-variant-compute/src/cast_to_variant.rs b/parquet-variant-compute/src/cast_to_variant.rs index 874b734466cb..bce3a427a0f9 100644 --- a/parquet-variant-compute/src/cast_to_variant.rs +++ b/parquet-variant-compute/src/cast_to_variant.rs @@ -42,9 +42,9 @@ macro_rules! primitive_conversion { } /// Convert the input array to a `VariantArray` row by row, using `method` -/// to downcast the generic array to a specific array type and `cast_fn` -/// to transform each element to a type compatible with Variant -macro_rules! cast_conversion { +/// requiring a generic type to downcast the generic array to a specific +/// array type and `cast_fn` to transform each element to a type compatible with Variant +macro_rules! generic_conversion { ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ let array = $input.$method::<$t>(); for i in 0..array.len() { @@ -58,7 +58,10 @@ macro_rules! cast_conversion { }}; } -macro_rules! cast_conversion_nongeneric { +/// Convert the input array to a `VariantArray` row by row, using `method` +/// not requiring a generic type to downcast the generic array to a specific +/// array type and `cast_fn` to transform each element to a type compatible with Variant +macro_rules! non_generic_conversion { ($method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ let array = $input.$method(); for i in 0..array.len() { @@ -126,14 +129,18 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { let input_type = input.data_type(); // todo: handle other types like Boolean, Strings, Date, Timestamp, etc. match input_type { + DataType::Boolean => { + non_generic_conversion!(as_boolean, |v| v, input, builder); + } + DataType::Binary => { - cast_conversion!(BinaryType, as_bytes, |v| v, input, builder); + generic_conversion!(BinaryType, as_bytes, |v| v, input, builder); } DataType::LargeBinary => { - cast_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder); + generic_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder); } DataType::BinaryView => { - cast_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder); + generic_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder); } DataType::Int8 => { primitive_conversion!(Int8Type, input, builder); @@ -160,7 +167,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { primitive_conversion!(UInt64Type, input, builder); } DataType::Float16 => { - cast_conversion!( + generic_conversion!( Float16Type, as_primitive, |v: f16| -> f32 { v.into() }, @@ -175,7 +182,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { primitive_conversion!(Float64Type, input, builder); } DataType::Decimal32(_, scale) => { - cast_conversion!( + generic_conversion!( Decimal32Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i32, VariantDecimal4), @@ -184,7 +191,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal64(_, scale) => { - cast_conversion!( + generic_conversion!( Decimal64Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i64, VariantDecimal8), @@ -193,7 +200,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal128(_, scale) => { - cast_conversion!( + generic_conversion!( Decimal128Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i128, VariantDecimal16), @@ -202,7 +209,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal256(_, scale) => { - cast_conversion!( + generic_conversion!( Decimal256Type, as_primitive, |v: i256| { @@ -220,7 +227,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::FixedSizeBinary(_) => { - cast_conversion_nongeneric!(as_fixed_size_binary, |v| v, input, builder); + non_generic_conversion!(as_fixed_size_binary, |v| v, input, builder); } dt => { return Err(ArrowError::CastError(format!( @@ -239,7 +246,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { mod tests { use super::*; use arrow::array::{ - ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, FixedSizeBinaryBuilder, Float16Array, Float32Array, Float64Array, GenericByteBuilder, GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, @@ -340,6 +347,18 @@ mod tests { ); } + #[test] + fn test_cast_to_variant_bool() { + run_test( + Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])), + vec![ + Some(Variant::BooleanTrue), + None, + Some(Variant::BooleanFalse), + ], + ); + } + #[test] fn test_cast_to_variant_int8() { run_test(