diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index 8d0dda4a4e50..b81cacc4bc40 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -113,11 +113,11 @@ fn make_room_for_header(buffer: &mut Vec, start_pos: usize, header_size: usi /// panic!("unexpected variant type") /// }; /// assert_eq!( -/// variant_object.field("first_name").unwrap(), +/// variant_object.field_by_name("first_name").unwrap(), /// Some(Variant::ShortString("Jiaying")) /// ); /// assert_eq!( -/// variant_object.field("last_name").unwrap(), +/// variant_object.field_by_name("last_name").unwrap(), /// Some(Variant::ShortString("Li")) /// ); /// ``` diff --git a/parquet-variant/src/utils.rs b/parquet-variant/src/utils.rs index 0eca30e408de..7a1b9f039937 100644 --- a/parquet-variant/src/utils.rs +++ b/parquet-variant/src/utils.rs @@ -46,9 +46,10 @@ pub(crate) fn map_try_from_slice_error(e: TryFromSliceError) -> ArrowError { ArrowError::InvalidArgumentError(e.to_string()) } -pub(crate) fn first_byte_from_slice(slice: &[u8]) -> Result<&u8, ArrowError> { +pub(crate) fn first_byte_from_slice(slice: &[u8]) -> Result { slice .first() + .copied() .ok_or_else(|| ArrowError::InvalidArgumentError("Received empty bytes".to_string())) } @@ -58,14 +59,14 @@ pub(crate) fn string_from_slice(slice: &[u8], range: Range) -> Result<&st .map_err(|_| ArrowError::InvalidArgumentError("invalid UTF-8 string".to_string())) } -/// Performs a binary search on a slice using a fallible key extraction function. +/// Performs a binary search over a range using a fallible key extraction function; a failed key +/// extraction immediately terminats the search. /// -/// This is similar to the standard library's `binary_search_by`, but allows the key -/// extraction function to fail. If key extraction fails during the search, that error -/// is propagated immediately. +/// This is similar to the standard library's `binary_search_by`, but generalized to ranges instead +/// of slices. /// /// # Arguments -/// * `slice` - The slice to search in +/// * `range` - The range to search in /// * `target` - The target value to search for /// * `key_extractor` - A function that extracts a comparable key from slice elements. /// This function can fail and return an error. @@ -74,28 +75,33 @@ pub(crate) fn string_from_slice(slice: &[u8], range: Range) -> Result<&st /// * `Ok(Ok(index))` - Element found at the given index /// * `Ok(Err(index))` - Element not found, but would be inserted at the given index /// * `Err(e)` - Key extraction failed with error `e` -pub(crate) fn try_binary_search_by( - slice: &[T], +pub(crate) fn try_binary_search_range_by( + range: Range, target: &K, mut key_extractor: F, ) -> Result, E> where K: Ord, - F: FnMut(&T) -> Result, + F: FnMut(usize) -> Result, { - let mut left = 0; - let mut right = slice.len(); - - while left < right { - let mid = (left + right) / 2; - let key = key_extractor(&slice[mid])?; - + let Range { mut start, mut end } = range; + while start < end { + let mid = start + (end - start) / 2; + let key = key_extractor(mid)?; match key.cmp(target) { std::cmp::Ordering::Equal => return Ok(Ok(mid)), - std::cmp::Ordering::Greater => right = mid, - std::cmp::Ordering::Less => left = mid + 1, + std::cmp::Ordering::Greater => end = mid, + std::cmp::Ordering::Less => start = mid + 1, } } - Ok(Err(left)) + Ok(Err(start)) +} + +/// Attempts to prove a fallible iterator is actually infallible in practice, by consuming every +/// element and returning the first error (if any). +pub(crate) fn validate_fallible_iterator( + mut it: impl Iterator>, +) -> Result<(), E> { + it.find(Result::is_err).transpose().map(|_| ()) } diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 3ebc193678c5..9d3e2488c905 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -19,11 +19,11 @@ use crate::decoder::{ }; use crate::utils::{ array_from_slice, first_byte_from_slice, slice_from_slice, string_from_slice, - try_binary_search_by, + try_binary_search_range_by, validate_fallible_iterator, }; use arrow_schema::ArrowError; use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; -use std::{num::TryFromIntError, ops::Range}; +use std::num::TryFromIntError; #[derive(Clone, Debug, Copy, PartialEq)] enum OffsetSizeBytes { @@ -91,6 +91,7 @@ impl OffsetSizeBytes { } } +/// A parsed version of the variant metadata header byte. #[derive(Clone, Debug, Copy, PartialEq)] pub(crate) struct VariantMetadataHeader { version: u8, @@ -105,6 +106,8 @@ const CORRECT_VERSION_VALUE: u8 = 1; impl VariantMetadataHeader { /// Tries to construct the variant metadata header, which has the form + /// + /// ```text /// 7 6 5 4 3 0 /// +-------+---+---+---------------+ /// header | | | | version | @@ -112,14 +115,14 @@ impl VariantMetadataHeader { /// ^ ^ /// | +-- sorted_strings /// +-- offset_size_minus_one + /// ``` + /// /// The version is a 4-bit value that must always contain the value 1. /// - sorted_strings is a 1-bit value indicating whether dictionary strings are sorted and unique. /// - offset_size_minus_one is a 2-bit value providing the number of bytes per dictionary size and offset field. /// - The actual number of bytes, offset_size, is offset_size_minus_one + 1 - pub(crate) fn try_new(bytes: &[u8]) -> Result { - let header = first_byte_from_slice(bytes)?; - - let version = header & 0x0F; // First four bits + pub(crate) fn try_new(header_byte: u8) -> Result { + let version = header_byte & 0x0F; // First four bits if version != CORRECT_VERSION_VALUE { let err_msg = format!( "The version bytes in the header is not {CORRECT_VERSION_VALUE}, got {:b}", @@ -127,8 +130,8 @@ impl VariantMetadataHeader { ); return Err(ArrowError::InvalidArgumentError(err_msg)); } - let is_sorted = (header & 0x10) != 0; // Fifth bit - let offset_size_minus_one = header >> 6; // Last two bits + let is_sorted = (header_byte & 0x10) != 0; // Fifth bit + let offset_size_minus_one = header_byte >> 6; // Last two bits Ok(Self { version, is_sorted, @@ -153,65 +156,41 @@ impl<'m> VariantMetadata<'m> { self.bytes } + /// Attempts to interpret `bytes` as a variant metadata instance. + /// + /// # Validation + /// + /// This constructor verifies that `bytes` points to a valid variant metadata instance. In + /// particular, all offsets are in-bounds and point to valid utf8 strings. pub fn try_new(bytes: &'m [u8]) -> Result { - let header = VariantMetadataHeader::try_new(bytes)?; + let header_byte = first_byte_from_slice(bytes)?; + let header = VariantMetadataHeader::try_new(header_byte)?; + // Offset 1, index 0 because first element after header is dictionary size let dict_size = header.offset_size.unpack_usize(bytes, 1, 0)?; - // Check that we have the correct metadata length according to dictionary_size, or return - // error early. - // Minimum number of bytes the metadata buffer must contain: - // 1 byte header - // + offset_size-byte `dictionary_size` field - // + (dict_size + 1) offset entries, each `offset_size` bytes. (Table size, essentially) - // 1 + offset_size + (dict_size + 1) * offset_size + // Calculate the starting offset of the dictionary string bytes. + // + // Value header, dict_size (offset_size bytes), and dict_size+1 offsets + // = 1 + offset_size + (dict_size + 1) * offset_size // = (dict_size + 2) * offset_size + 1 - let offset_size = header.offset_size as usize; // Cheap to copy - let dictionary_key_start_byte = dict_size .checked_add(2) - .and_then(|n| n.checked_mul(offset_size)) + .and_then(|n| n.checked_mul(header.offset_size as usize)) .and_then(|n| n.checked_add(1)) .ok_or_else(|| ArrowError::InvalidArgumentError("metadata length overflow".into()))?; - - if bytes.len() < dictionary_key_start_byte { - return Err(ArrowError::InvalidArgumentError( - "Metadata shorter than dictionary_size implies".to_string(), - )); - } - - // Check that all offsets are monotonically increasing - let mut offsets = (0..=dict_size).map(|i| header.offset_size.unpack_usize(bytes, 1, i + 1)); - let Some(Ok(mut end @ 0)) = offsets.next() else { - return Err(ArrowError::InvalidArgumentError( - "First offset is non-zero".to_string(), - )); - }; - - for offset in offsets { - let offset = offset?; - if end >= offset { - return Err(ArrowError::InvalidArgumentError( - "Offsets are not monotonically increasing".to_string(), - )); - } - end = offset; - } - - // Verify the buffer covers the whole dictionary-string section - if end > bytes.len() - dictionary_key_start_byte { - // `prev` holds the last offset seen still - return Err(ArrowError::InvalidArgumentError( - "Last offset does not equal dictionary length".to_string(), - )); - } - - Ok(Self { + println!("dictionary_key_start_byte: {dictionary_key_start_byte}"); + let s = Self { bytes, header, dict_size, dictionary_key_start_byte, - }) + }; + + // Iterate over all string keys in this dictionary in order to validate the offset array and + // prove that the string bytes are all in bounds. Otherwise, `iter` might panic on `unwrap`. + validate_fallible_iterator(s.iter_checked())?; + Ok(s) } /// Whether the dictionary keys are sorted and unique @@ -223,113 +202,92 @@ impl<'m> VariantMetadata<'m> { pub fn dictionary_size(&self) -> usize { self.dict_size } + + /// The variant protocol version pub fn version(&self) -> u8 { self.header.version } - /// Helper method to get the offset start and end range for a key by index. - fn get_offsets_for_key_by(&self, index: usize) -> Result, ArrowError> { - if index >= self.dict_size { - return Err(ArrowError::InvalidArgumentError(format!( - "Index {} out of bounds for dictionary of length {}", - index, self.dict_size - ))); - } - - // Skipping the header byte (setting byte_offset = 1) and the dictionary_size (setting offset_index +1) - let unpack = |i| self.header.offset_size.unpack_usize(self.bytes, 1, i + 1); - Ok(unpack(index)?..unpack(index + 1)?) - } - - /// Get a single offset by index - pub fn get_offset_by(&self, index: usize) -> Result { - if index >= self.dict_size { - return Err(ArrowError::InvalidArgumentError(format!( - "Index {} out of bounds for dictionary of length {}", - index, self.dict_size - ))); - } - + /// Gets an offset array entry by index. + /// + /// This offset is an index into the dictionary, at the boundary between string `i-1` and string + /// `i`. See [`Self::get`] to retrieve a specific dictionary entry. + fn get_offset(&self, i: usize) -> Result { // Skipping the header byte (setting byte_offset = 1) and the dictionary_size (setting offset_index +1) - let unpack = |i| self.header.offset_size.unpack_usize(self.bytes, 1, i + 1); - unpack(index) + let bytes = slice_from_slice(self.bytes, ..self.dictionary_key_start_byte)?; + self.header.offset_size.unpack_usize(bytes, 1, i + 1) } - /// Get the key-name by index - pub fn get_field_by(&self, index: usize) -> Result<&'m str, ArrowError> { - let offset_range = self.get_offsets_for_key_by(index)?; - self.get_field_by_offset(offset_range) + /// Gets a dictionary entry by index + pub fn get(&self, i: usize) -> Result<&'m str, ArrowError> { + let dictionary_keys_bytes = slice_from_slice(self.bytes, self.dictionary_key_start_byte..)?; + let byte_range = self.get_offset(i)?..self.get_offset(i + 1)?; + string_from_slice(dictionary_keys_bytes, byte_range) } - /// Gets the field using an offset (Range) - helper method to keep consistent API. - pub(crate) fn get_field_by_offset(&self, offset: Range) -> Result<&'m str, ArrowError> { - let dictionary_keys_bytes = - slice_from_slice(self.bytes, self.dictionary_key_start_byte..self.bytes.len())?; - let result = string_from_slice(dictionary_keys_bytes, offset)?; - - Ok(result) + /// Get all dictionary entries as an Iterator of strings + pub fn iter(&self) -> impl Iterator + '_ { + // NOTE: It is safe to unwrap because the constructor already made a successful traversal. + self.iter_checked().map(Result::unwrap) } - #[allow(unused)] - pub(crate) fn header(&self) -> VariantMetadataHeader { - self.header + // Fallible iteration over the fields of this dictionary. The constructor traverses the iterator + // to prove it has no errors, so that all other use sites can blindly `unwrap` the result. + fn iter_checked(&self) -> impl Iterator> + '_ { + (0..self.dict_size).map(move |i| self.get(i)) } +} - /// Get the offsets as an iterator - pub fn offsets(&self) -> impl Iterator, ArrowError>> + 'm { - let offset_size = self.header.offset_size; // `Copy` - let bytes = self.bytes; +/// A parsed version of the variant object value header byte. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct VariantObjectHeader { + field_offset_size: OffsetSizeBytes, + field_id_size: OffsetSizeBytes, + is_large: bool, +} - (0..self.dict_size).map(move |i| { - // This wont be out of bounds as long as dict_size and offsets have been validated - // during construction via `try_new`, as it calls unpack_usize for the - // indices `1..dict_size+1` already. - let start = offset_size.unpack_usize(bytes, 1, i + 1); - let end = offset_size.unpack_usize(bytes, 1, i + 2); +impl VariantObjectHeader { + pub(crate) fn try_new(header_byte: u8) -> Result { + // Parse the header byte to get object parameters + let value_header = header_byte >> 2; + let field_offset_size_minus_one = value_header & 0x03; // Last 2 bits + let field_id_size_minus_one = (value_header >> 2) & 0x03; // Next 2 bits + let is_large = (value_header & 0x10) != 0; // 5th bit - match (start, end) { - (Ok(s), Ok(e)) => Ok(s..e), - (Err(e), _) | (_, Err(e)) => Err(e), - } + Ok(Self { + field_offset_size: OffsetSizeBytes::try_new(field_offset_size_minus_one)?, + field_id_size: OffsetSizeBytes::try_new(field_id_size_minus_one)?, + is_large, }) } - - /// Get all key-names as an Iterator of strings - pub fn fields( - &'m self, - ) -> Result>, ArrowError> { - let iterator = self - .offsets() - .map(move |offset_range| self.get_field_by_offset(offset_range?)); - Ok(iterator) - } } #[derive(Clone, Debug, PartialEq)] -pub(crate) struct VariantObjectHeader { - field_offset_size: OffsetSizeBytes, - field_id_size: OffsetSizeBytes, +pub struct VariantObject<'m, 'v> { + pub metadata: VariantMetadata<'m>, + pub value: &'v [u8], + header: VariantObjectHeader, num_elements: usize, field_ids_start_byte: usize, field_offsets_start_byte: usize, values_start_byte: usize, } -impl VariantObjectHeader { - pub(crate) fn try_new(value: &[u8]) -> Result { - // Parse the header byte to get object parameters - let header = first_byte_from_slice(value)?; - let value_header = header >> 2; - - let field_offset_size_minus_one = value_header & 0x03; // Last 2 bits - let field_id_size_minus_one = (value_header >> 2) & 0x03; // Next 2 bits - let is_large = value_header & 0x10; // 5th bit - - let field_offset_size = OffsetSizeBytes::try_new(field_offset_size_minus_one)?; - let field_id_size = OffsetSizeBytes::try_new(field_id_size_minus_one)?; +impl<'m, 'v> VariantObject<'m, 'v> { + /// Attempts to interpret `value` as a variant object value. + /// + /// # Validation + /// + /// This constructor verifies that `value` points to a valid variant object value. In + /// particular, that all field ids exist in `metadata`, and all offsets are in-bounds and point + /// to valid objects. + // TODO: How to make the validation non-recursive while still making iterators safely infallible?? + pub fn try_new(metadata: VariantMetadata<'m>, value: &'v [u8]) -> Result { + let header_byte = first_byte_from_slice(value)?; + let header = VariantObjectHeader::try_new(header_byte)?; // Determine num_elements size based on is_large flag - let num_elements_size = if is_large != 0 { + let num_elements_size = if header.is_large { OffsetSizeBytes::Four } else { OffsetSizeBytes::One @@ -340,25 +298,19 @@ impl VariantObjectHeader { // Calculate byte offsets for different sections let field_ids_start_byte = 1 + num_elements_size as usize; - let field_offsets_start_byte = field_ids_start_byte + num_elements * field_id_size as usize; + let field_offsets_start_byte = + field_ids_start_byte + num_elements * header.field_id_size as usize; let values_start_byte = - field_offsets_start_byte + (num_elements + 1) * field_offset_size as usize; + field_offsets_start_byte + (num_elements + 1) * header.field_offset_size as usize; - // Verify that the last field offset array entry is inside the value slice - let last_field_offset_byte = - field_offsets_start_byte + (num_elements + 1) * field_offset_size as usize; - if last_field_offset_byte > value.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "Last field offset array entry at offset {} with length {} is outside the value slice of length {}", - last_field_offset_byte, - field_offset_size as usize, - value.len() - ))); - } - - // Verify that the value of the last field offset array entry fits inside the value slice + // Spec says: "The last field_offset points to the byte after the end of the last value" + // + // Use the last offset as a bounds check. The iterator check below doesn't use it -- offsets + // are not monotonic -- so we have to check separately here. let last_field_offset = - field_offset_size.unpack_usize(value, field_offsets_start_byte, num_elements)?; + header + .field_offset_size + .unpack_usize(value, field_offsets_start_byte, num_elements)?; if values_start_byte + last_field_offset > value.len() { return Err(ArrowError::InvalidArgumentError(format!( "Last field offset value {} at offset {} is outside the value slice of length {}", @@ -367,150 +319,137 @@ impl VariantObjectHeader { value.len() ))); } - Ok(Self { - field_offset_size, - field_id_size, + + let s = Self { + metadata, + value, + header, num_elements, field_ids_start_byte, field_offsets_start_byte, values_start_byte, - }) + }; + + // Iterate over all fields of this object in order to validate the field_id and field_offset + // arrays, and also to prove the field values are all in bounds. Otherwise, `iter` might + // panic on `unwrap`. + validate_fallible_iterator(s.iter_checked())?; + Ok(s) } /// Returns the number of key-value pairs in this object - pub(crate) fn num_elements(&self) -> usize { + pub fn len(&self) -> usize { self.num_elements } -} -#[derive(Clone, Debug, PartialEq)] -pub struct VariantObject<'m, 'v> { - pub metadata: VariantMetadata<'m>, - pub value: &'v [u8], - header: VariantObjectHeader, -} + /// Returns true if the object contains no key-value pairs + pub fn is_empty(&self) -> bool { + self.len() == 0 + } -impl<'m, 'v> VariantObject<'m, 'v> { - pub fn try_new(metadata: VariantMetadata<'m>, value: &'v [u8]) -> Result { - Ok(Self { - metadata, - value, - header: VariantObjectHeader::try_new(value)?, - }) + /// Get a field's value by index in `0..self.len()` + pub fn field(&self, i: usize) -> Result, ArrowError> { + let start_offset = self.header.field_offset_size.unpack_usize( + self.value, + self.field_offsets_start_byte, + i, + )?; + let value_bytes = slice_from_slice(self.value, self.values_start_byte + start_offset..)?; + Variant::try_new_with_metadata(self.metadata, value_bytes) } - /// Returns the number of key-value pairs in this object - pub fn len(&self) -> usize { - self.header.num_elements() + /// Get a field's name by index in `0..self.len()` + pub fn field_name(&self, i: usize) -> Result<&'m str, ArrowError> { + let field_id = + self.header + .field_id_size + .unpack_usize(self.value, self.field_ids_start_byte, i)?; + self.metadata.get(field_id) } - /// Returns true if the object contains no key-value pairs - pub fn is_empty(&self) -> bool { - self.len() == 0 + /// Returns an iterator of (name, value) pairs over the fields of this object. + pub fn iter(&self) -> impl Iterator)> + '_ { + // NOTE: It is safe to unwrap because the constructor already made a successful traversal. + self.iter_checked().map(Result::unwrap) } - pub fn fields(&self) -> Result)>, ArrowError> { - let field_list = self.parse_field_list()?; - Ok(field_list.into_iter()) + // Fallible iteration over the fields of this object. The constructor traverses the iterator to + // prove it has no errors, so that all other use sites can blindly `unwrap` the result. + fn iter_checked( + &self, + ) -> impl Iterator), ArrowError>> + '_ { + (0..self.num_elements).map(move |i| Ok((self.field_name(i)?, self.field(i)?))) } - pub fn field(&self, name: &str) -> Result>, ArrowError> { + /// Returns the value of the field with the specified name, if any. + /// + /// `Ok(None)` means the field does not exist; `Err` means the search encountered an error. + pub fn field_by_name(&self, name: &str) -> Result>, ArrowError> { // Binary search through the field IDs of this object to find the requested field name. // // NOTE: This does not require a sorted metadata dictionary, because the variant spec // requires object field ids to be lexically sorted by their corresponding string values, // and probing the dictionary for a field id is always O(1) work. - let (field_ids, field_offsets) = self.parse_field_arrays()?; - let search_result = try_binary_search_by(&field_ids, &name, |&field_id| { - self.metadata.get_field_by(field_id) - })?; + let search_result = + try_binary_search_range_by(0..self.num_elements, &name, |i| self.field_name(i))?; - let Ok(index) = search_result else { - return Ok(None); - }; - let start_offset = field_offsets[index]; - let end_offset = field_offsets[index + 1]; - let value_bytes = slice_from_slice( - self.value, - self.header.values_start_byte + start_offset - ..self.header.values_start_byte + end_offset, - )?; - let variant = Variant::try_new_with_metadata(self.metadata, value_bytes)?; - Ok(Some(variant)) - } - - /// Parse field IDs and field offsets arrays using the cached header - fn parse_field_arrays(&self) -> Result<(Vec, Vec), ArrowError> { - // Parse field IDs - let field_ids = (0..self.header.num_elements) - .map(|i| { - self.header.field_id_size.unpack_usize( - self.value, - self.header.field_ids_start_byte, - i, - ) - }) - .collect::, _>>()?; - debug_assert_eq!(field_ids.len(), self.header.num_elements); - - // Parse field offsets (num_elements + 1 entries) - let field_offsets = (0..=self.header.num_elements) - .map(|i| { - self.header.field_offset_size.unpack_usize( - self.value, - self.header.field_offsets_start_byte, - i, - ) - }) - .collect::, _>>()?; - debug_assert_eq!(field_offsets.len(), self.header.num_elements + 1); - - Ok((field_ids, field_offsets)) - } - - /// Parse all fields into a vector for iteration - fn parse_field_list(&self) -> Result)>, ArrowError> { - let (field_ids, field_offsets) = self.parse_field_arrays()?; - - let mut fields = Vec::with_capacity(self.header.num_elements); - - for i in 0..self.header.num_elements { - let field_id = field_ids[i]; - let field_name = self.metadata.get_field_by(field_id)?; - - let start_offset = field_offsets[i]; - let value_bytes = - slice_from_slice(self.value, self.header.values_start_byte + start_offset..)?; - let variant = Variant::try_new_with_metadata(self.metadata, value_bytes)?; - - fields.push((field_name, variant)); - } - - Ok(fields) + search_result.ok().map(|i| self.field(i)).transpose() } } +/// A parsed version of the variant array value header byte. #[derive(Clone, Debug, PartialEq)] pub(crate) struct VariantListHeader { offset_size: OffsetSizeBytes, is_large: bool, - num_elements: usize, - first_offset_byte: usize, - first_value_byte: usize, } impl VariantListHeader { - pub(crate) fn try_new(value: &[u8]) -> Result { + pub(crate) fn try_new(header_byte: u8) -> Result { // The 6 first bits to the left are the value_header and the 2 bits // to the right are the basic type, so we shift to get only the value_header - let value_header = first_byte_from_slice(value)? >> 2; + let value_header = header_byte >> 2; let is_large = (value_header & 0x04) != 0; // 3rd bit from the right let field_offset_size_minus_one = value_header & 0x03; // Last two bits let offset_size = OffsetSizeBytes::try_new(field_offset_size_minus_one)?; + Ok(Self { + offset_size, + is_large, + }) + } +} + +/// Represents a variant array. +/// +/// NOTE: The "list" naming differs from the variant spec -- which calls it "array" -- in order to be +/// consistent with parquet and arrow type naming. Otherwise, the name would conflict with the +/// `VariantArray : Array` we must eventually define for variant-typed arrow arrays. +#[derive(Clone, Debug, PartialEq)] +pub struct VariantList<'m, 'v> { + pub metadata: VariantMetadata<'m>, + pub value: &'v [u8], + header: VariantListHeader, + num_elements: usize, + first_offset_byte: usize, + first_value_byte: usize, +} + +impl<'m, 'v> VariantList<'m, 'v> { + /// Attempts to interpret `value` as a variant array value. + /// + /// # Validation + /// + /// This constructor verifies that `value` points to a valid variant array value. In particular, + /// that all offsets are in-bounds and point to valid objects. + // TODO: How to make the validation non-recursive while still making iterators safely infallible?? + pub fn try_new(metadata: VariantMetadata<'m>, value: &'v [u8]) -> Result { + let header_byte = first_byte_from_slice(value)?; + let header = VariantListHeader::try_new(header_byte)?; + // The size of the num_elements entry in the array value_data is 4 bytes if // is_large is true, otherwise 1 byte. - let num_elements_size = match is_large { + let num_elements_size = match header.is_large { true => OffsetSizeBytes::Four, false => OffsetSizeBytes::One, }; @@ -527,7 +466,7 @@ impl VariantListHeader { // 2. (num_elements + 1) * offset_size let value_bytes = n_offsets - .checked_mul(offset_size as usize) + .checked_mul(header.offset_size as usize) .ok_or_else(overflow)?; // 3. first_offset_byte + ... @@ -535,89 +474,24 @@ impl VariantListHeader { .checked_add(value_bytes) .ok_or_else(overflow)?; - // Verify that the last offset array entry is inside the value slice - let last_offset_byte = first_offset_byte + n_offsets * offset_size as usize; - if last_offset_byte > value.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "Last offset array entry at offset {} with length {} is outside the value slice of length {}", - last_offset_byte, - offset_size as usize, - value.len() - ))); - } - - // Verify that the value of the last offset array entry fits inside the value slice - let last_offset = offset_size.unpack_usize(value, first_offset_byte, num_elements)?; - if first_value_byte + last_offset > value.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "Last offset value {} at offset {} is outside the value slice of length {}", - last_offset, - first_value_byte, - value.len() - ))); - } - - Ok(Self { - offset_size, - is_large, + let s = Self { + metadata, + value, + header, num_elements, first_offset_byte, first_value_byte, - }) - } - - /// Returns the number of elements in this list - pub(crate) fn num_elements(&self) -> usize { - self.num_elements - } - - /// Returns the offset size in bytes - #[allow(unused)] - pub(crate) fn offset_size(&self) -> usize { - self.offset_size as _ - } - - /// Returns whether this is a large list - #[allow(unused)] - pub(crate) fn is_large(&self) -> bool { - self.is_large - } - - /// Returns the byte offset where the offset array starts - pub(crate) fn first_offset_byte(&self) -> usize { - self.first_offset_byte - } - - /// Returns the byte offset where the values start - pub(crate) fn first_value_byte(&self) -> usize { - self.first_value_byte - } -} - -/// Represents a variant array. -/// -/// NOTE: The "list" naming differs from the variant spec -- which calls it "array" -- in order to be -/// consistent with parquet and arrow type naming. Otherwise, the name would conflict with the -/// `VariantArray : Array` we must eventually define for variant-typed arrow arrays. -#[derive(Clone, Debug, PartialEq)] -pub struct VariantList<'m, 'v> { - pub metadata: VariantMetadata<'m>, - pub value: &'v [u8], - header: VariantListHeader, -} + }; -impl<'m, 'v> VariantList<'m, 'v> { - pub fn try_new(metadata: VariantMetadata<'m>, value: &'v [u8]) -> Result { - Ok(Self { - metadata, - value, - header: VariantListHeader::try_new(value)?, - }) + // Iterate over all values of this array in order to validate the field_offset array and + // prove that the field values are all in bounds. Otherwise, `iter` might panic on `unwrap`. + validate_fallible_iterator(s.iter_checked())?; + Ok(s) } /// Return the length of this array pub fn len(&self) -> usize { - self.header.num_elements() + self.num_elements } /// Is the array of zero length @@ -625,44 +499,41 @@ impl<'m, 'v> VariantList<'m, 'v> { self.len() == 0 } - pub fn values(&self) -> Result>, ArrowError> { - let len = self.len(); - let values = (0..len) - .map(move |i| self.get(i)) - .collect::, _>>()?; - Ok(values.into_iter()) - } - pub fn get(&self, index: usize) -> Result, ArrowError> { - if index >= self.header.num_elements() { + if index >= self.num_elements { return Err(ArrowError::InvalidArgumentError(format!( "Index {} out of bounds for list of length {}", - index, - self.header.num_elements() + index, self.num_elements, ))); } // Skip header and num_elements bytes to read the offsets - let start_field_offset_from_first_value_byte = self.header.offset_size.unpack_usize( - self.value, - self.header.first_offset_byte(), - index, - )?; - let end_field_offset_from_first_value_byte = self.header.offset_size.unpack_usize( - self.value, - self.header.first_offset_byte(), - index + 1, - )?; + let unpack = |i| { + self.header + .offset_size + .unpack_usize(self.value, self.first_offset_byte, i) + }; // Read the value bytes from the offsets let variant_value_bytes = slice_from_slice( self.value, - self.header.first_value_byte() + start_field_offset_from_first_value_byte - ..self.header.first_value_byte() + end_field_offset_from_first_value_byte, + self.first_value_byte + unpack(index)?..self.first_value_byte + unpack(index + 1)?, )?; let variant = Variant::try_new_with_metadata(self.metadata, variant_value_bytes)?; Ok(variant) } + + /// Iterates over the values of this list + pub fn iter(&self) -> impl Iterator> + '_ { + // NOTE: It is safe to unwrap because the constructor already made a successful traversal. + self.iter_checked().map(Result::unwrap) + } + + // Fallible iteration over the fields of this dictionary. The constructor traverses the iterator + // to prove it has no errors, so that all other use sites can blindly `unwrap` the result. + fn iter_checked(&self) -> impl Iterator, ArrowError>> + '_ { + (0..self.len()).map(move |i| self.get(i)) + } } /// Variant value. May contain references to metadata and value @@ -731,7 +602,7 @@ impl<'m, 'v> Variant<'m, 'v> { metadata: VariantMetadata<'m>, value: &'v [u8], ) -> Result { - let value_metadata = *first_byte_from_slice(value)?; + let value_metadata = first_byte_from_slice(value)?; let value_data = slice_from_slice(value, 1..)?; let new_self = match get_basic_type(value_metadata)? { VariantBasicType::Primitive => match get_primitive_type(value_metadata)? { @@ -1527,26 +1398,21 @@ mod tests { let md = VariantMetadata::try_new(bytes).expect("should parse"); assert_eq!(md.dictionary_size(), 2); // Fields - assert_eq!(md.get_field_by(0).unwrap(), "cat"); - assert_eq!(md.get_field_by(1).unwrap(), "dog"); + assert_eq!(md.get(0).unwrap(), "cat"); + assert_eq!(md.get(1).unwrap(), "dog"); // Offsets - assert_eq!(md.get_offset_by(0).unwrap(), 0x00); - assert_eq!(md.get_offset_by(1).unwrap(), 0x03); - // We only have 2 keys, the final offset should not be accessible using this method. - let err = md.get_offset_by(2).unwrap_err(); + assert_eq!(md.get_offset(0).unwrap(), 0x00); + assert_eq!(md.get_offset(1).unwrap(), 0x03); + assert_eq!(md.get_offset(2).unwrap(), 0x06); + let err = md.get_offset(3).unwrap_err(); assert!( - matches!(err, ArrowError::InvalidArgumentError(ref msg) - if msg.contains("Index 2 out of bounds for dictionary of length 2")), + matches!(err, ArrowError::InvalidArgumentError(_)), "unexpected error: {err:?}" ); - let fields: Vec<(usize, &str)> = md - .fields() - .unwrap() - .enumerate() - .map(|(i, r)| (i, r.unwrap())) - .collect(); + + let fields: Vec<(usize, &str)> = md.iter().enumerate().collect(); assert_eq!(fields, vec![(0usize, "cat"), (1usize, "dog")]); } @@ -1566,15 +1432,14 @@ mod tests { let working_md = VariantMetadata::try_new(bytes).expect("should parse"); assert_eq!(working_md.dictionary_size(), 2); - assert_eq!(working_md.get_field_by(0).unwrap(), "a"); - assert_eq!(working_md.get_field_by(1).unwrap(), "b"); + assert_eq!(working_md.get(0).unwrap(), "a"); + assert_eq!(working_md.get(1).unwrap(), "b"); let truncated = &bytes[..bytes.len() - 1]; let err = VariantMetadata::try_new(truncated).unwrap_err(); assert!( - matches!(err, ArrowError::InvalidArgumentError(ref msg) - if msg.contains("Last offset")), + matches!(err, ArrowError::InvalidArgumentError(_)), "unexpected error: {err:?}" ); } @@ -1603,7 +1468,7 @@ mod tests { let err = VariantMetadata::try_new(bytes).unwrap_err(); assert!( - matches!(err, ArrowError::InvalidArgumentError(ref msg) if msg.contains("monotonically")), + matches!(err, ArrowError::InvalidArgumentError(_)), "unexpected error: {err:?}" ); } @@ -1615,7 +1480,7 @@ mod tests { let err = VariantMetadata::try_new(bytes).unwrap_err(); assert!( - matches!(err, ArrowError::InvalidArgumentError(ref msg) if msg.contains("shorter")), + matches!(err, ArrowError::InvalidArgumentError(_)), "unexpected error: {err:?}" ); } @@ -1678,24 +1543,24 @@ mod tests { assert!(!variant_obj.is_empty()); // Test field access - let active_field = variant_obj.field("active").unwrap(); + let active_field = variant_obj.field_by_name("active").unwrap(); assert!(active_field.is_some()); assert_eq!(active_field.unwrap().as_boolean(), Some(true)); - let age_field = variant_obj.field("age").unwrap(); + let age_field = variant_obj.field_by_name("age").unwrap(); assert!(age_field.is_some()); assert_eq!(age_field.unwrap().as_int8(), Some(42)); - let name_field = variant_obj.field("name").unwrap(); + let name_field = variant_obj.field_by_name("name").unwrap(); assert!(name_field.is_some()); assert_eq!(name_field.unwrap().as_string(), Some("hello")); // Test non-existent field - let missing_field = variant_obj.field("missing").unwrap(); + let missing_field = variant_obj.field_by_name("missing").unwrap(); assert!(missing_field.is_none()); // Test fields iterator - let fields: Vec<_> = variant_obj.fields().unwrap().collect(); + let fields: Vec<_> = variant_obj.iter().collect(); assert_eq!(fields.len(), 3); // Fields should be in sorted order: active, age, name @@ -1734,11 +1599,11 @@ mod tests { assert!(variant_obj.is_empty()); // Test field access on empty object - let missing_field = variant_obj.field("anything").unwrap(); + let missing_field = variant_obj.field_by_name("anything").unwrap(); assert!(missing_field.is_none()); // Test fields iterator on empty object - let fields: Vec<_> = variant_obj.fields().unwrap().collect(); + let fields: Vec<_> = variant_obj.iter().collect(); assert_eq!(fields.len(), 0); } @@ -1796,7 +1661,7 @@ mod tests { )); // Test values iterator - let values: Vec<_> = variant_list.values().unwrap().collect(); + let values: Vec<_> = variant_list.iter().collect(); assert_eq!(values.len(), 3); assert_eq!(values[0].as_int8(), Some(42)); assert_eq!(values[1].as_boolean(), Some(true)); @@ -1832,7 +1697,7 @@ mod tests { assert!(out_of_bounds.is_err()); // Test values iterator on empty list - let values: Vec<_> = variant_list.values().unwrap().collect(); + let values: Vec<_> = variant_list.iter().collect(); assert_eq!(values.len(), 0); } diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index 14a9669c2de3..82766a8fbea8 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -137,7 +137,7 @@ fn variant_object_primitive() { Variant::ShortString("2025-04-16T12:34:56.78"), ), ]; - let actual_fields: Vec<_> = variant_object.fields().unwrap().collect(); + let actual_fields: Vec<_> = variant_object.iter().collect(); assert_eq!(actual_fields, expected_fields); } #[test] @@ -163,7 +163,7 @@ fn variant_array_primitive() { Variant::Int8(5), Variant::Int8(9), ]; - let actual: Vec<_> = list.values().unwrap().collect(); + let actual: Vec<_> = list.iter().collect(); assert_eq!(actual, expected); // Call `get` for each individual element