diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml index 60c9b316cdf1..41b127ef14e6 100644 --- a/parquet-variant/Cargo.toml +++ b/parquet-variant/Cargo.toml @@ -31,6 +31,7 @@ edition = { workspace = true } rust-version = { workspace = true } [dependencies] +arrow-schema = "55.1.0" [lib] diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs new file mode 100644 index 000000000000..80d6947c3da6 --- /dev/null +++ b/parquet-variant/src/decoder.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow_schema::ArrowError; +use std::array::TryFromSliceError; + +use crate::utils::{array_from_slice, first_byte_from_slice, string_from_slice}; + +#[derive(Debug, Clone, Copy)] +pub enum VariantBasicType { + Primitive = 0, + ShortString = 1, + Object = 2, + Array = 3, +} + +#[derive(Debug, Clone, Copy)] +pub enum VariantPrimitiveType { + Null = 0, + BooleanTrue = 1, + BooleanFalse = 2, + Int8 = 3, + // TODO: Add types for the rest of primitives, once API is agreed upon + String = 16, +} + +/// Extracts the basic type from a header byte +pub(crate) fn get_basic_type(header: u8) -> Result { + // See https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#value-encoding + let basic_type = header & 0x03; // Basic type is encoded in the first 2 bits + let basic_type = match basic_type { + 0 => VariantBasicType::Primitive, + 1 => VariantBasicType::ShortString, + 2 => VariantBasicType::Object, + 3 => VariantBasicType::Array, + _ => { + //NOTE: A 2-bit value has a max of 4 different values (0-3), hence this is unreachable as we + // masked `basic_type` with 0x03 above. + unreachable!(); + } + }; + Ok(basic_type) +} + +impl TryFrom for VariantPrimitiveType { + type Error = ArrowError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(VariantPrimitiveType::Null), + 1 => Ok(VariantPrimitiveType::BooleanTrue), + 2 => Ok(VariantPrimitiveType::BooleanFalse), + 3 => Ok(VariantPrimitiveType::Int8), + // TODO: Add types for the rest, once API is agreed upon + 16 => Ok(VariantPrimitiveType::String), + _ => Err(ArrowError::InvalidArgumentError(format!( + "unknown primitive type: {}", + value + ))), + } + } +} +/// Extract the primitive type from a Variant value-header byte +pub(crate) fn get_primitive_type(header: u8) -> Result { + // last 6 bits contain the primitive-type, see spec + VariantPrimitiveType::try_from(header >> 2) +} + +/// To be used in `map_err` when unpacking an integer from a slice of bytes. +fn map_try_from_slice_error(e: TryFromSliceError) -> ArrowError { + ArrowError::InvalidArgumentError(e.to_string()) +} + +/// Decodes an Int8 from the value section of a variant. +pub(crate) fn decode_int8(value: &[u8]) -> Result { + let value = i8::from_le_bytes(array_from_slice(value, 1)?); + Ok(value) +} + +/// Decodes a long string from the value section of a variant. +pub(crate) fn decode_long_string(value: &[u8]) -> Result<&str, ArrowError> { + let len = u32::from_le_bytes(array_from_slice(value, 1)?) as usize; + let string = string_from_slice(value, 5..5 + len)?; + Ok(string) +} + +/// Decodes a short string from the value section of a variant. +pub(crate) fn decode_short_string(value: &[u8]) -> Result<&str, ArrowError> { + let len = (first_byte_from_slice(value)? >> 2) as usize; + + let string = string_from_slice(value, 1..1 + len)?; + Ok(string) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_i8() -> Result<(), ArrowError> { + let value = [ + 0 | 3 << 2, // Primitive type for i8 + 42, + ]; + let result = decode_int8(&value)?; + assert_eq!(result, 42); + Ok(()) + } + + #[test] + fn test_short_string() -> Result<(), ArrowError> { + let value = [ + 1 | 5 << 2, // Basic type for short string | length of short string + 'H' as u8, + 'e' as u8, + 'l' as u8, + 'l' as u8, + 'o' as u8, + 'o' as u8, + ]; + let result = decode_short_string(&value)?; + assert_eq!(result, "Hello"); + Ok(()) + } + + #[test] + fn test_string() -> Result<(), ArrowError> { + let value = [ + 0 | 16 << 2, // Basic type for short string | length of short string + 5, + 0, + 0, + 0, // Length of string + 'H' as u8, + 'e' as u8, + 'l' as u8, + 'l' as u8, + 'o' as u8, + 'o' as u8, + ]; + let result = decode_long_string(&value)?; + assert_eq!(result, "Hello"); + Ok(()) + } +} diff --git a/parquet-variant/src/lib.rs b/parquet-variant/src/lib.rs index 6289f86a263f..a31187daeb69 100644 --- a/parquet-variant/src/lib.rs +++ b/parquet-variant/src/lib.rs @@ -26,3 +26,16 @@ //! If you are interested in helping, you can find more information on the GitHub [Variant issue] //! //! [Variant issue]: https://github.com/apache/arrow-rs/issues/6736 + +// TODO: dead code removal +#[allow(dead_code)] +mod decoder; +// TODO: dead code removal +#[allow(dead_code)] +mod variant; +// TODO: dead code removal +#[allow(dead_code)] +mod utils; + +#[cfg(test)] +mod test_variant; diff --git a/parquet-variant/src/test_variant.rs b/parquet-variant/src/test_variant.rs new file mode 100644 index 000000000000..07c9eaf9c6f0 --- /dev/null +++ b/parquet-variant/src/test_variant.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! End-to-end check: (almost) every sample from apache/parquet-testing/variant +//! can be parsed into our `Variant`. + +// NOTE: We keep this file separate rather than a test mod inside variant.rs because it should be +// moved to the test folder later +use std::fs; +use std::path::{Path, PathBuf}; + +use crate::variant::{Variant, VariantMetadata}; +use arrow_schema::ArrowError; + +fn cases_dir() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("parquet-testing") + .join("variant") +} + +fn load_case(name: &str) -> Result<(Vec, Vec), ArrowError> { + let root = cases_dir(); + let meta = fs::read(root.join(format!("{name}.metadata")))?; + let val = fs::read(root.join(format!("{name}.value")))?; + Ok((meta, val)) +} + +fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { + vec![ + ("primitive_boolean_false", Variant::BooleanFalse), + ("primitive_boolean_true", Variant::BooleanTrue), + ("primitive_int8", Variant::Int8(42)), + // Using the From trait + ("primitive_string", Variant::from("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")), + // Using the From trait + ("short_string", Variant::from("Less than 64 bytes (❤\u{fe0f} with utf8)")), + // TODO Reenable when https://github.com/apache/parquet-testing/issues/81 is fixed + // ("primitive_null", Variant::Null), + ] +} + +fn get_non_primitive_cases() -> Vec<&'static str> { + vec!["object_primitive", "array_primitive"] +} + +#[test] +fn variant_primitive() -> Result<(), ArrowError> { + let cases = get_primitive_cases(); + for (case, want) in cases { + let (metadata_bytes, value) = load_case(case)?; + let metadata = VariantMetadata::try_new(&metadata_bytes)?; + let got = Variant::try_new(&metadata, &value)?; + assert_eq!(got, want); + } + Ok(()) +} + +#[test] +fn variant_non_primitive() -> Result<(), ArrowError> { + let cases = get_non_primitive_cases(); + for case in cases { + let (metadata, value) = load_case(case)?; + let metadata = VariantMetadata::try_new(&metadata)?; + let variant = Variant::try_new(&metadata, &value)?; + match case { + "object_primitive" => { + assert!(matches!(variant, Variant::Object(_))); + assert_eq!(metadata.dictionary_size(), 7); + let dict_val = metadata.get_field_by(0)?; + assert_eq!(dict_val, "int_field"); + } + "array_primitive" => match variant { + Variant::Array(arr) => { + let v = arr.get(0)?; + assert!(matches!(v, Variant::Int8(2))); + let v = arr.get(1)?; + assert!(matches!(v, Variant::Int8(1))); + } + _ => panic!("expected an array"), + }, + _ => unreachable!(), + } + } + Ok(()) +} diff --git a/parquet-variant/src/utils.rs b/parquet-variant/src/utils.rs new file mode 100644 index 000000000000..85feb0bcb1c9 --- /dev/null +++ b/parquet-variant/src/utils.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use std::{array::TryFromSliceError, ops::Range, str}; + +use arrow_schema::ArrowError; + +use std::fmt::Debug; +use std::slice::SliceIndex; + +#[inline] + +pub(crate) fn slice_from_slice + Clone + Debug>( + bytes: &[u8], + index: I, +) -> Result<&I::Output, ArrowError> { + bytes.get(index.clone()).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Tried to extract byte(s) {index:?} from {}-byte buffer", + bytes.len(), + )) + }) +} +pub(crate) fn array_from_slice( + bytes: &[u8], + offset: usize, +) -> Result<[u8; N], ArrowError> { + let bytes = slice_from_slice(bytes, offset..offset + N)?; + bytes.try_into().map_err(map_try_from_slice_error) +} + +/// To be used in `map_err` when unpacking an integer from a slice of bytes. +pub(crate) fn map_try_from_slice_error(e: TryFromSliceError) -> ArrowError { + ArrowError::InvalidArgumentError(e.to_string()) +} + +pub(crate) fn first_byte_from_slice(slice: &[u8]) -> Result<&u8, ArrowError> { + slice + .get(0) + .ok_or_else(|| ArrowError::InvalidArgumentError("Received empty bytes".to_string())) +} + +/// Helper to get a &str from a slice based on range, if it's valid or an error otherwise +pub(crate) fn string_from_slice(slice: &[u8], range: Range) -> Result<&str, ArrowError> { + str::from_utf8(slice_from_slice(slice, range)?) + .map_err(|_| ArrowError::InvalidArgumentError("invalid UTF-8 string".to_string())) +} diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs new file mode 100644 index 000000000000..cf9f51acc72d --- /dev/null +++ b/parquet-variant/src/variant.rs @@ -0,0 +1,719 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use crate::decoder::{ + self, get_basic_type, get_primitive_type, VariantBasicType, VariantPrimitiveType, +}; +use crate::utils::{array_from_slice, first_byte_from_slice, slice_from_slice, string_from_slice}; +use arrow_schema::ArrowError; +use std::{ + num::TryFromIntError, + ops::{Index, Range}, +}; + +#[derive(Clone, Debug, Copy, PartialEq)] +enum OffsetSizeBytes { + One = 1, + Two = 2, + Three = 3, + Four = 4, +} + +impl OffsetSizeBytes { + /// Build from the `offset_size_minus_one` bits (see spec). + fn try_new(offset_size_minus_one: u8) -> Result { + use OffsetSizeBytes::*; + let result = match offset_size_minus_one { + 0 => One, + 1 => Two, + 2 => Three, + 3 => Four, + _ => { + return Err(ArrowError::InvalidArgumentError( + "offset_size_minus_one must be 0–3".to_string(), + )) + } + }; + Ok(result) + } + + /// Return one unsigned little-endian value from `bytes`. + /// + /// * `bytes` – the Variant-metadata buffer. + /// * `byte_offset` – number of bytes to skip **before** reading the first + /// value (usually `1` to move past the header byte). + /// * `offset_index` – 0-based index **after** the skip + /// (`0` is the first value, `1` the next, …). + /// + /// Each value is `self as usize` bytes wide (1, 2, 3 or 4). + /// Three-byte values are zero-extended to 32 bits before the final + /// fallible cast to `usize`. + fn unpack_usize( + &self, + bytes: &[u8], + byte_offset: usize, // how many bytes to skip + offset_index: usize, // which offset in an array of offsets + ) -> Result { + use OffsetSizeBytes::*; + let offset = byte_offset + (*self as usize) * offset_index; + let result = match self { + One => u8::from_le_bytes(array_from_slice(bytes, offset)?).into(), + Two => u16::from_le_bytes(array_from_slice(bytes, offset)?).into(), + Three => { + // Let's grab the three byte le-chunk first + let b3_chunks: [u8; 3] = array_from_slice(bytes, offset)?; + // Let's pad it and construct a padded u32 from it. + let mut buf = [0u8; 4]; + buf[..3].copy_from_slice(&b3_chunks); + u32::from_le_bytes(buf) + .try_into() + .map_err(|e: TryFromIntError| ArrowError::InvalidArgumentError(e.to_string()))? + } + Four => u32::from_le_bytes(array_from_slice(bytes, offset)?) + .try_into() + .map_err(|e: TryFromIntError| ArrowError::InvalidArgumentError(e.to_string()))?, + }; + Ok(result) + } +} + +#[derive(Clone, Debug, Copy, PartialEq)] +pub(crate) struct VariantMetadataHeader { + version: u8, + is_sorted: bool, + /// Note: This is `offset_size_minus_one` + 1 + offset_size: OffsetSizeBytes, +} + +// According to the spec this is currently always = 1, and so we store this const for validation +// purposes and to make that visible. +const CORRECT_VERSION_VALUE: u8 = 1; + +impl VariantMetadataHeader { + /// Tries to construct the variant metadata header, which has the form + /// 7 6 5 4 3 0 + /// +-------+---+---+---------------+ + /// header | | | | version | + /// +-------+---+---+---------------+ + /// ^ ^ + /// | +-- sorted_strings + /// +-- offset_size_minus_one + /// The version is a 4-bit value that must always contain the value 1. + /// - sorted_strings is a 1-bit value indicating whether dictionary strings are sorted and unique. + /// - offset_size_minus_one is a 2-bit value providing the number of bytes per dictionary size and offset field. + /// - The actual number of bytes, offset_size, is offset_size_minus_one + 1 + pub fn try_new(bytes: &[u8]) -> Result { + let header = first_byte_from_slice(bytes)?; + + let version = header & 0x0F; // First four bits + if version != CORRECT_VERSION_VALUE { + let err_msg = format!( + "The version bytes in the header is not {CORRECT_VERSION_VALUE}, got {:b}", + version + ); + return Err(ArrowError::InvalidArgumentError(err_msg)); + } + let is_sorted = (header & 0x10) != 0; // Fifth bit + let offset_size_minus_one = header >> 6; // Last two bits + Ok(Self { + version, + is_sorted, + offset_size: OffsetSizeBytes::try_new(offset_size_minus_one)?, + }) + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +/// Encodes the Variant Metadata, see the Variant spec file for more information +pub struct VariantMetadata<'m> { + bytes: &'m [u8], + header: VariantMetadataHeader, + dict_size: usize, + dictionary_key_start_byte: usize, +} + +impl<'m> VariantMetadata<'m> { + /// View the raw bytes (needed by very low-level decoders) + #[inline] + pub const fn as_bytes(&self) -> &'m [u8] { + self.bytes + } + + pub fn try_new(bytes: &'m [u8]) -> Result { + let header = VariantMetadataHeader::try_new(bytes)?; + // Offset 1, index 0 because first element after header is dictionary size + let dict_size = header.offset_size.unpack_usize(bytes, 1, 0)?; + + // Check that we have the correct metadata length according to dictionary_size, or return + // error early. + // Minimum number of bytes the metadata buffer must contain: + // 1 byte header + // + offset_size-byte `dictionary_size` field + // + (dict_size + 1) offset entries, each `offset_size` bytes. (Table size, essentially) + // 1 + offset_size + (dict_size + 1) * offset_size + // = (dict_size + 2) * offset_size + 1 + let offset_size = header.offset_size as usize; // Cheap to copy + + let dictionary_key_start_byte = dict_size + .checked_add(2) + .and_then(|n| n.checked_mul(offset_size)) + .and_then(|n| n.checked_add(1)) + .ok_or_else(|| ArrowError::InvalidArgumentError("metadata length overflow".into()))?; + + if bytes.len() < dictionary_key_start_byte { + return Err(ArrowError::InvalidArgumentError( + "Metadata shorter than dictionary_size implies".to_string(), + )); + } + + // Check that all offsets are monotonically increasing + let mut offsets = (0..=dict_size).map(|i| header.offset_size.unpack_usize(bytes, 1, i + 1)); + let Some(Ok(mut end @ 0)) = offsets.next() else { + return Err(ArrowError::InvalidArgumentError( + "First offset is non-zero".to_string(), + )); + }; + + for offset in offsets { + let offset = offset?; + if end >= offset { + return Err(ArrowError::InvalidArgumentError( + "Offsets are not monotonically increasing".to_string(), + )); + } + end = offset; + } + + // Verify the buffer covers the whole dictionary-string section + if end > bytes.len() - dictionary_key_start_byte { + // `prev` holds the last offset seen still + return Err(ArrowError::InvalidArgumentError( + "Last offset does not equal dictionary length".to_string(), + )); + } + + Ok(Self { + bytes, + header, + dict_size, + dictionary_key_start_byte, + }) + } + + /// Whether the dictionary keys are sorted and unique + pub fn is_sorted(&self) -> bool { + self.header.is_sorted + } + + /// Get the dictionary size + pub fn dictionary_size(&self) -> usize { + self.dict_size + } + pub fn version(&self) -> u8 { + self.header.version + } + + /// Helper method to get the offset start and end range for a key by index. + fn get_offsets_for_key_by(&self, index: usize) -> Result, ArrowError> { + if index >= self.dict_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Index {} out of bounds for dictionary of length {}", + index, self.dict_size + ))); + } + + // Skipping the header byte (setting byte_offset = 1) and the dictionary_size (setting offset_index +1) + let unpack = |i| self.header.offset_size.unpack_usize(self.bytes, 1, i + 1); + Ok(unpack(index)?..unpack(index + 1)?) + } + + /// Get a single offset by index + pub fn get_offset_by(&self, index: usize) -> Result { + if index >= self.dict_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Index {} out of bounds for dictionary of length {}", + index, self.dict_size + ))); + } + + // Skipping the header byte (setting byte_offset = 1) and the dictionary_size (setting offset_index +1) + let unpack = |i| self.header.offset_size.unpack_usize(self.bytes, 1, i + 1); + Ok(unpack(index)?) + } + + /// Get the key-name by index + pub fn get_field_by(&self, index: usize) -> Result<&'m str, ArrowError> { + let offset_range = self.get_offsets_for_key_by(index)?; + self.get_field_by_offset(offset_range) + } + + /// Gets the field using an offset (Range) - helper method to keep consistent API. + pub(crate) fn get_field_by_offset(&self, offset: Range) -> Result<&'m str, ArrowError> { + let dictionary_keys_bytes = + slice_from_slice(self.bytes, self.dictionary_key_start_byte..self.bytes.len())?; + let result = string_from_slice(dictionary_keys_bytes, offset)?; + + Ok(result) + } + + pub fn header(&self) -> VariantMetadataHeader { + self.header + } + + /// Get the offsets as an iterator + pub fn offsets(&self) -> impl Iterator, ArrowError>> + 'm { + let offset_size = self.header.offset_size; // `Copy` + let bytes = self.bytes; + + let iterator = (0..self.dict_size).map(move |i| { + // This wont be out of bounds as long as dict_size and offsets have been validated + // during construction via `try_new`, as it calls unpack_usize for the + // indices `1..dict_size+1` already. + let start = offset_size.unpack_usize(bytes, 1, i + 1); + let end = offset_size.unpack_usize(bytes, 1, i + 2); + + match (start, end) { + (Ok(s), Ok(e)) => Ok(s..e), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }); + + iterator + } + + /// Get all key-names as an Iterator of strings + pub fn fields( + &'m self, + ) -> Result>, ArrowError> { + let iterator = self + .offsets() + .map(move |offset_range| self.get_field_by_offset(offset_range?)); + Ok(iterator) + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct VariantObject<'m, 'v> { + pub metadata: &'m VariantMetadata<'m>, + pub value: &'v [u8], +} +impl<'m, 'v> VariantObject<'m, 'v> { + pub fn fields(&self) -> Result)>, ArrowError> { + todo!(); + #[allow(unreachable_code)] // Just to infer the return type + Ok(vec![].into_iter()) + } + pub fn field(&self, _name: &'m str) -> Result, ArrowError> { + todo!() + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct VariantArray<'m, 'v> { + pub metadata: &'m VariantMetadata<'m>, + pub value: &'v [u8], +} + +impl<'m, 'v> VariantArray<'m, 'v> { + pub fn len(&self) -> usize { + todo!() + } + + pub fn values(&self) -> Result>, ArrowError> { + todo!(); + #[allow(unreachable_code)] // Just to infer the return type + Ok(vec![].into_iter()) + } + + pub fn get(&self, index: usize) -> Result, ArrowError> { + // The 6 first bits to the left are the value_header and the 2 bits + // to the right are the basic type, so we shift to get only the value_header + let value_header = first_byte_from_slice(self.value)? >> 2; + let is_large = (value_header & 0x04) != 0; // 3rd bit from the right + let field_offset_size_minus_one = value_header & 0x03; // Last two bits + let offset_size = OffsetSizeBytes::try_new(field_offset_size_minus_one)?; + // The size of the num_elements entry in the array value_data is 4 bytes if + // is_large is true, otherwise 1 byte. + let num_elements_size = match is_large { + true => OffsetSizeBytes::Four, + false => OffsetSizeBytes::One, + }; + // Skip the header byte to read the num_elements + // The size of the num_elements entry in the array value_data is 4 bytes if + // is_large is true, otherwise 1 byte. + let num_elements = num_elements_size.unpack_usize(self.value, 1, 0)?; + let first_offset_byte = 1 + num_elements_size as usize; + + let overflow = + || ArrowError::InvalidArgumentError("Variant value_byte_length overflow".into()); + + // 1. num_elements + 1 + let n_offsets = num_elements.checked_add(1).ok_or_else(overflow)?; + + // 2. (num_elements + 1) * offset_size + let value_bytes = n_offsets + .checked_mul(offset_size as usize) + .ok_or_else(overflow)?; + + // 3. first_offset_byte + ... + let first_value_byte = first_offset_byte + .checked_add(value_bytes) + .ok_or_else(overflow)?; + + // Skip header and num_elements bytes to read the offsets + let start_field_offset_from_first_value_byte = + offset_size.unpack_usize(self.value, first_offset_byte, index)?; + let end_field_offset_from_first_value_byte = + offset_size.unpack_usize(self.value, first_offset_byte, index + 1)?; + + // Read the value bytes from the offsets + let variant_value_bytes = slice_from_slice( + self.value, + first_value_byte + start_field_offset_from_first_value_byte + ..first_value_byte + end_field_offset_from_first_value_byte, + )?; + let variant = Variant::try_new(self.metadata, variant_value_bytes)?; + Ok(variant) + } +} + +// impl<'m, 'v> Index for VariantArray<'m, 'v> { +// type Output = Variant<'m, 'v>; +// +// } + +/// Variant value. May contain references to metadata and value +#[derive(Clone, Debug, Copy, PartialEq)] +pub enum Variant<'m, 'v> { + // TODO: Add types for the rest of the primitive types, once API is agreed upon + Null, + Int8(i8), + + BooleanTrue, + BooleanFalse, + + // Note: only need the *value* buffer + String(&'v str), + ShortString(&'v str), + + // need both metadata & value + Object(VariantObject<'m, 'v>), + Array(VariantArray<'m, 'v>), +} + +impl<'m, 'v> Variant<'m, 'v> { + /// Parse the buffers and return the appropriate variant. + pub fn try_new(metadata: &'m VariantMetadata, value: &'v [u8]) -> Result { + let header = *first_byte_from_slice(value)?; + let new_self = match get_basic_type(header)? { + VariantBasicType::Primitive => match get_primitive_type(header)? { + VariantPrimitiveType::Null => Variant::Null, + VariantPrimitiveType::Int8 => Variant::Int8(decoder::decode_int8(value)?), + VariantPrimitiveType::BooleanTrue => Variant::BooleanTrue, + VariantPrimitiveType::BooleanFalse => Variant::BooleanFalse, + // TODO: Add types for the rest, once API is agreed upon + VariantPrimitiveType::String => { + Variant::String(decoder::decode_long_string(value)?) + } + }, + VariantBasicType::ShortString => { + Variant::ShortString(decoder::decode_short_string(value)?) + } + VariantBasicType::Object => Variant::Object(VariantObject { metadata, value }), + VariantBasicType::Array => Variant::Array(VariantArray { metadata, value }), + }; + Ok(new_self) + } + + pub fn as_null(&self) -> Option<()> { + matches!(self, Variant::Null).then_some(()) + } + + pub fn as_boolean(&self) -> Option { + match self { + Variant::BooleanTrue => Some(true), + Variant::BooleanFalse => Some(false), + _ => None, + } + } + + pub fn as_string(&'v self) -> Option<&'v str> { + match self { + Variant::String(s) | Variant::ShortString(s) => Some(s), + _ => None, + } + } + + pub fn as_int8(&self) -> Option { + match *self { + Variant::Int8(i) => Some(i), + // TODO: Add branches for type-widening/shortening when implemting rest of primitives for int + // Variant::Int16(i) => i.try_into().ok(), + // ... + _ => None, + } + } + + pub fn metadata(&self) -> Option<&'m VariantMetadata> { + match self { + Variant::Object(VariantObject { metadata, .. }) + | Variant::Array(VariantArray { metadata, .. }) => Some(*metadata), + _ => None, + } + } +} + +impl<'m, 'v> From for Variant<'m, 'v> { + fn from(value: i8) -> Self { + Variant::Int8(value) + } +} + +impl<'m, 'v> From for Variant<'m, 'v> { + fn from(value: bool) -> Self { + match value { + true => Variant::BooleanTrue, + false => Variant::BooleanFalse, + } + } +} + +impl<'m, 'v> From<&'v str> for Variant<'m, 'v> { + fn from(value: &'v str) -> Self { + if value.len() < 64 { + Variant::ShortString(value) + } else { + Variant::String(value) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_offset() { + assert_eq!(OffsetSizeBytes::try_new(0).unwrap(), OffsetSizeBytes::One); + assert_eq!(OffsetSizeBytes::try_new(1).unwrap(), OffsetSizeBytes::Two); + assert_eq!(OffsetSizeBytes::try_new(2).unwrap(), OffsetSizeBytes::Three); + assert_eq!(OffsetSizeBytes::try_new(3).unwrap(), OffsetSizeBytes::Four); + + // everything outside 0-3 must error + assert!(OffsetSizeBytes::try_new(4).is_err()); + assert!(OffsetSizeBytes::try_new(255).is_err()); + } + + #[test] + fn unpack_usize_all_widths() { + // One-byte offsets + let buf_one = [0x01u8, 0xAB, 0xCD]; + assert_eq!( + OffsetSizeBytes::One.unpack_usize(&buf_one, 0, 0).unwrap(), + 0x01 + ); + assert_eq!( + OffsetSizeBytes::One.unpack_usize(&buf_one, 0, 2).unwrap(), + 0xCD + ); + + // Two-byte offsets (little-endian 0x1234, 0x5678) + let buf_two = [0x34, 0x12, 0x78, 0x56]; + assert_eq!( + OffsetSizeBytes::Two.unpack_usize(&buf_two, 0, 0).unwrap(), + 0x1234 + ); + assert_eq!( + OffsetSizeBytes::Two.unpack_usize(&buf_two, 0, 1).unwrap(), + 0x5678 + ); + + // Three-byte offsets (0x030201 and 0x0000FF) + let buf_three = [0x01, 0x02, 0x03, 0xFF, 0x00, 0x00]; + assert_eq!( + OffsetSizeBytes::Three + .unpack_usize(&buf_three, 0, 0) + .unwrap(), + 0x0302_01 + ); + assert_eq!( + OffsetSizeBytes::Three + .unpack_usize(&buf_three, 0, 1) + .unwrap(), + 0x0000_FF + ); + + // Four-byte offsets (0x12345678, 0x90ABCDEF) + let buf_four = [0x78, 0x56, 0x34, 0x12, 0xEF, 0xCD, 0xAB, 0x90]; + assert_eq!( + OffsetSizeBytes::Four.unpack_usize(&buf_four, 0, 0).unwrap(), + 0x1234_5678 + ); + assert_eq!( + OffsetSizeBytes::Four.unpack_usize(&buf_four, 0, 1).unwrap(), + 0x90AB_CDEF + ); + } + + #[test] + fn unpack_usize_out_of_bounds() { + let tiny = [0x00u8]; // deliberately too short + assert!(OffsetSizeBytes::Two.unpack_usize(&tiny, 0, 0).is_err()); + assert!(OffsetSizeBytes::Three.unpack_usize(&tiny, 0, 0).is_err()); + } + + #[test] + fn unpack_simple() { + let buf = [ + 0x41, // header + 0x02, 0x00, // dictionary_size = 2 + 0x00, 0x00, // offset[0] = 0 + 0x05, 0x00, // offset[1] = 5 + 0x09, 0x00, // offset[2] = 9 + ]; + + let width = OffsetSizeBytes::Two; + + // dictionary_size starts immediately after the header + let dict_size = width.unpack_usize(&buf, 1, 0).unwrap(); + assert_eq!(dict_size, 2); + + let first = width.unpack_usize(&buf, 1, 1).unwrap(); + assert_eq!(first, 0); + + let second = width.unpack_usize(&buf, 1, 2).unwrap(); + assert_eq!(second, 5); + + let third = width.unpack_usize(&buf, 1, 3).unwrap(); + assert_eq!(third, 9); + + let err = width.unpack_usize(&buf, 1, 4); + assert!(err.is_err()) + } + + /// `"cat"`, `"dog"` – valid metadata + #[test] + fn try_new_ok_inline() { + let bytes = &[ + 0b0000_0001, // header, offset_size_minus_one=0 and version=1 + 0x02, // dictionary_size (2 strings) + 0x00, + 0x03, + 0x06, + b'c', + b'a', + b't', + b'd', + b'o', + b'g', + ]; + + let md = VariantMetadata::try_new(bytes).expect("should parse"); + assert_eq!(md.dictionary_size(), 2); + // Fields + assert_eq!(md.get_field_by(0).unwrap(), "cat"); + assert_eq!(md.get_field_by(1).unwrap(), "dog"); + + // Offsets + assert_eq!(md.get_offset_by(0).unwrap(), 0x00); + assert_eq!(md.get_offset_by(1).unwrap(), 0x03); + // We only have 2 keys, the final offset should not be accessible using this method. + let err = md.get_offset_by(2).unwrap_err(); + + assert!( + matches!(err, ArrowError::InvalidArgumentError(ref msg) + if msg.contains("Index 2 out of bounds for dictionary of length 2")), + "unexpected error: {err:?}" + ); + let fields: Vec<(usize, &str)> = md + .fields() + .unwrap() + .enumerate() + .map(|(i, r)| (i, r.unwrap())) + .collect(); + assert_eq!(fields, vec![(0usize, "cat"), (1usize, "dog")]); + } + + /// Too short buffer test (missing one required offset). + /// Should error with “metadata shorter than dictionary_size implies”. + #[test] + fn try_new_missing_last_value() { + let bytes = &[ + 0b0000_0001, // header, offset_size_minus_one=0 and version=1 + 0x02, // dictionary_size = 2 + 0x00, + 0x01, + 0x02, + b'a', + b'b', // <-- we'll remove this + ]; + + let working_md = VariantMetadata::try_new(bytes).expect("should parse"); + assert_eq!(working_md.dictionary_size(), 2); + assert_eq!(working_md.get_field_by(0).unwrap(), "a"); + assert_eq!(working_md.get_field_by(1).unwrap(), "b"); + + let truncated = &bytes[..bytes.len() - 1]; + + let err = VariantMetadata::try_new(truncated).unwrap_err(); + assert!( + matches!(err, ArrowError::InvalidArgumentError(ref msg) + if msg.contains("Last offset")), + "unexpected error: {err:?}" + ); + } + + #[test] + fn try_new_fails_non_monotonic() { + // 'cat', 'dog', 'lamb' + let bytes = &[ + 0b0000_0001, // header, offset_size_minus_one=0 and version=1 + 0x03, // dictionary_size + 0x00, + 0x02, + 0x01, // Doesn't increase monotonically + 0x10, + b'c', + b'a', + b't', + b'd', + b'o', + b'g', + b'l', + b'a', + b'm', + b'b', + ]; + + let err = VariantMetadata::try_new(bytes).unwrap_err(); + assert!( + matches!(err, ArrowError::InvalidArgumentError(ref msg) if msg.contains("monotonically")), + "unexpected error: {err:?}" + ); + } + + #[test] + fn try_new_truncated_offsets_inline() { + // Missing final offset + let bytes = &[0b0000_0001, 0x02, 0x00, 0x01]; + + let err = VariantMetadata::try_new(bytes).unwrap_err(); + assert!( + matches!(err, ArrowError::InvalidArgumentError(ref msg) if msg.contains("shorter")), + "unexpected error: {err:?}" + ); + } +}