diff --git a/Cargo.lock b/Cargo.lock index cfcc4899c96..55a6c25c9fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,22 @@ dependencies = [ "half", ] +[[package]] +name = "arrow-scalar" +version = "57.0.0" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ord", + "arrow-row", + "arrow-schema", + "half", + "proptest", + "rstest 0.23.0", +] + [[package]] name = "arrow-schema" version = "57.2.0" diff --git a/Cargo.toml b/Cargo.toml index b080335a735..81cd7a0d4f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "rust/lance-tools", "rust/compression/fsst", "rust/compression/bitpacking", + "rust/arrow-scalar", ] exclude = ["python", "java/lance-jni"] # Python package needs to be built by maturin. @@ -72,6 +73,7 @@ lance-testing = { version = "=3.0.0-beta.2", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "57.0.0", optional = false, features = ["prettyprint"] } +arrow-scalar = { version = "=57.0.0", path = "./rust/arrow-scalar" } arrow-arith = "57.0.0" arrow-array = "57.0.0" arrow-buffer = "57.0.0" diff --git a/rust/arrow-scalar/Cargo.toml b/rust/arrow-scalar/Cargo.toml new file mode 100644 index 00000000000..c3d3f9181c3 --- /dev/null +++ b/rust/arrow-scalar/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "arrow-scalar" +version = "57.0.0" +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description = "Arrow scalar type with Ord, Hash, and Eq support" +keywords.workspace = true +categories.workspace = true +rust-version.workspace = true +readme = "README.md" + +[dependencies] +# Note: this is a core crate and we should aim to keep this dependency list +# as minimal as possible. +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-row = { workspace = true } +arrow-schema = { workspace = true } +half = { workspace = true } + +[dev-dependencies] +arrow-ord = { workspace = true } +proptest = { workspace = true } +rstest = { workspace = true } + +[lints] +workspace = true diff --git a/rust/arrow-scalar/src/convert.rs b/rust/arrow-scalar/src/convert.rs new file mode 100644 index 00000000000..de783a3a604 --- /dev/null +++ b/rust/arrow-scalar/src/convert.rs @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::*; +use half::f16; + +use crate::ArrowScalar; + +macro_rules! impl_from_primitive { + ($native_ty:ty, $array_ty:ty) => { + impl From<$native_ty> for ArrowScalar { + fn from(value: $native_ty) -> Self { + let array: ArrayRef = Arc::new(<$array_ty>::from(vec![value])); + Self::try_from_array(array).expect("single-element primitive array is always valid") + } + } + }; +} + +impl_from_primitive!(i8, Int8Array); +impl_from_primitive!(i16, Int16Array); +impl_from_primitive!(i32, Int32Array); +impl_from_primitive!(i64, Int64Array); +impl_from_primitive!(u8, UInt8Array); +impl_from_primitive!(u16, UInt16Array); +impl_from_primitive!(u32, UInt32Array); +impl_from_primitive!(u64, UInt64Array); +impl_from_primitive!(f32, Float32Array); +impl_from_primitive!(f64, Float64Array); + +impl From for ArrowScalar { + fn from(value: bool) -> Self { + let array: ArrayRef = Arc::new(BooleanArray::from(vec![value])); + Self::try_from_array(array).expect("single-element boolean array is always valid") + } +} + +impl From for ArrowScalar { + fn from(value: f16) -> Self { + let array: ArrayRef = Arc::new(Float16Array::from(vec![value])); + Self::try_from_array(array).expect("single-element f16 array is always valid") + } +} + +impl From<&str> for ArrowScalar { + fn from(value: &str) -> Self { + let array: ArrayRef = Arc::new(StringArray::from(vec![value])); + Self::try_from_array(array).expect("single-element string array is always valid") + } +} + +impl From for ArrowScalar { + fn from(value: String) -> Self { + Self::from(value.as_str()) + } +} + +impl From<&[u8]> for ArrowScalar { + fn from(value: &[u8]) -> Self { + let array: ArrayRef = Arc::new(BinaryArray::from_vec(vec![value])); + Self::try_from_array(array).expect("single-element binary array is always valid") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_primitives() { + let s = ArrowScalar::from(42i32); + assert!(!s.is_null()); + assert_eq!(format!("{s}"), "42"); + + let s = ArrowScalar::from(1.5f64); + assert!(!s.is_null()); + + let s = ArrowScalar::from(true); + assert_eq!(format!("{s}"), "true"); + } + + #[test] + fn test_from_string_types() { + let s = ArrowScalar::from("hello"); + assert_eq!(format!("{s}"), "hello"); + + let s = ArrowScalar::from(String::from("world")); + assert_eq!(format!("{s}"), "world"); + } + + #[test] + fn test_from_binary() { + let bytes: &[u8] = &[0xDE, 0xAD]; + let s = ArrowScalar::from(bytes); + assert!(!s.is_null()); + } + + #[test] + fn test_from_f16() { + let s = ArrowScalar::from(f16::from_f32(1.5)); + assert!(!s.is_null()); + } +} diff --git a/rust/arrow-scalar/src/lib.rs b/rust/arrow-scalar/src/lib.rs new file mode 100644 index 00000000000..4bc9a588c97 --- /dev/null +++ b/rust/arrow-scalar/src/lib.rs @@ -0,0 +1,580 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! A scalar type backed by a single-element Arrow array with [`Ord`], [`Hash`], +//! and [`Eq`] support. +//! +//! Comparisons and hashing are delegated to [`arrow_row::OwnedRow`], which +//! provides a correct total ordering for all Arrow types (including proper NaN +//! handling for floats and null ordering). + +mod convert; +pub mod serde; + +use std::cmp::Ordering; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow_array::{make_array, new_null_array, ArrayRef}; +use arrow_cast::display::ArrayFormatter; +use arrow_data::transform::MutableArrayData; +use arrow_row::{OwnedRow, RowConverter, SortField}; +use arrow_schema::{ArrowError, DataType}; + +type Result = std::result::Result; + +/// A scalar value backed by a length-1 Arrow array. +/// +/// `ArrowScalar` provides [`Eq`], [`Ord`], and [`Hash`] by caching an +/// [`OwnedRow`] at construction time. This means comparisons and hashing are +/// O(1) row-byte operations rather than per-type dispatch. +/// +/// # Cross-type comparison +/// +/// Comparing scalars of different data types produces an arbitrary but +/// consistent ordering based on the underlying row bytes. This is intentional +/// — it allows scalars to be used as keys in sorted collections regardless of +/// type, but the ordering across types is not semantically meaningful. +/// +/// # Examples +/// +/// ``` +/// use arrow_scalar::ArrowScalar; +/// +/// let a = ArrowScalar::from(1i32); +/// let b = ArrowScalar::from(2i32); +/// assert!(a < b); +/// +/// let c = ArrowScalar::from("hello"); +/// assert_eq!(c, ArrowScalar::from("hello")); +/// ``` +pub struct ArrowScalar { + array: ArrayRef, + row: OwnedRow, +} + +impl ArrowScalar { + /// Create a scalar by extracting the element at `offset` from `array`. + pub fn try_new(array: &ArrayRef, offset: usize) -> Result { + if offset >= array.len() { + return Err(ArrowError::InvalidArgumentError( + "Scalar index out of bounds".to_string(), + )); + } + + let data = array.to_data(); + let mut mutable = MutableArrayData::new(vec![&data], true, 1); + mutable.extend(0, offset, offset + 1); + let single = make_array(mutable.freeze()); + Self::try_from_array(single) + } + + /// Create a scalar from a length-1 array. + pub fn try_from_array(array: ArrayRef) -> Result { + if array.len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "ArrowScalar requires a length-1 array, got length {}", + array.len() + ))); + } + + let row = Self::compute_row(&array)?; + Ok(Self { array, row }) + } + + /// Create a null scalar of the given data type. + pub fn new_null(data_type: &DataType) -> Result { + Self::try_from_array(new_null_array(data_type, 1)) + } + + fn compute_row(array: &ArrayRef) -> Result { + let sort_field = SortField::new(array.data_type().clone()); + let converter = RowConverter::new(vec![sort_field])?; + let rows = converter.convert_columns(&[Arc::clone(array)])?; + Ok(rows.row(0).owned()) + } + + /// Returns a reference to the underlying length-1 array. + pub fn as_array(&self) -> &ArrayRef { + &self.array + } + + /// Returns the data type of this scalar. + pub fn data_type(&self) -> &DataType { + self.array.data_type() + } + + /// Returns `true` if this scalar is null. + pub fn is_null(&self) -> bool { + self.array.null_count() == 1 + } +} + +impl PartialEq for ArrowScalar { + fn eq(&self, other: &Self) -> bool { + self.row == other.row + } +} + +impl Eq for ArrowScalar {} + +impl PartialOrd for ArrowScalar { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ArrowScalar { + fn cmp(&self, other: &Self) -> Ordering { + self.row.cmp(&other.row) + } +} + +impl Hash for ArrowScalar { + fn hash(&self, state: &mut H) { + self.row.hash(state); + } +} + +impl fmt::Display for ArrowScalar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_null() { + return write!(f, "null"); + } + let formatter = + ArrayFormatter::try_new(&self.array, &Default::default()).map_err(|_| fmt::Error)?; + write!(f, "{}", formatter.value(0)) + } +} + +impl fmt::Debug for ArrowScalar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ArrowScalar({}: {})", self.data_type(), self) + } +} + +impl Clone for ArrowScalar { + fn clone(&self) -> Self { + Self { + array: Arc::clone(&self.array), + row: self.row.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::{BTreeSet, HashSet}; + use std::sync::Arc; + + use arrow_array::*; + use rstest::rstest; + + use super::*; + + #[test] + fn test_try_new_extracts_element() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let s = ArrowScalar::try_new(&array, 1).unwrap(); + assert_eq!(format!("{s}"), "20"); + } + + #[test] + fn test_try_new_out_of_bounds() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1])); + assert!(ArrowScalar::try_new(&array, 5).is_err()); + } + + #[test] + fn test_try_from_array_wrong_length() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + assert!(ArrowScalar::try_from_array(array).is_err()); + } + + #[test] + fn test_equality() { + let a = ArrowScalar::from(42i32); + let b = ArrowScalar::from(42i32); + let c = ArrowScalar::from(99i32); + assert_eq!(a, b); + assert_ne!(a, c); + } + + #[test] + fn test_ordering() { + let a = ArrowScalar::from(1i32); + let b = ArrowScalar::from(2i32); + let c = ArrowScalar::from(3i32); + assert!(a < b); + assert!(b < c); + assert_eq!(a.cmp(&a), Ordering::Equal); + } + + #[test] + fn test_hash_consistent_with_eq() { + use std::hash::DefaultHasher; + + let a = ArrowScalar::from(42i32); + let b = ArrowScalar::from(42i32); + let hash_a = { + let mut h = DefaultHasher::new(); + a.hash(&mut h); + h.finish() + }; + let hash_b = { + let mut h = DefaultHasher::new(); + b.hash(&mut h); + h.finish() + }; + assert_eq!(hash_a, hash_b); + } + + #[test] + fn test_in_hashset() { + let mut set = HashSet::new(); + set.insert(ArrowScalar::from(1i32)); + set.insert(ArrowScalar::from(2i32)); + set.insert(ArrowScalar::from(1i32)); + assert_eq!(set.len(), 2); + } + + #[test] + fn test_in_btreeset() { + let mut set = BTreeSet::new(); + set.insert(ArrowScalar::from(3i32)); + set.insert(ArrowScalar::from(1i32)); + set.insert(ArrowScalar::from(2i32)); + let values: Vec<_> = set.iter().map(|s| format!("{s}")).collect(); + assert_eq!(values, vec!["1", "2", "3"]); + } + + #[test] + fn test_null_scalar() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![None])); + let s = ArrowScalar::try_from_array(array).unwrap(); + assert!(s.is_null()); + assert_eq!(format!("{s}"), "null"); + } + + #[test] + fn test_null_sorts_first() { + let null_scalar = { + let array: ArrayRef = Arc::new(Int32Array::from(vec![None])); + ArrowScalar::try_from_array(array).unwrap() + }; + let value_scalar = ArrowScalar::from(0i32); + assert!(null_scalar < value_scalar); + } + + #[rstest] + #[case::float_nan( + ArrowScalar::from(f64::NAN), + ArrowScalar::from(f64::INFINITY), + Ordering::Greater + )] + #[case::float_normal(ArrowScalar::from(1.0f64), ArrowScalar::from(2.0f64), Ordering::Less)] + fn test_float_ordering( + #[case] a: ArrowScalar, + #[case] b: ArrowScalar, + #[case] expected: Ordering, + ) { + assert_eq!(a.cmp(&b), expected); + } + + #[test] + fn test_display_string() { + let s = ArrowScalar::from("hello world"); + assert_eq!(format!("{s}"), "hello world"); + } + + #[test] + fn test_debug() { + let s = ArrowScalar::from(42i32); + let debug = format!("{s:?}"); + assert!(debug.contains("ArrowScalar")); + assert!(debug.contains("42")); + } + + #[test] + fn test_clone() { + let a = ArrowScalar::from(42i32); + let b = a.clone(); + assert_eq!(a, b); + } + + #[test] + fn test_data_type() { + let s = ArrowScalar::from(42i32); + assert_eq!(s.data_type(), &DataType::Int32); + } + + #[test] + fn test_boolean_roundtrip() { + let t = ArrowScalar::from(true); + let f = ArrowScalar::from(false); + assert_eq!(t.data_type(), &DataType::Boolean); + assert!(!t.is_null()); + assert_eq!(format!("{t}"), "true"); + assert_eq!(format!("{f}"), "false"); + + // Extract from multi-element array + let array: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); + let s = ArrowScalar::try_new(&array, 1).unwrap(); + assert_eq!(format!("{s}"), "false"); + assert_eq!(s.data_type(), &DataType::Boolean); + } + + #[test] + fn test_boolean_equality_and_ordering() { + let t1 = ArrowScalar::from(true); + let t2 = ArrowScalar::from(true); + let f1 = ArrowScalar::from(false); + assert_eq!(t1, t2); + assert_ne!(t1, f1); + // false < true in arrow row encoding + assert!(f1 < t1); + } + + #[test] + fn test_boolean_null() { + let array: ArrayRef = Arc::new(BooleanArray::from(vec![None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert!(scalar.is_null()); + assert_eq!(scalar.data_type(), &DataType::Boolean); + assert_eq!(format!("{scalar}"), "null"); + + // null sorts before false + let f = ArrowScalar::from(false); + assert!(scalar < f); + } + + #[test] + fn test_string_view_roundtrip() { + let array: ArrayRef = Arc::new(StringViewArray::from(vec![ + "hello world, this is a long string view", + ])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert_eq!(scalar.data_type(), &DataType::Utf8View); + assert!(!scalar.is_null()); + assert_eq!( + format!("{scalar}"), + "hello world, this is a long string view" + ); + + // Extract from multi-element array + let array: ArrayRef = Arc::new(StringViewArray::from(vec!["alpha", "beta", "gamma"])); + let s = ArrowScalar::try_new(&array, 1).unwrap(); + assert_eq!(format!("{s}"), "beta"); + assert_eq!(s.data_type(), &DataType::Utf8View); + } + + #[test] + fn test_binary_view_roundtrip() { + let values: Vec<&[u8]> = vec![b"\xDE\xAD\xBE\xEF"]; + let array: ArrayRef = Arc::new(BinaryViewArray::from(values)); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert_eq!(scalar.data_type(), &DataType::BinaryView); + assert!(!scalar.is_null()); + + // Extract from multi-element array + let values: Vec<&[u8]> = vec![b"aaa", b"bbb", b"ccc"]; + let array: ArrayRef = Arc::new(BinaryViewArray::from(values)); + let s = ArrowScalar::try_new(&array, 2).unwrap(); + assert_eq!(s.data_type(), &DataType::BinaryView); + } + + #[test] + fn test_string_view_equality_and_ordering() { + let mk = |s: &str| { + let array: ArrayRef = Arc::new(StringViewArray::from(vec![s])); + ArrowScalar::try_from_array(array).unwrap() + }; + let a = mk("apple"); + let b = mk("apple"); + let c = mk("banana"); + assert_eq!(a, b); + assert_ne!(a, c); + assert!(a < c); + } + + #[test] + fn test_binary_view_equality_and_ordering() { + let mk = |b: &[u8]| { + let values: Vec<&[u8]> = vec![b]; + let array: ArrayRef = Arc::new(BinaryViewArray::from(values)); + ArrowScalar::try_from_array(array).unwrap() + }; + let a = mk(b"\x01\x02"); + let b = mk(b"\x01\x02"); + let c = mk(b"\x01\x03"); + assert_eq!(a, b); + assert_ne!(a, c); + assert!(a < c); + } + + #[test] + fn test_string_view_in_collections() { + let mk = |s: &str| { + let array: ArrayRef = Arc::new(StringViewArray::from(vec![s])); + ArrowScalar::try_from_array(array).unwrap() + }; + + let mut hset = HashSet::new(); + hset.insert(mk("foo")); + hset.insert(mk("bar")); + hset.insert(mk("foo")); + assert_eq!(hset.len(), 2); + + let mut bset = BTreeSet::new(); + bset.insert(mk("cherry")); + bset.insert(mk("apple")); + bset.insert(mk("banana")); + let sorted: Vec<_> = bset.iter().map(|s| format!("{s}")).collect(); + assert_eq!(sorted, vec!["apple", "banana", "cherry"]); + } + + #[test] + fn test_string_view_null() { + let array: ArrayRef = Arc::new(StringViewArray::from(vec![Option::<&str>::None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert!(scalar.is_null()); + assert_eq!(scalar.data_type(), &DataType::Utf8View); + assert_eq!(format!("{scalar}"), "null"); + } + + #[test] + fn test_binary_view_null() { + let array: ArrayRef = Arc::new(BinaryViewArray::from(vec![Option::<&[u8]>::None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert!(scalar.is_null()); + assert_eq!(scalar.data_type(), &DataType::BinaryView); + } + + #[test] + fn test_cross_type_comparison_is_consistent() { + let int_scalar = ArrowScalar::from(42i32); + let str_scalar = ArrowScalar::from("hello"); + // The ordering is arbitrary but must be consistent + let ord1 = int_scalar.cmp(&str_scalar); + let ord2 = int_scalar.cmp(&str_scalar); + assert_eq!(ord1, ord2); + // And the reverse should be opposite + assert_eq!(str_scalar.cmp(&int_scalar), ord1.reverse()); + } +} + +#[cfg(test)] +mod prop_tests { + use std::sync::Arc; + + use arrow_array::*; + use arrow_ord::sort::sort; + use arrow_schema::SortOptions; + use proptest::prelude::*; + + use super::ArrowScalar; + + /// Generate an arbitrary Arrow array of a randomly chosen type, including + /// nulls. Covers primitives, booleans, string/binary types and their view + /// variants. + fn arbitrary_array() -> BoxedStrategy { + let len = 0..=100usize; + + prop_oneof![ + // --- integer types --- + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Int8Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Int16Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Int32Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Int64Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(UInt8Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(UInt16Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(UInt32Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(UInt64Array::from(v)) as ArrayRef), + // --- float types --- + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Float32Array::from(v)) as ArrayRef), + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(Float64Array::from(v)) as ArrayRef), + // --- boolean --- + proptest::collection::vec(proptest::option::of(any::()), len.clone()) + .prop_map(|v| Arc::new(BooleanArray::from(v)) as ArrayRef), + // --- string types --- + proptest::collection::vec(proptest::option::of(any::()), len.clone()).prop_map( + |v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(StringArray::from(refs)) as ArrayRef + } + ), + proptest::collection::vec(proptest::option::of(any::()), len.clone()).prop_map( + |v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(LargeStringArray::from(refs)) as ArrayRef + } + ), + proptest::collection::vec(proptest::option::of(any::()), len.clone()).prop_map( + |v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(StringViewArray::from(refs)) as ArrayRef + } + ), + // --- binary types --- + proptest::collection::vec( + proptest::option::of(proptest::collection::vec(any::(), 0..50)), + len.clone(), + ) + .prop_map(|v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(BinaryArray::from(refs)) as ArrayRef + }), + proptest::collection::vec( + proptest::option::of(proptest::collection::vec(any::(), 0..50)), + len.clone(), + ) + .prop_map(|v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(LargeBinaryArray::from(refs)) as ArrayRef + }), + proptest::collection::vec( + proptest::option::of(proptest::collection::vec(any::(), 0..50)), + len, + ) + .prop_map(|v| { + let refs: Vec> = v.iter().map(|o| o.as_deref()).collect(); + Arc::new(BinaryViewArray::from(refs)) as ArrayRef + }), + ] + .boxed() + } + + proptest::proptest! { + #[test] + fn sorted_array_produces_sorted_scalars(array in arbitrary_array()) { + let sorted = sort( + &array, + Some(SortOptions { descending: false, nulls_first: true }), + ) + .unwrap(); + + let scalars: Vec = (0..sorted.len()) + .map(|i| ArrowScalar::try_new(&sorted, i).unwrap()) + .collect(); + + for i in 1..scalars.len() { + prop_assert!( + scalars[i - 1] <= scalars[i], + "scalar[{}] ({:?}) should be <= scalar[{}] ({:?})", + i - 1, scalars[i - 1], i, scalars[i], + ); + } + } + } +} diff --git a/rust/arrow-scalar/src/serde.rs b/rust/arrow-scalar/src/serde.rs new file mode 100644 index 00000000000..7a458d13887 --- /dev/null +++ b/rust/arrow-scalar/src/serde.rs @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Binary serialization for [`ArrowScalar`]. +//! +//! Default format (with type prefix): +//! ```text +//! | varint: format_string_len | raw: format_string_bytes | +//! | varint: null_flag (0 = non-null, 1 = null) | +//! | varint: num_buffers | (only if non-null) +//! | varint: buffer_0_len | ... | varint: buffer_{n-1}_len | (only if non-null) +//! | raw: buffer_0 bytes | ... | raw: buffer_{n-1} bytes | (only if non-null) +//! ``` +//! +//! The format string uses the +//! [Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) +//! encoding. Use [`EncodeOptions`] / [`DecodeOptions`] to omit the type prefix +//! when the caller already knows the data type. + +use std::borrow::Cow; +use std::sync::Arc; + +use arrow_array::make_array; +use arrow_buffer::Buffer; +use arrow_data::ArrayDataBuilder; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::ArrowScalar; + +type Result = std::result::Result; + +/// Options for [`ArrowScalar::encode_with_options`]. +pub struct EncodeOptions { + /// When `true` (the default), the Arrow C Data Interface format string + /// for the scalar's data type is prepended as a varint-length-prefixed + /// UTF-8 string. Set to `false` to omit the type prefix (the caller + /// must then supply the `DataType` at decode time). + pub include_data_type: bool, +} + +impl Default for EncodeOptions { + fn default() -> Self { + Self { + include_data_type: true, + } + } +} + +/// Options for [`ArrowScalar::decode_with_options`]. +#[derive(Default)] +pub struct DecodeOptions<'a> { + /// When `Some`, the data type is taken from this value and the encoded + /// bytes are assumed to contain no type prefix. When `None` (the + /// default), the data type is read from the encoded format-string prefix. + pub data_type: Option<&'a DataType>, +} + +/// Encode a `u64` as a variable-length integer (LEB128). +/// +/// Values below 128 use a single byte; the maximum encoding is 10 bytes. +pub fn encode_varint(out: &mut Vec, mut value: u64) { + loop { + let byte = (value & 0x7F) as u8; + value >>= 7; + if value == 0 { + out.push(byte); + return; + } + out.push(byte | 0x80); + } +} + +/// Decode a variable-length integer (LEB128) from `buf` at the given `offset`. +/// +/// On success, `offset` is advanced past the consumed bytes. +pub fn decode_varint(buf: &[u8], offset: &mut usize) -> Result { + let mut result: u64 = 0; + let mut shift = 0u32; + loop { + if *offset >= buf.len() { + return Err(ArrowError::InvalidArgumentError( + "Invalid varint: unexpected EOF".to_string(), + )); + } + let byte = buf[*offset]; + *offset += 1; + + result |= u64::from(byte & 0x7F) << shift; + if byte & 0x80 == 0 { + return Ok(result); + } + shift += 7; + if shift >= 64 { + return Err(ArrowError::InvalidArgumentError( + "Invalid varint: too many bytes".to_string(), + )); + } + } +} + +/// Convert a [`DataType`] to its Arrow C Data Interface format string. +/// +/// Only non-nested types are supported (nested types are already rejected by +/// [`ArrowScalar::encode`]). +fn data_type_to_format_string(dtype: &DataType) -> Result> { + match dtype { + DataType::Null => Ok("n".into()), + DataType::Boolean => Ok("b".into()), + DataType::Int8 => Ok("c".into()), + DataType::UInt8 => Ok("C".into()), + DataType::Int16 => Ok("s".into()), + DataType::UInt16 => Ok("S".into()), + DataType::Int32 => Ok("i".into()), + DataType::UInt32 => Ok("I".into()), + DataType::Int64 => Ok("l".into()), + DataType::UInt64 => Ok("L".into()), + DataType::Float16 => Ok("e".into()), + DataType::Float32 => Ok("f".into()), + DataType::Float64 => Ok("g".into()), + DataType::Binary => Ok("z".into()), + DataType::LargeBinary => Ok("Z".into()), + DataType::Utf8 => Ok("u".into()), + DataType::LargeUtf8 => Ok("U".into()), + DataType::BinaryView => Ok("vz".into()), + DataType::Utf8View => Ok("vu".into()), + DataType::FixedSizeBinary(n) => Ok(Cow::Owned(format!("w:{n}"))), + DataType::Decimal32(p, s) => Ok(Cow::Owned(format!("d:{p},{s},32"))), + DataType::Decimal64(p, s) => Ok(Cow::Owned(format!("d:{p},{s},64"))), + DataType::Decimal128(p, s) => Ok(Cow::Owned(format!("d:{p},{s}"))), + DataType::Decimal256(p, s) => Ok(Cow::Owned(format!("d:{p},{s},256"))), + DataType::Date32 => Ok("tdD".into()), + DataType::Date64 => Ok("tdm".into()), + DataType::Time32(TimeUnit::Second) => Ok("tts".into()), + DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".into()), + DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".into()), + DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".into()), + DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".into()), + DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".into()), + DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".into()), + DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".into()), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(Cow::Owned(format!("tss:{tz}"))), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(Cow::Owned(format!("tsm:{tz}"))), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Cow::Owned(format!("tsu:{tz}"))), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(Cow::Owned(format!("tsn:{tz}"))), + DataType::Duration(TimeUnit::Second) => Ok("tDs".into()), + DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".into()), + DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".into()), + DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".into()), + DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".into()), + DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".into()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".into()), + other => Err(ArrowError::InvalidArgumentError(format!( + "Cannot encode data type as format string: {other:?}" + ))), + } +} + +/// Parse an Arrow C Data Interface format string back to a [`DataType`]. +/// +/// Only non-nested types are supported. +fn format_string_to_data_type(fmt: &str) -> Result { + match fmt { + "n" => Ok(DataType::Null), + "b" => Ok(DataType::Boolean), + "c" => Ok(DataType::Int8), + "C" => Ok(DataType::UInt8), + "s" => Ok(DataType::Int16), + "S" => Ok(DataType::UInt16), + "i" => Ok(DataType::Int32), + "I" => Ok(DataType::UInt32), + "l" => Ok(DataType::Int64), + "L" => Ok(DataType::UInt64), + "e" => Ok(DataType::Float16), + "f" => Ok(DataType::Float32), + "g" => Ok(DataType::Float64), + "z" => Ok(DataType::Binary), + "Z" => Ok(DataType::LargeBinary), + "u" => Ok(DataType::Utf8), + "U" => Ok(DataType::LargeUtf8), + "vz" => Ok(DataType::BinaryView), + "vu" => Ok(DataType::Utf8View), + "tdD" => Ok(DataType::Date32), + "tdm" => Ok(DataType::Date64), + "tts" => Ok(DataType::Time32(TimeUnit::Second)), + "ttm" => Ok(DataType::Time32(TimeUnit::Millisecond)), + "ttu" => Ok(DataType::Time64(TimeUnit::Microsecond)), + "ttn" => Ok(DataType::Time64(TimeUnit::Nanosecond)), + "tDs" => Ok(DataType::Duration(TimeUnit::Second)), + "tDm" => Ok(DataType::Duration(TimeUnit::Millisecond)), + "tDu" => Ok(DataType::Duration(TimeUnit::Microsecond)), + "tDn" => Ok(DataType::Duration(TimeUnit::Nanosecond)), + "tiM" => Ok(DataType::Interval(IntervalUnit::YearMonth)), + "tiD" => Ok(DataType::Interval(IntervalUnit::DayTime)), + "tin" => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + other => { + let parts: Vec<&str> = other.splitn(2, ':').collect(); + match parts.as_slice() { + ["w", num_bytes] => { + let n = num_bytes.parse::().map_err(|_| { + ArrowError::InvalidArgumentError( + "FixedSizeBinary requires an integer byte count".to_string(), + ) + })?; + Ok(DataType::FixedSizeBinary(n)) + } + ["d", extra] => { + let dec_parts: Vec<&str> = extra.splitn(3, ',').collect(); + match dec_parts.as_slice() { + [precision, scale] => { + let p = precision.parse::().map_err(|_| { + ArrowError::InvalidArgumentError( + "Decimal requires an integer precision".to_string(), + ) + })?; + let s = scale.parse::().map_err(|_| { + ArrowError::InvalidArgumentError( + "Decimal requires an integer scale".to_string(), + ) + })?; + Ok(DataType::Decimal128(p, s)) + } + [precision, scale, bits] => { + let p = precision.parse::().map_err(|_| { + ArrowError::InvalidArgumentError( + "Decimal requires an integer precision".to_string(), + ) + })?; + let s = scale.parse::().map_err(|_| { + ArrowError::InvalidArgumentError( + "Decimal requires an integer scale".to_string(), + ) + })?; + match *bits { + "32" => Ok(DataType::Decimal32(p, s)), + "64" => Ok(DataType::Decimal64(p, s)), + "128" => Ok(DataType::Decimal128(p, s)), + "256" => Ok(DataType::Decimal256(p, s)), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported decimal bit width: {bits}" + ))), + } + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format string: d:{extra}" + ))), + } + } + ["tss", ""] => Ok(DataType::Timestamp(TimeUnit::Second, None)), + ["tsm", ""] => Ok(DataType::Timestamp(TimeUnit::Millisecond, None)), + ["tsu", ""] => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)), + ["tsn", ""] => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), + ["tss", tz] => Ok(DataType::Timestamp(TimeUnit::Second, Some(Arc::from(*tz)))), + ["tsm", tz] => Ok(DataType::Timestamp( + TimeUnit::Millisecond, + Some(Arc::from(*tz)), + )), + ["tsu", tz] => Ok(DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from(*tz)), + )), + ["tsn", tz] => Ok(DataType::Timestamp( + TimeUnit::Nanosecond, + Some(Arc::from(*tz)), + )), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported format string: {other:?}" + ))), + } + } + } +} + +impl ArrowScalar { + /// Serialize this scalar to a self-describing binary representation. + /// + /// The data type is encoded as a format-string prefix so that + /// [`decode`](Self::decode) can reconstruct the scalar without external + /// type information. Use [`encode_with_options`](Self::encode_with_options) + /// to omit the prefix when the caller already knows the type. + /// + /// Only non-nested scalars are supported. Null scalars are encoded as a + /// null flag with no buffer data. + pub fn encode(&self) -> Result> { + self.encode_with_options(&EncodeOptions::default()) + } + + /// Serialize this scalar with the given [`EncodeOptions`]. + pub fn encode_with_options(&self, options: &EncodeOptions) -> Result> { + let array = self.as_array(); + let data = array.to_data(); + if !data.child_data().is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Cannot encode nested scalar".to_string(), + )); + } + + let mut out = Vec::with_capacity(64); + + if options.include_data_type { + let fmt = data_type_to_format_string(array.data_type())?; + encode_varint(&mut out, fmt.len() as u64); + out.extend_from_slice(fmt.as_bytes()); + } + + if self.is_null() { + encode_varint(&mut out, 1); // null_flag = 1 + } else { + encode_varint(&mut out, 0); // null_flag = 0 + let buffers = data.buffers(); + encode_varint(&mut out, buffers.len() as u64); + for b in buffers { + encode_varint(&mut out, b.len() as u64); + } + for b in buffers { + out.extend_from_slice(b.as_slice()); + } + } + Ok(out) + } + + /// Deserialize a scalar from the self-describing binary representation + /// produced by [`encode`](Self::encode). + /// + /// The data type is read from the format-string prefix in the encoded + /// bytes. Use [`decode_with_options`](Self::decode_with_options) to supply + /// the type externally when the prefix was omitted at encode time. + pub fn decode(buf: &[u8]) -> Result { + Self::decode_with_options(buf, &DecodeOptions::default()) + } + + /// Deserialize a scalar with the given [`DecodeOptions`]. + pub fn decode_with_options(buf: &[u8], options: &DecodeOptions) -> Result { + let mut offset = 0; + + let data_type = match options.data_type { + Some(dt) => dt.clone(), + None => { + let fmt_len = decode_varint(buf, &mut offset)? as usize; + if offset + fmt_len > buf.len() { + return Err(ArrowError::InvalidArgumentError( + "Invalid scalar buffer: unexpected EOF reading format string".to_string(), + )); + } + let fmt_str = std::str::from_utf8(&buf[offset..offset + fmt_len]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid format string: not valid UTF-8: {e}" + )) + })?; + offset += fmt_len; + format_string_to_data_type(fmt_str)? + } + }; + + let null_flag = decode_varint(buf, &mut offset)?; + if null_flag == 1 { + if offset != buf.len() { + return Err(ArrowError::InvalidArgumentError( + "Invalid scalar buffer: trailing bytes after null flag".to_string(), + )); + } + return Self::new_null(&data_type); + } + + let num_buffers = decode_varint(buf, &mut offset)? as usize; + + let mut buffer_lens = Vec::with_capacity(num_buffers); + for _ in 0..num_buffers { + buffer_lens.push(decode_varint(buf, &mut offset)? as usize); + } + + let mut buffers = Vec::with_capacity(num_buffers); + for len in &buffer_lens { + if offset + len > buf.len() { + return Err(ArrowError::InvalidArgumentError( + "Invalid scalar buffer: unexpected EOF".to_string(), + )); + } + buffers.push(Buffer::from_vec(buf[offset..offset + len].to_vec())); + offset += len; + } + + if offset != buf.len() { + return Err(ArrowError::InvalidArgumentError( + "Invalid scalar buffer: trailing bytes".to_string(), + )); + } + + let mut builder = ArrayDataBuilder::new(data_type).len(1).null_count(0); + for b in buffers { + builder = builder.add_buffer(b); + } + let array = make_array(builder.build()?); + Self::try_from_array(array) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ + ArrayRef, BinaryViewArray, Int32Array, StringArray, StringViewArray, + TimestampMicrosecondArray, + }; + use arrow_schema::DataType; + use rstest::rstest; + + use super::*; + use crate::ArrowScalar; + + #[test] + fn test_varint_roundtrip() { + for value in [0u64, 1, 127, 128, 16383, 16384, u64::MAX] { + let mut buf = Vec::new(); + encode_varint(&mut buf, value); + let mut offset = 0; + let decoded = decode_varint(&buf, &mut offset).unwrap(); + assert_eq!(decoded, value); + assert_eq!(offset, buf.len()); + } + } + + #[test] + fn test_varint_small_is_one_byte() { + let mut buf = Vec::new(); + encode_varint(&mut buf, 42); + assert_eq!(buf.len(), 1); + assert_eq!(buf[0], 42); + } + + #[rstest] + #[case::int32(Arc::new(Int32Array::from(vec![42])) as ArrayRef)] + #[case::string(Arc::new(StringArray::from(vec!["hello"])) as ArrayRef)] + #[case::string_view(Arc::new(StringViewArray::from(vec!["hello world, long string view"])) as ArrayRef)] + #[case::binary_view(Arc::new(BinaryViewArray::from(vec![b"\xDE\xAD\xBE\xEF".as_ref()])) as ArrayRef)] + fn test_encode_decode_roundtrip(#[case] array: ArrayRef) { + let scalar = ArrowScalar::try_from_array(array).unwrap(); + let encoded = scalar.encode().unwrap(); + let decoded = ArrowScalar::decode(&encoded).unwrap(); + assert_eq!(scalar, decoded); + assert_eq!(scalar.data_type(), decoded.data_type()); + } + + #[rstest] + #[case::int32(Arc::new(Int32Array::from(vec![42])) as ArrayRef, DataType::Int32)] + #[case::string(Arc::new(StringArray::from(vec!["hello"])) as ArrayRef, DataType::Utf8)] + #[case::string_view(Arc::new(StringViewArray::from(vec!["hello view"])) as ArrayRef, DataType::Utf8View)] + #[case::binary_view(Arc::new(BinaryViewArray::from(vec![b"\xCA\xFE".as_ref()])) as ArrayRef, DataType::BinaryView)] + fn test_encode_decode_without_type_prefix(#[case] array: ArrayRef, #[case] dt: DataType) { + let scalar = ArrowScalar::try_from_array(array).unwrap(); + let opts = EncodeOptions { + include_data_type: false, + }; + let encoded = scalar.encode_with_options(&opts).unwrap(); + let decode_opts = DecodeOptions { + data_type: Some(&dt), + }; + let decoded = ArrowScalar::decode_with_options(&encoded, &decode_opts).unwrap(); + assert_eq!(scalar, decoded); + } + + #[test] + fn test_null_encode_decode_roundtrip() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert!(scalar.is_null()); + let encoded = scalar.encode().unwrap(); + let decoded = ArrowScalar::decode(&encoded).unwrap(); + assert!(decoded.is_null()); + assert_eq!(decoded.data_type(), &DataType::Int32); + assert_eq!(scalar, decoded); + } + + #[test] + fn test_null_encode_decode_without_type_prefix() { + let array: ArrayRef = Arc::new(StringArray::from(vec![Option::<&str>::None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + let opts = EncodeOptions { + include_data_type: false, + }; + let encoded = scalar.encode_with_options(&opts).unwrap(); + let decode_opts = DecodeOptions { + data_type: Some(&DataType::Utf8), + }; + let decoded = ArrowScalar::decode_with_options(&encoded, &decode_opts).unwrap(); + assert!(decoded.is_null()); + assert_eq!(decoded.data_type(), &DataType::Utf8); + } + + #[test] + fn test_decode_trailing_bytes() { + let scalar = ArrowScalar::from(42i32); + let mut encoded = scalar.encode().unwrap(); + encoded.push(0xFF); + assert!(ArrowScalar::decode(&encoded).is_err()); + } + + #[test] + fn test_encoded_bytes_contain_format_prefix() { + let scalar = ArrowScalar::from(42i32); + let encoded = scalar.encode().unwrap(); + // First byte is varint length of format string "i" (length 1) + assert_eq!(encoded[0], 1); + // Second byte is the format string itself + assert_eq!(encoded[1], b'i'); + } + + #[rstest] + #[case::null(DataType::Null, "n")] + #[case::boolean(DataType::Boolean, "b")] + #[case::int8(DataType::Int8, "c")] + #[case::uint8(DataType::UInt8, "C")] + #[case::int16(DataType::Int16, "s")] + #[case::uint16(DataType::UInt16, "S")] + #[case::int32(DataType::Int32, "i")] + #[case::uint32(DataType::UInt32, "I")] + #[case::int64(DataType::Int64, "l")] + #[case::uint64(DataType::UInt64, "L")] + #[case::float16(DataType::Float16, "e")] + #[case::float32(DataType::Float32, "f")] + #[case::float64(DataType::Float64, "g")] + #[case::binary(DataType::Binary, "z")] + #[case::large_binary(DataType::LargeBinary, "Z")] + #[case::utf8(DataType::Utf8, "u")] + #[case::large_utf8(DataType::LargeUtf8, "U")] + #[case::binary_view(DataType::BinaryView, "vz")] + #[case::utf8_view(DataType::Utf8View, "vu")] + #[case::date32(DataType::Date32, "tdD")] + #[case::date64(DataType::Date64, "tdm")] + #[case::fixed_size_binary(DataType::FixedSizeBinary(16), "w:16")] + #[case::decimal128(DataType::Decimal128(10, 2), "d:10,2")] + #[case::decimal256(DataType::Decimal256(38, 10), "d:38,10,256")] + #[case::timestamp_us_utc( + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + "tsu:UTC" + )] + #[case::timestamp_ns_none(DataType::Timestamp(TimeUnit::Nanosecond, None), "tsn:")] + #[case::duration_s(DataType::Duration(TimeUnit::Second), "tDs")] + #[case::interval_ym(DataType::Interval(IntervalUnit::YearMonth), "tiM")] + fn test_format_string_roundtrip(#[case] dt: DataType, #[case] expected_fmt: &str) { + let fmt = data_type_to_format_string(&dt).unwrap(); + assert_eq!(fmt.as_ref(), expected_fmt); + let roundtripped = format_string_to_data_type(&fmt).unwrap(); + assert_eq!(roundtripped, dt); + } + + #[test] + fn test_timestamp_with_tz_roundtrip() { + let array: ArrayRef = Arc::new( + TimestampMicrosecondArray::from(vec![1_000_000]).with_timezone("America/New_York"), + ); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + let encoded = scalar.encode().unwrap(); + let decoded = ArrowScalar::decode(&encoded).unwrap(); + assert_eq!(scalar, decoded); + assert_eq!(scalar.data_type(), decoded.data_type()); + } +}