diff --git a/parquet-variant-compute/src/variant_get.rs b/parquet-variant-compute/src/variant_get.rs index e95c774f2926..c3c4e2a2290d 100644 --- a/parquet-variant-compute/src/variant_get.rs +++ b/parquet-variant-compute/src/variant_get.rs @@ -24,9 +24,9 @@ use arrow::{ use arrow_schema::{ArrowError, DataType, FieldRef}; use parquet_variant::{VariantPath, VariantPathElement}; -use crate::VariantArray; use crate::variant_array::BorrowedShreddingState; use crate::variant_to_arrow::make_variant_to_arrow_row_builder; +use crate::{VariantArray, VariantType, unshred_variant}; use arrow::array::AsArray; use std::sync::Arc; @@ -125,7 +125,23 @@ fn shredded_get_path( // Helper that shreds a VariantArray to a specific type. let shred_basic_variant = |target: VariantArray, path: VariantPath<'_>, as_field: Option<&Field>| { - let as_type = as_field.map(|f| f.data_type()); + let requested_variant = + as_field.is_some_and(Field::has_valid_extension_type::); + let target = if requested_variant { + unshred_variant(&target)? + } else { + target + }; + + if requested_variant && path.is_empty() { + return Ok(ArrayRef::from(target)); + } + + let as_type = if requested_variant { + None + } else { + as_field.map(|f| f.data_type()) + }; let mut builder = make_variant_to_arrow_row_builder( target.metadata_field(), path, @@ -171,6 +187,16 @@ fn shredded_get_path( } ShreddedPathStep::Missing => { let num_rows = input.len(); + if as_field.is_some_and(Field::has_valid_extension_type::) { + let all_nulls = Some(arrow::buffer::NullBuffer::from(vec![false; num_rows])); + let arr = VariantArray::from_parts( + input.metadata_field().clone(), + None, + None, + all_nulls, + ); + return Ok(ArrayRef::from(arr)); + } let arr = match as_field.map(|f| f.data_type()) { Some(data_type) => array::new_null_array(data_type, num_rows), None => Arc::new(array::NullArray::new(num_rows)) as _, @@ -214,30 +240,32 @@ fn shredded_get_path( // // For shredded/partially-shredded targets (`typed_value` present), recurse into each field // separately to take advantage of deeper shredding in child fields. - if let DataType::Struct(fields) = as_field.data_type() { - if target.typed_value_field().is_none() { - return shred_basic_variant(target, VariantPath::default(), Some(as_field)); - } - - let children = fields - .iter() - .map(|field| { - shredded_get_path( - &target, - &[VariantPathElement::from(field.name().as_str())], - Some(field), - cast_options, - ) - }) - .collect::>>()?; - - let struct_nulls = target.nulls().cloned(); + if !as_field.has_valid_extension_type::() { + if let DataType::Struct(fields) = as_field.data_type() { + if target.typed_value_field().is_none() { + return shred_basic_variant(target, VariantPath::default(), Some(as_field)); + } - return Ok(Arc::new(StructArray::try_new( - fields.clone(), - children, - struct_nulls, - )?)); + let children = fields + .iter() + .map(|field| { + shredded_get_path( + &target, + &[VariantPathElement::from(field.name().as_str())], + Some(field), + cast_options, + ) + }) + .collect::>>()?; + + let struct_nulls = target.nulls().cloned(); + + return Ok(Arc::new(StructArray::try_new( + fields.clone(), + children, + struct_nulls, + )?)); + } } // Not a struct, so directly shred the variant as the requested type @@ -2053,6 +2081,63 @@ mod test { println!("Nested path 'a.x' result: {:?}", result); } + #[test] + fn test_variant_get_as_variant_from_unshredded_input() { + let (unshredded, _) = create_variant_get_as_variant_test_data(); + assert_variant_field_extraction_returns_unshredded_variant(&unshredded); + } + + #[test] + fn test_variant_get_as_variant_from_shredded_input() { + let (_, shredded) = create_variant_get_as_variant_test_data(); + assert_variant_field_extraction_returns_unshredded_variant(&shredded); + } + + fn create_variant_get_as_variant_test_data() -> (ArrayRef, ArrayRef) { + let input_json: ArrayRef = Arc::new(StringArray::from(vec![ + Some(r#"{"field_name": {"k": 100000}}"#), + Some(r#"{"field_name": {"k": "s"}}"#), + ])); + + let unshredded = ArrayRef::from(json_to_variant(&input_json).unwrap()); + let unshredded_variant = VariantArray::try_new(&unshredded).unwrap(); + + let as_type = DataType::Struct(Fields::from(vec![Field::new( + "field_name", + DataType::Struct(Fields::from(vec![Field::new("k", DataType::Int32, true)])), + true, + )])); + let shredded = ArrayRef::from(shred_variant(&unshredded_variant, &as_type).unwrap()); + + (unshredded, shredded) + } + + fn assert_variant_field_extraction_returns_unshredded_variant(input: &ArrayRef) { + let variant_output = VariantArray::try_new(input).unwrap().field("result"); + let options = GetOptions::new_with_path(VariantPath::try_from("field_name").unwrap()) + .with_as_type(Some(FieldRef::from(variant_output))); + + let result = variant_get(input, options).unwrap(); + let result_variant = VariantArray::try_new(&result).unwrap(); + + assert!(result_variant.typed_value_field().is_none()); + assert!(result_variant.value_field().is_some()); + + let expected_json: ArrayRef = Arc::new(StringArray::from(vec![ + Some(r#"{"k":100000}"#), + Some(r#"{"k":"s"}"#), + ])); + let expected = json_to_variant(&expected_json).unwrap(); + + assert_eq!(result_variant.len(), expected.len()); + for i in 0..result_variant.len() { + assert_eq!(result_variant.is_null(i), expected.is_null(i)); + if !result_variant.is_null(i) { + assert_eq!(result_variant.value(i), expected.value(i)); + } + } + } + /// Create test data for depth 0 (direct field access) /// [{"x": 42}, {"x": "foo"}, {"y": 10}] fn create_depth_0_test_data() -> ArrayRef {