From 430ddfcab73570ab67466a53aa9f3f4430b6abe3 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 28 Jan 2026 19:11:28 -0800 Subject: [PATCH] add an arrow-scalar crate with a scalar definition --- Cargo.lock | 13 + Cargo.toml | 2 + rust/arrow-scalar/Cargo.toml | 29 ++ rust/arrow-scalar/src/bytes.rs | 482 +++++++++++++++++++++++ rust/arrow-scalar/src/cmp.rs | 477 +++++++++++++++++++++++ rust/arrow-scalar/src/convert.rs | 632 +++++++++++++++++++++++++++++++ rust/arrow-scalar/src/display.rs | 319 ++++++++++++++++ rust/arrow-scalar/src/lib.rs | 47 +++ rust/arrow-scalar/src/scalar.rs | 341 +++++++++++++++++ 9 files changed, 2342 insertions(+) create mode 100644 rust/arrow-scalar/Cargo.toml create mode 100644 rust/arrow-scalar/src/bytes.rs create mode 100644 rust/arrow-scalar/src/cmp.rs create mode 100644 rust/arrow-scalar/src/convert.rs create mode 100644 rust/arrow-scalar/src/display.rs create mode 100644 rust/arrow-scalar/src/lib.rs create mode 100644 rust/arrow-scalar/src/scalar.rs diff --git a/Cargo.lock b/Cargo.lock index cfcc4899c96..92242412c68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,19 @@ dependencies = [ "half", ] +[[package]] +name = "arrow-scalar" +version = "3.0.0-beta.2" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "rstest 0.23.0", +] + [[package]] name = "arrow-schema" version = "57.2.0" diff --git a/Cargo.toml b/Cargo.toml index b080335a735..25040ff369a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "rust/arrow-scalar", "rust/examples", "rust/lance", "rust/lance-arrow", @@ -80,6 +81,7 @@ arrow-data = "57.0.0" arrow-ipc = { version = "57.0.0", features = ["zstd"] } arrow-ord = "57.0.0" arrow-row = "57.0.0" +arrow-scalar = { version = "57.0.0", path = "./rust/arrow-scalar" } arrow-schema = "57.0.0" arrow-select = "57.0.0" async-recursion = "1.0" diff --git a/rust/arrow-scalar/Cargo.toml b/rust/arrow-scalar/Cargo.toml new file mode 100644 index 00000000000..c12597d6e2c --- /dev/null +++ b/rust/arrow-scalar/Cargo.toml @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +[package] +name = "arrow-scalar" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +readme.workspace = true +description = "Scalar value representation for Apache Arrow types" +keywords = ["arrow", "scalar", "data-format"] +categories = ["data-structures"] +rust-version.workspace = true + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +arrow-select = { workspace = true } +half = { workspace = true } + +[dev-dependencies] +rstest = { workspace = true } + +[lints] +workspace = true diff --git a/rust/arrow-scalar/src/bytes.rs b/rust/arrow-scalar/src/bytes.rs new file mode 100644 index 00000000000..7698ddebb39 --- /dev/null +++ b/rust/arrow-scalar/src/bytes.rs @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Byte serialization and deserialization for Scalar values. +//! +//! Format for primitives: +//! ```text +//! | is_null (1 byte) | value bytes (if not null) | +//! ``` +//! +//! Format for variable-length types (Utf8, Binary, Utf8View, BinaryView, FixedSizeBinary): +//! ```text +//! | is_null (1 byte) | u32_le length | value bytes | +//! ``` +//! +//! Format for large variable-length types (LargeUtf8, LargeBinary): +//! ```text +//! | is_null (1 byte) | u64_le length | value bytes | +//! ``` + +use arrow_buffer::{i256, Buffer}; +use arrow_schema::{ArrowError, DataType, IntervalUnit}; +use half::f16; + +use crate::Scalar; + +type Result = std::result::Result; + +const NULL_MARKER: u8 = 0; +const NON_NULL_MARKER: u8 = 1; + +/// Serializes a scalar value to bytes. +impl Scalar { + /// Converts this scalar to a byte representation. + /// + /// The format is designed for simple serialization of scalar values, + /// primarily for use in indexing and storage scenarios. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + self.write_bytes(&mut buf); + buf + } + + fn write_bytes(&self, buf: &mut Vec) { + use Scalar::*; + + match self { + Null => buf.push(NULL_MARKER), + Boolean(v) => write_opt_primitive(buf, v.map(u8::from)), + Int8(v) => write_opt_primitive(buf, v.map(|x| x as u8)), + Int16(v) => write_opt_le(buf, *v), + Int32(v) => write_opt_le(buf, *v), + Int64(v) => write_opt_le(buf, *v), + UInt8(v) => write_opt_primitive(buf, *v), + UInt16(v) => write_opt_le(buf, *v), + UInt32(v) => write_opt_le(buf, *v), + UInt64(v) => write_opt_le(buf, *v), + Float16(v) => write_opt_le(buf, v.map(|f| f.to_bits())), + Float32(v) => write_opt_le(buf, v.map(|f| f.to_bits())), + Float64(v) => write_opt_le(buf, v.map(|f| f.to_bits())), + Decimal128(v, _, _) => write_opt_le(buf, *v), + Decimal256(v, _, _) => match v { + None => buf.push(NULL_MARKER), + Some(val) => { + buf.push(NON_NULL_MARKER); + buf.extend_from_slice(&val.to_le_bytes()); + } + }, + Utf8(v) | Utf8View(v) | Binary(v) | BinaryView(v) => write_opt_bytes(buf, v.as_deref()), + LargeUtf8(v) | LargeBinary(v) => write_opt_bytes_large(buf, v.as_deref()), + FixedSizeBinary(_, v) => write_opt_bytes(buf, v.as_deref()), + Date32(v) => write_opt_le(buf, *v), + Date64(v) => write_opt_le(buf, *v), + Time32(v, _) => write_opt_le(buf, *v), + Time64(v, _) => write_opt_le(buf, *v), + Timestamp(v, _, _) => write_opt_le(buf, *v), + Duration(v, _) => write_opt_le(buf, *v), + IntervalYearMonth(v) => write_opt_le(buf, *v), + IntervalDayTime(v) => write_opt_le(buf, *v), + IntervalMonthDayNano(v) => write_opt_le(buf, *v), + List(_) + | LargeList(_) + | FixedSizeList(_) + | Struct(_, _) + | Map(_) + | Dictionary(_, _) => { + panic!( + "Complex types (List, Struct, Map, Dictionary) do not support byte serialization" + ); + } + } + } + + /// Deserializes a scalar from bytes given its data type. + pub fn from_bytes(data_type: &DataType, bytes: &[u8]) -> Result { + let mut offset = 0; + Self::read_bytes(data_type, bytes, &mut offset) + } + + fn read_bytes(data_type: &DataType, bytes: &[u8], offset: &mut usize) -> Result { + match data_type { + DataType::Null => { + read_null_marker(bytes, offset)?; + Ok(Self::Null) + } + DataType::Boolean => { + let v = read_opt_primitive(bytes, offset)?; + Ok(Self::Boolean(v.map(|b| b != 0))) + } + DataType::Int8 => { + let v = read_opt_primitive(bytes, offset)?; + Ok(Self::Int8(v.map(|b| b as i8))) + } + DataType::Int16 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Int16(v)) + } + DataType::Int32 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Int32(v)) + } + DataType::Int64 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Int64(v)) + } + DataType::UInt8 => { + let v = read_opt_primitive(bytes, offset)?; + Ok(Self::UInt8(v)) + } + DataType::UInt16 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::UInt16(v)) + } + DataType::UInt32 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::UInt32(v)) + } + DataType::UInt64 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::UInt64(v)) + } + DataType::Float16 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Float16(v.map(f16::from_bits))) + } + DataType::Float32 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Float32(v.map(f32::from_bits))) + } + DataType::Float64 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Float64(v.map(f64::from_bits))) + } + DataType::Decimal128(precision, scale) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Decimal128(v, *precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + let is_null = read_null_marker(bytes, offset)?; + if is_null { + Ok(Self::Decimal256(None, *precision, *scale)) + } else { + let val_bytes = read_exact(bytes, offset, 32)?; + let val = i256::from_le_bytes(val_bytes.try_into().unwrap()); + Ok(Self::Decimal256(Some(val), *precision, *scale)) + } + } + DataType::Utf8 => { + let v = read_opt_bytes(bytes, offset)?; + if let Some(b) = &v { + std::str::from_utf8(b).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UTF-8: {}", e)) + })?; + } + Ok(Self::Utf8(v.map(Buffer::from))) + } + DataType::LargeUtf8 => { + let v = read_opt_bytes_large(bytes, offset)?; + if let Some(b) = &v { + std::str::from_utf8(b).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UTF-8: {}", e)) + })?; + } + Ok(Self::LargeUtf8(v.map(Buffer::from))) + } + DataType::Utf8View => { + let v = read_opt_bytes(bytes, offset)?; + if let Some(b) = &v { + std::str::from_utf8(b).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UTF-8: {}", e)) + })?; + } + Ok(Self::Utf8View(v.map(Buffer::from))) + } + DataType::Binary => { + let v = read_opt_bytes(bytes, offset)?; + Ok(Self::Binary(v.map(Buffer::from))) + } + DataType::LargeBinary => { + let v = read_opt_bytes_large(bytes, offset)?; + Ok(Self::LargeBinary(v.map(Buffer::from))) + } + DataType::BinaryView => { + let v = read_opt_bytes(bytes, offset)?; + Ok(Self::BinaryView(v.map(Buffer::from))) + } + DataType::FixedSizeBinary(size) => { + let v = read_opt_bytes(bytes, offset)?; + Ok(Self::FixedSizeBinary(*size, v.map(Buffer::from))) + } + DataType::Date32 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Date32(v)) + } + DataType::Date64 => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Date64(v)) + } + DataType::Time32(unit) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Time32(v, *unit)) + } + DataType::Time64(unit) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Time64(v, *unit)) + } + DataType::Timestamp(unit, tz) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Timestamp(v, *unit, tz.clone())) + } + DataType::Duration(unit) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::Duration(v, *unit)) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::IntervalYearMonth(v)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::IntervalDayTime(v)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let v = read_opt_le::(bytes, offset)?; + Ok(Self::IntervalMonthDayNano(v)) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Byte deserialization not implemented for {:?}", + data_type + ))), + } + } +} + +// Helper functions for writing + +fn write_opt_primitive(buf: &mut Vec, v: Option) { + match v { + None => buf.push(NULL_MARKER), + Some(val) => { + buf.push(NON_NULL_MARKER); + buf.push(val); + } + } +} + +fn write_opt_le(buf: &mut Vec, v: Option) { + match v { + None => buf.push(NULL_MARKER), + Some(val) => { + buf.push(NON_NULL_MARKER); + buf.extend_from_slice(&val.to_le_bytes_vec()); + } + } +} + +fn write_opt_bytes(buf: &mut Vec, v: Option<&[u8]>) { + match v { + None => buf.push(NULL_MARKER), + Some(bytes) => { + buf.push(NON_NULL_MARKER); + buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(bytes); + } + } +} + +fn write_opt_bytes_large(buf: &mut Vec, v: Option<&[u8]>) { + match v { + None => buf.push(NULL_MARKER), + Some(bytes) => { + buf.push(NON_NULL_MARKER); + buf.extend_from_slice(&(bytes.len() as u64).to_le_bytes()); + buf.extend_from_slice(bytes); + } + } +} + +// Helper functions for reading + +fn read_null_marker(bytes: &[u8], offset: &mut usize) -> Result { + if *offset >= bytes.len() { + return Err(ArrowError::InvalidArgumentError( + "Unexpected end of bytes".to_string(), + )); + } + let marker = bytes[*offset]; + *offset += 1; + Ok(marker == NULL_MARKER) +} + +fn read_exact<'a>(bytes: &'a [u8], offset: &mut usize, len: usize) -> Result<&'a [u8]> { + if *offset + len > bytes.len() { + return Err(ArrowError::InvalidArgumentError( + "Unexpected end of bytes".to_string(), + )); + } + let slice = &bytes[*offset..*offset + len]; + *offset += len; + Ok(slice) +} + +fn read_opt_primitive(bytes: &[u8], offset: &mut usize) -> Result> { + let is_null = read_null_marker(bytes, offset)?; + if is_null { + Ok(None) + } else { + if *offset >= bytes.len() { + return Err(ArrowError::InvalidArgumentError( + "Unexpected end of bytes".to_string(), + )); + } + let val = bytes[*offset]; + *offset += 1; + Ok(Some(val)) + } +} + +fn read_opt_le(bytes: &[u8], offset: &mut usize) -> Result> { + let is_null = read_null_marker(bytes, offset)?; + if is_null { + Ok(None) + } else { + let size = std::mem::size_of::(); + let val_bytes = read_exact(bytes, offset, size)?; + Ok(Some(T::from_le_bytes_slice(val_bytes))) + } +} + +fn read_opt_bytes(bytes: &[u8], offset: &mut usize) -> Result>> { + let is_null = read_null_marker(bytes, offset)?; + if is_null { + Ok(None) + } else { + let len_bytes = read_exact(bytes, offset, 4)?; + let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize; + let val_bytes = read_exact(bytes, offset, len)?; + Ok(Some(val_bytes.to_vec())) + } +} + +fn read_opt_bytes_large(bytes: &[u8], offset: &mut usize) -> Result>> { + let is_null = read_null_marker(bytes, offset)?; + if is_null { + Ok(None) + } else { + let len_bytes = read_exact(bytes, offset, 8)?; + let len = u64::from_le_bytes(len_bytes.try_into().unwrap()) as usize; + let val_bytes = read_exact(bytes, offset, len)?; + Ok(Some(val_bytes.to_vec())) + } +} + +// Traits for generic le bytes conversion + +trait ToLeBytes { + fn to_le_bytes_vec(&self) -> Vec; +} + +macro_rules! impl_to_le_bytes { + ($($t:ty),*) => { + $( + impl ToLeBytes for $t { + fn to_le_bytes_vec(&self) -> Vec { + self.to_le_bytes().to_vec() + } + } + )* + }; +} + +impl_to_le_bytes!(i16, i32, i64, i128, u16, u32, u64); + +trait FromLeBytes: Sized { + fn from_le_bytes_slice(bytes: &[u8]) -> Self; +} + +macro_rules! impl_from_le_bytes { + ($($t:ty),*) => { + $( + impl FromLeBytes for $t { + fn from_le_bytes_slice(bytes: &[u8]) -> Self { + let arr: [u8; std::mem::size_of::<$t>()] = bytes.try_into().unwrap(); + Self::from_le_bytes(arr) + } + } + )* + }; +} + +impl_from_le_bytes!(i16, i32, i64, i128, u16, u32, u64); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::TimeUnit; + use rstest::rstest; + use std::sync::Arc; + + #[rstest] + #[case::null(Scalar::Null, DataType::Null)] + #[case::bool_true(Scalar::Boolean(Some(true)), DataType::Boolean)] + #[case::bool_false(Scalar::Boolean(Some(false)), DataType::Boolean)] + #[case::bool_null(Scalar::Boolean(None), DataType::Boolean)] + #[case::int32(Scalar::Int32(Some(42)), DataType::Int32)] + #[case::int32_neg(Scalar::Int32(Some(-42)), DataType::Int32)] + #[case::int32_null(Scalar::Int32(None), DataType::Int32)] + #[case::int64(Scalar::Int64(Some(1234567890123)), DataType::Int64)] + #[case::float32(Scalar::Float32(Some(3.14)), DataType::Float32)] + #[case::float64(Scalar::Float64(Some(2.718281828)), DataType::Float64)] + #[case::float64_nan(Scalar::Float64(Some(f64::NAN)), DataType::Float64)] + #[case::utf8(Scalar::Utf8(Some(Buffer::from(b"hello".as_ref()))), DataType::Utf8)] + #[case::utf8_empty(Scalar::Utf8(Some(Buffer::from(b"".as_ref()))), DataType::Utf8)] + #[case::utf8_null(Scalar::Utf8(None), DataType::Utf8)] + #[case::large_utf8(Scalar::LargeUtf8(Some(Buffer::from(b"hello large".as_ref()))), DataType::LargeUtf8)] + #[case::large_utf8_null(Scalar::LargeUtf8(None), DataType::LargeUtf8)] + #[case::binary(Scalar::Binary(Some(Buffer::from(vec![1u8, 2, 3]))), DataType::Binary)] + #[case::large_binary(Scalar::LargeBinary(Some(Buffer::from(vec![4u8, 5, 6]))), DataType::LargeBinary)] + #[case::large_binary_null(Scalar::LargeBinary(None), DataType::LargeBinary)] + #[case::date32(Scalar::Date32(Some(19000)), DataType::Date32)] + fn test_round_trip(#[case] scalar: Scalar, #[case] data_type: DataType) { + let bytes = scalar.to_bytes(); + let decoded = Scalar::from_bytes(&data_type, &bytes).unwrap(); + + // For floats, compare bit patterns to handle NaN + match (&scalar, &decoded) { + (Scalar::Float32(Some(a)), Scalar::Float32(Some(b))) => { + assert_eq!(a.to_bits(), b.to_bits()); + } + (Scalar::Float64(Some(a)), Scalar::Float64(Some(b))) => { + assert_eq!(a.to_bits(), b.to_bits()); + } + _ => { + assert_eq!(scalar, decoded); + } + } + } + + #[test] + fn test_decimal128_round_trip() { + let scalar = Scalar::Decimal128(Some(12345678901234567890), 38, 10); + let bytes = scalar.to_bytes(); + let decoded = Scalar::from_bytes(&DataType::Decimal128(38, 10), &bytes).unwrap(); + assert_eq!(scalar, decoded); + } + + #[test] + fn test_timestamp_round_trip() { + let scalar = Scalar::Timestamp( + Some(1234567890123456789), + TimeUnit::Nanosecond, + Some(Arc::from("UTC")), + ); + let bytes = scalar.to_bytes(); + let decoded = Scalar::from_bytes( + &DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))), + &bytes, + ) + .unwrap(); + assert_eq!(scalar, decoded); + } +} diff --git a/rust/arrow-scalar/src/cmp.rs b/rust/arrow-scalar/src/cmp.rs new file mode 100644 index 00000000000..8e57e8bb45b --- /dev/null +++ b/rust/arrow-scalar/src/cmp.rs @@ -0,0 +1,477 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Comparison trait implementations for Scalar. +//! +//! This module provides `PartialEq`, `Eq`, `PartialOrd`, `Ord`, and `Hash` +//! implementations for `Scalar` values. +//! +//! Key semantics: +//! - NULL == NULL (for equality purposes) +//! - NaN == NaN (using total_cmp semantics for floats) +//! - Nulls sort first (less than all non-null values) +//! - Floats use total_cmp() for ordering + +use std::cmp::Ordering; +use std::hash::{Hash, Hasher}; + +use crate::Scalar; + +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + use Scalar::*; + match (self, other) { + (Null, Null) => true, + (Boolean(a), Boolean(b)) => a == b, + (Int8(a), Int8(b)) => a == b, + (Int16(a), Int16(b)) => a == b, + (Int32(a), Int32(b)) => a == b, + (Int64(a), Int64(b)) => a == b, + (UInt8(a), UInt8(b)) => a == b, + (UInt16(a), UInt16(b)) => a == b, + (UInt32(a), UInt32(b)) => a == b, + (UInt64(a), UInt64(b)) => a == b, + (Float16(a), Float16(b)) => match (a, b) { + (Some(x), Some(y)) => x.to_bits() == y.to_bits(), + (None, None) => true, + _ => false, + }, + (Float32(a), Float32(b)) => match (a, b) { + (Some(x), Some(y)) => x.to_bits() == y.to_bits(), + (None, None) => true, + _ => false, + }, + (Float64(a), Float64(b)) => match (a, b) { + (Some(x), Some(y)) => x.to_bits() == y.to_bits(), + (None, None) => true, + _ => false, + }, + (Decimal128(a, p1, s1), Decimal128(b, p2, s2)) => a == b && p1 == p2 && s1 == s2, + (Decimal256(a, p1, s1), Decimal256(b, p2, s2)) => a == b && p1 == p2 && s1 == s2, + (Utf8(a), Utf8(b)) => a == b, + (LargeUtf8(a), LargeUtf8(b)) => a == b, + (Utf8View(a), Utf8View(b)) => a == b, + (Binary(a), Binary(b)) => a == b, + (LargeBinary(a), LargeBinary(b)) => a == b, + (BinaryView(a), BinaryView(b)) => a == b, + (FixedSizeBinary(s1, a), FixedSizeBinary(s2, b)) => s1 == s2 && a == b, + (Date32(a), Date32(b)) => a == b, + (Date64(a), Date64(b)) => a == b, + (Time32(a, u1), Time32(b, u2)) => a == b && u1 == u2, + (Time64(a, u1), Time64(b, u2)) => a == b && u1 == u2, + (Timestamp(a, u1, tz1), Timestamp(b, u2, tz2)) => a == b && u1 == u2 && tz1 == tz2, + (Duration(a, u1), Duration(b, u2)) => a == b && u1 == u2, + (IntervalYearMonth(a), IntervalYearMonth(b)) => a == b, + (IntervalDayTime(a), IntervalDayTime(b)) => a == b, + (IntervalMonthDayNano(a), IntervalMonthDayNano(b)) => a == b, + (List(a), List(b)) => a == b, + (LargeList(a), LargeList(b)) => a == b, + (FixedSizeList(a), FixedSizeList(b)) => a == b, + (Struct(f1, v1), Struct(f2, v2)) => f1 == f2 && v1 == v2, + (Map(a), Map(b)) => a == b, + (Dictionary(k1, v1), Dictionary(k2, v2)) => k1 == k2 && v1 == v2, + _ => false, + } + } +} + +impl Eq for Scalar {} + +impl PartialOrd for Scalar { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Scalar { + fn cmp(&self, other: &Self) -> Ordering { + use Scalar::*; + + // Helper macro for comparing Option where T: Ord + macro_rules! cmp_opt { + ($a:expr, $b:expr) => { + match ($a, $b) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(x), Some(y)) => x.cmp(y), + } + }; + } + + // Helper macro for comparing Option using total_cmp + macro_rules! cmp_float { + ($a:expr, $b:expr) => { + match ($a, $b) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(x), Some(y)) => x.total_cmp(y), + } + }; + } + + // Helper macro for comparing Option (Buffer doesn't impl Ord) + macro_rules! cmp_opt_buf { + ($a:expr, $b:expr) => { + match ($a, $b) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(x), Some(y)) => x.as_slice().cmp(y.as_slice()), + } + }; + } + + match (self, other) { + (Null, Null) => Ordering::Equal, + (Null, _) => { + if other.is_null() { + Ordering::Equal + } else { + Ordering::Less + } + } + (_, Null) => { + if self.is_null() { + Ordering::Equal + } else { + Ordering::Greater + } + } + + (Boolean(a), Boolean(b)) => cmp_opt!(a, b), + + (Int8(a), Int8(b)) => cmp_opt!(a, b), + (Int16(a), Int16(b)) => cmp_opt!(a, b), + (Int32(a), Int32(b)) => cmp_opt!(a, b), + (Int64(a), Int64(b)) => cmp_opt!(a, b), + + (UInt8(a), UInt8(b)) => cmp_opt!(a, b), + (UInt16(a), UInt16(b)) => cmp_opt!(a, b), + (UInt32(a), UInt32(b)) => cmp_opt!(a, b), + (UInt64(a), UInt64(b)) => cmp_opt!(a, b), + + (Float16(a), Float16(b)) => cmp_float!(a, b), + (Float32(a), Float32(b)) => cmp_float!(a, b), + (Float64(a), Float64(b)) => cmp_float!(a, b), + + (Decimal128(a, p1, s1), Decimal128(b, p2, s2)) => { + if p1 != p2 || s1 != s2 { + panic!( + "Cannot compare Decimal128 with different precision/scale: ({}, {}) vs ({}, {})", + p1, s1, p2, s2 + ); + } + cmp_opt!(a, b) + } + (Decimal256(a, p1, s1), Decimal256(b, p2, s2)) => { + if p1 != p2 || s1 != s2 { + panic!( + "Cannot compare Decimal256 with different precision/scale: ({}, {}) vs ({}, {})", + p1, s1, p2, s2 + ); + } + cmp_opt!(a, b) + } + + (Utf8(a), Utf8(b)) => cmp_opt_buf!(a, b), + (LargeUtf8(a), LargeUtf8(b)) => cmp_opt_buf!(a, b), + (Utf8View(a), Utf8View(b)) => cmp_opt_buf!(a, b), + // Allow comparing different string types + (Utf8(a) | LargeUtf8(a) | Utf8View(a), Utf8(b) | LargeUtf8(b) | Utf8View(b)) => { + cmp_opt_buf!(a, b) + } + + (Binary(a), Binary(b)) => cmp_opt_buf!(a, b), + (LargeBinary(a), LargeBinary(b)) => cmp_opt_buf!(a, b), + (BinaryView(a), BinaryView(b)) => cmp_opt_buf!(a, b), + // Allow comparing different binary types + ( + Binary(a) | LargeBinary(a) | BinaryView(a), + Binary(b) | LargeBinary(b) | BinaryView(b), + ) => cmp_opt_buf!(a, b), + + (FixedSizeBinary(s1, a), FixedSizeBinary(s2, b)) => { + if s1 != s2 { + panic!( + "Cannot compare FixedSizeBinary with different sizes: {} vs {}", + s1, s2 + ); + } + cmp_opt_buf!(a, b) + } + + (Date32(a), Date32(b)) => cmp_opt!(a, b), + (Date64(a), Date64(b)) => cmp_opt!(a, b), + + (Time32(a, u1), Time32(b, u2)) => { + if u1 != u2 { + panic!( + "Cannot compare Time32 with different units: {:?} vs {:?}", + u1, u2 + ); + } + cmp_opt!(a, b) + } + (Time64(a, u1), Time64(b, u2)) => { + if u1 != u2 { + panic!( + "Cannot compare Time64 with different units: {:?} vs {:?}", + u1, u2 + ); + } + cmp_opt!(a, b) + } + + (Timestamp(a, u1, _), Timestamp(b, u2, _)) => { + if u1 != u2 { + panic!( + "Cannot compare Timestamp with different units: {:?} vs {:?}", + u1, u2 + ); + } + cmp_opt!(a, b) + } + + (Duration(a, u1), Duration(b, u2)) => { + if u1 != u2 { + panic!( + "Cannot compare Duration with different units: {:?} vs {:?}", + u1, u2 + ); + } + cmp_opt!(a, b) + } + + (IntervalYearMonth(a), IntervalYearMonth(b)) => cmp_opt!(a, b), + (IntervalDayTime(a), IntervalDayTime(b)) => cmp_opt!(a, b), + (IntervalMonthDayNano(a), IntervalMonthDayNano(b)) => cmp_opt!(a, b), + + // Complex types - compare by array equality or panic + (List(a), List(b)) => { + if a == b { + Ordering::Equal + } else { + panic!("Cannot order List scalars") + } + } + (LargeList(a), LargeList(b)) => { + if a == b { + Ordering::Equal + } else { + panic!("Cannot order LargeList scalars") + } + } + (FixedSizeList(a), FixedSizeList(b)) => { + if a == b { + Ordering::Equal + } else { + panic!("Cannot order FixedSizeList scalars") + } + } + (Struct(f1, v1), Struct(f2, v2)) => { + if f1 != f2 { + panic!("Cannot compare Struct with different fields"); + } + for (a, b) in v1.iter().zip(v2.iter()) { + match a.cmp(b) { + Ordering::Equal => continue, + ord => return ord, + } + } + Ordering::Equal + } + (Map(a), Map(b)) => { + if a == b { + Ordering::Equal + } else { + panic!("Cannot order Map scalars") + } + } + (Dictionary(_, v1), Dictionary(_, v2)) => v1.cmp(v2), + + // Mismatched types + (a, b) => panic!( + "Cannot compare scalars of different types: {:?} vs {:?}", + a.data_type(), + b.data_type() + ), + } + } +} + +impl Hash for Scalar { + fn hash(&self, state: &mut H) { + use Scalar::*; + std::mem::discriminant(self).hash(state); + match self { + Null => {} + Boolean(v) => v.hash(state), + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Float16(v) => v.map(|f| f.to_bits()).hash(state), + Float32(v) => v.map(|f| f.to_bits()).hash(state), + Float64(v) => v.map(|f| f.to_bits()).hash(state), + Decimal128(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state); + } + Decimal256(v, p, s) => { + // i256 doesn't implement Hash, so we hash its bytes + if let Some(val) = v { + val.to_le_bytes().hash(state); + } else { + 0u8.hash(state); + } + p.hash(state); + s.hash(state); + } + Utf8(v) | LargeUtf8(v) | Utf8View(v) | Binary(v) | LargeBinary(v) | BinaryView(v) => { + v.as_ref().map(|b| b.as_slice()).hash(state) + } + FixedSizeBinary(s, v) => { + s.hash(state); + v.as_ref().map(|b| b.as_slice()).hash(state); + } + Date32(v) => v.hash(state), + Date64(v) => v.hash(state), + Time32(v, u) => { + v.hash(state); + u.hash(state); + } + Time64(v, u) => { + v.hash(state); + u.hash(state); + } + Timestamp(v, u, tz) => { + v.hash(state); + u.hash(state); + tz.hash(state); + } + Duration(v, u) => { + v.hash(state); + u.hash(state); + } + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), + IntervalMonthDayNano(v) => v.hash(state), + // For complex types, we hash their array data + List(arr) | LargeList(arr) | FixedSizeList(arr) | Map(arr) => { + arr.to_data().buffers().iter().for_each(|b| { + b.as_slice().hash(state); + }); + } + Struct(fields, values) => { + fields.len().hash(state); + values.iter().for_each(|v| v.hash(state)); + } + Dictionary(k, v) => { + k.hash(state); + v.hash(state); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::Buffer; + use std::collections::HashSet; + + #[test] + fn test_eq_nulls() { + assert_eq!(Scalar::Null, Scalar::Null); + assert_eq!(Scalar::Int32(None), Scalar::Int32(None)); + assert_ne!(Scalar::Null, Scalar::Int32(None)); + } + + #[test] + fn test_eq_floats_nan() { + let nan1 = Scalar::Float64(Some(f64::NAN)); + let nan2 = Scalar::Float64(Some(f64::NAN)); + assert_eq!(nan1, nan2); + + let neg_zero = Scalar::Float64(Some(-0.0)); + let pos_zero = Scalar::Float64(Some(0.0)); + assert_ne!(neg_zero, pos_zero); + } + + #[test] + fn test_ord_nulls_first() { + let null = Scalar::Int32(None); + let one = Scalar::Int32(Some(1)); + let two = Scalar::Int32(Some(2)); + + assert!(null < one); + assert!(null < two); + assert!(one < two); + } + + #[test] + fn test_ord_floats_nan() { + let nan = Scalar::Float64(Some(f64::NAN)); + let inf = Scalar::Float64(Some(f64::INFINITY)); + let neg_inf = Scalar::Float64(Some(f64::NEG_INFINITY)); + let one = Scalar::Float64(Some(1.0)); + + // NaN should be greater than everything in total_cmp + assert!(nan > inf); + assert!(nan > one); + assert!(nan > neg_inf); + } + + #[test] + fn test_hash_consistency() { + use std::hash::DefaultHasher; + + fn hash_scalar(s: &Scalar) -> u64 { + let mut hasher = DefaultHasher::new(); + s.hash(&mut hasher); + hasher.finish() + } + + let a = Scalar::Int32(Some(42)); + let b = Scalar::Int32(Some(42)); + assert_eq!(hash_scalar(&a), hash_scalar(&b)); + + let nan1 = Scalar::Float64(Some(f64::NAN)); + let nan2 = Scalar::Float64(Some(f64::NAN)); + assert_eq!(hash_scalar(&nan1), hash_scalar(&nan2)); + } + + #[test] + fn test_hash_set() { + let mut set = HashSet::new(); + set.insert(Scalar::Int32(Some(1))); + set.insert(Scalar::Int32(Some(2))); + set.insert(Scalar::Int32(Some(1))); + assert_eq!(set.len(), 2); + + set.insert(Scalar::Int32(None)); + assert_eq!(set.len(), 3); + } + + #[test] + fn test_ord_strings() { + let a = Scalar::Utf8(Some(Buffer::from(b"aaa".as_ref()))); + let b = Scalar::Utf8(Some(Buffer::from(b"bbb".as_ref()))); + assert!(a < b); + } + + #[test] + #[should_panic(expected = "Cannot compare scalars of different types")] + fn test_ord_different_types_panics() { + let a = Scalar::Int32(Some(1)); + let b = Scalar::Int64(Some(1)); + let _ = a.cmp(&b); + } +} diff --git a/rust/arrow-scalar/src/convert.rs b/rust/arrow-scalar/src/convert.rs new file mode 100644 index 00000000000..8beb28340e2 --- /dev/null +++ b/rust/arrow-scalar/src/convert.rs @@ -0,0 +1,632 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Conversions between Arrow arrays and Scalar values. + +use std::sync::Arc; + +use arrow_array::{ + cast::AsArray, + types::{ + BinaryViewType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, + }, + Array, ArrayRef, BooleanArray, FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, + PrimitiveArray, StructArray, +}; +use arrow_buffer::{Buffer, OffsetBuffer, ScalarBuffer}; +use arrow_data::transform::MutableArrayData; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::Scalar; + +type Result = std::result::Result; + +/// Extracts a scalar value from an array at the given index. +pub fn try_from_array(array: &dyn Array, index: usize) -> Result { + if index >= array.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Index {} out of bounds for array of length {}", + index, + array.len() + ))); + } + + if array.is_null(index) { + return Ok(Scalar::null_for_type(array.data_type())); + } + + match array.data_type() { + DataType::Null => Ok(Scalar::Null), + DataType::Boolean => { + let arr = array.as_boolean(); + Ok(Scalar::Boolean(Some(arr.value(index)))) + } + DataType::Int8 => { + let arr = array.as_primitive::(); + Ok(Scalar::Int8(Some(arr.value(index)))) + } + DataType::Int16 => { + let arr = array.as_primitive::(); + Ok(Scalar::Int16(Some(arr.value(index)))) + } + DataType::Int32 => { + let arr = array.as_primitive::(); + Ok(Scalar::Int32(Some(arr.value(index)))) + } + DataType::Int64 => { + let arr = array.as_primitive::(); + Ok(Scalar::Int64(Some(arr.value(index)))) + } + DataType::UInt8 => { + let arr = array.as_primitive::(); + Ok(Scalar::UInt8(Some(arr.value(index)))) + } + DataType::UInt16 => { + let arr = array.as_primitive::(); + Ok(Scalar::UInt16(Some(arr.value(index)))) + } + DataType::UInt32 => { + let arr = array.as_primitive::(); + Ok(Scalar::UInt32(Some(arr.value(index)))) + } + DataType::UInt64 => { + let arr = array.as_primitive::(); + Ok(Scalar::UInt64(Some(arr.value(index)))) + } + DataType::Float16 => { + let arr = array.as_primitive::(); + Ok(Scalar::Float16(Some(arr.value(index)))) + } + DataType::Float32 => { + let arr = array.as_primitive::(); + Ok(Scalar::Float32(Some(arr.value(index)))) + } + DataType::Float64 => { + let arr = array.as_primitive::(); + Ok(Scalar::Float64(Some(arr.value(index)))) + } + DataType::Decimal128(precision, scale) => { + let arr = array.as_primitive::(); + Ok(Scalar::Decimal128( + Some(arr.value(index)), + *precision, + *scale, + )) + } + DataType::Decimal256(precision, scale) => { + let arr = array.as_primitive::(); + Ok(Scalar::Decimal256( + Some(arr.value(index)), + *precision, + *scale, + )) + } + DataType::Utf8 => { + let arr = array.as_string::(); + let offsets = arr.value_offsets(); + let start = offsets[index] as usize; + let end = offsets[index + 1] as usize; + let buf = arr.values().slice_with_length(start, end - start); + Ok(Scalar::Utf8(Some(buf))) + } + DataType::LargeUtf8 => { + let arr = array.as_string::(); + let offsets = arr.value_offsets(); + let start = offsets[index] as usize; + let end = offsets[index + 1] as usize; + let buf = arr.values().slice_with_length(start, end - start); + Ok(Scalar::LargeUtf8(Some(buf))) + } + DataType::Utf8View => { + let arr = array.as_string_view(); + Ok(Scalar::Utf8View(Some(Buffer::from( + arr.value(index).as_bytes(), + )))) + } + DataType::Binary => { + let arr = array.as_binary::(); + let offsets = arr.value_offsets(); + let start = offsets[index] as usize; + let end = offsets[index + 1] as usize; + let buf = arr.values().slice_with_length(start, end - start); + Ok(Scalar::Binary(Some(buf))) + } + DataType::LargeBinary => { + let arr = array.as_binary::(); + let offsets = arr.value_offsets(); + let start = offsets[index] as usize; + let end = offsets[index + 1] as usize; + let buf = arr.values().slice_with_length(start, end - start); + Ok(Scalar::LargeBinary(Some(buf))) + } + DataType::BinaryView => { + let arr = array.as_binary_view(); + Ok(Scalar::BinaryView(Some(Buffer::from(arr.value(index))))) + } + DataType::FixedSizeBinary(size) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::InvalidArgumentError("Expected FixedSizeBinaryArray".to_string()) + })?; + let offset = arr.value_offset(index) as usize; + let length = arr.value_length() as usize; + let buf = arr.to_data().buffers()[0].slice_with_length(offset, length); + Ok(Scalar::FixedSizeBinary(*size, Some(buf))) + } + DataType::Date32 => { + let arr = array.as_primitive::(); + Ok(Scalar::Date32(Some(arr.value(index)))) + } + DataType::Date64 => { + let arr = array.as_primitive::(); + Ok(Scalar::Date64(Some(arr.value(index)))) + } + DataType::Time32(TimeUnit::Second) => { + let arr = array.as_primitive::(); + Ok(Scalar::Time32(Some(arr.value(index)), TimeUnit::Second)) + } + DataType::Time32(TimeUnit::Millisecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Time32( + Some(arr.value(index)), + TimeUnit::Millisecond, + )) + } + DataType::Time32(unit) => Err(ArrowError::InvalidArgumentError(format!( + "Invalid time unit for Time32: {:?}", + unit + ))), + DataType::Time64(TimeUnit::Microsecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Time64( + Some(arr.value(index)), + TimeUnit::Microsecond, + )) + } + DataType::Time64(TimeUnit::Nanosecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Time64(Some(arr.value(index)), TimeUnit::Nanosecond)) + } + DataType::Time64(unit) => Err(ArrowError::InvalidArgumentError(format!( + "Invalid time unit for Time64: {:?}", + unit + ))), + DataType::Timestamp(TimeUnit::Second, tz) => { + let arr = array.as_primitive::(); + Ok(Scalar::Timestamp( + Some(arr.value(index)), + TimeUnit::Second, + tz.clone(), + )) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let arr = array.as_primitive::(); + Ok(Scalar::Timestamp( + Some(arr.value(index)), + TimeUnit::Millisecond, + tz.clone(), + )) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let arr = array.as_primitive::(); + Ok(Scalar::Timestamp( + Some(arr.value(index)), + TimeUnit::Microsecond, + tz.clone(), + )) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let arr = array.as_primitive::(); + Ok(Scalar::Timestamp( + Some(arr.value(index)), + TimeUnit::Nanosecond, + tz.clone(), + )) + } + DataType::Duration(TimeUnit::Second) => { + let arr = array.as_primitive::(); + Ok(Scalar::Duration(Some(arr.value(index)), TimeUnit::Second)) + } + DataType::Duration(TimeUnit::Millisecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Duration( + Some(arr.value(index)), + TimeUnit::Millisecond, + )) + } + DataType::Duration(TimeUnit::Microsecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Duration( + Some(arr.value(index)), + TimeUnit::Microsecond, + )) + } + DataType::Duration(TimeUnit::Nanosecond) => { + let arr = array.as_primitive::(); + Ok(Scalar::Duration( + Some(arr.value(index)), + TimeUnit::Nanosecond, + )) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let arr = array.as_primitive::(); + Ok(Scalar::IntervalYearMonth(Some(arr.value(index)))) + } + DataType::Interval(IntervalUnit::DayTime) => { + let arr = array.as_primitive::(); + let val = arr.value(index); + // IntervalDayTime is stored as days (lower 32 bits) and ms (upper 32 bits) + let combined = ((val.milliseconds as i64) << 32) | (val.days as i64 & 0xFFFFFFFF); + Ok(Scalar::IntervalDayTime(Some(combined))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let arr = array.as_primitive::(); + let val = arr.value(index); + // IntervalMonthDayNano: months (lower 32), days (next 32), nanos (upper 64) + let combined = ((val.nanoseconds as i128) << 64) + | ((val.days as i128 & 0xFFFFFFFF) << 32) + | (val.months as i128 & 0xFFFFFFFF); + Ok(Scalar::IntervalMonthDayNano(Some(combined))) + } + DataType::List(_) => { + let arr = extract_scalar_element(array, index)?; + Ok(Scalar::List(arr)) + } + DataType::LargeList(_) => { + let arr = extract_scalar_element(array, index)?; + Ok(Scalar::LargeList(arr)) + } + DataType::FixedSizeList(_, _) => { + let arr = extract_scalar_element(array, index)?; + Ok(Scalar::FixedSizeList(arr)) + } + DataType::Struct(fields) => { + let struct_arr = array.as_struct(); + let values = struct_arr + .columns() + .iter() + .map(|col| try_from_array(col.as_ref(), index)) + .collect::>>()?; + Ok(Scalar::Struct(fields.clone(), values)) + } + DataType::Map(_, _) => { + let arr = extract_scalar_element(array, index)?; + Ok(Scalar::Map(arr)) + } + DataType::Dictionary(key_type, _) => { + let dict = array.as_any_dictionary(); + let key_idx = dict.keys().as_primitive::().value(index) as usize; + let value_scalar = try_from_array(dict.values().as_ref(), key_idx)?; + Ok(Scalar::Dictionary(key_type.clone(), Box::new(value_scalar))) + } + dt => Err(ArrowError::NotYetImplemented(format!( + "Scalar conversion not implemented for {:?}", + dt + ))), + } +} + +/// Extracts a single element from an array as a length-1 array. +fn extract_scalar_element(array: &dyn Array, index: usize) -> Result { + let data = array.to_data(); + let mut mutable = MutableArrayData::new(vec![&data], true, 1); + mutable.extend(0, index, index + 1); + Ok(arrow_array::make_array(mutable.freeze())) +} + +impl Scalar { + /// Converts this scalar to a length-1 Arrow array. + pub fn to_array(&self) -> ArrayRef { + match self { + Self::Null => arrow_array::new_null_array(&DataType::Null, 1), + Self::Boolean(v) => Arc::new(BooleanArray::from(vec![*v])), + Self::Int8(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Int16(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Int32(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Int64(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::UInt8(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::UInt16(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::UInt32(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::UInt64(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Float16(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Float32(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Float64(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Decimal128(v, precision, scale) => { + let arr = PrimitiveArray::::from(vec![*v]) + .with_precision_and_scale(*precision, *scale) + .expect("Invalid decimal precision/scale"); + Arc::new(arr) + } + Self::Decimal256(v, precision, scale) => { + let arr = PrimitiveArray::::from(vec![*v]) + .with_precision_and_scale(*precision, *scale) + .expect("Invalid decimal precision/scale"); + Arc::new(arr) + } + Self::Utf8(v) => match v { + Some(buf) => { + let offsets = + OffsetBuffer::new(ScalarBuffer::from(vec![0i32, buf.len() as i32])); + Arc::new(GenericByteArray::::new( + offsets, + buf.clone(), + None, + )) + } + None => arrow_array::new_null_array(&DataType::Utf8, 1), + }, + Self::LargeUtf8(v) => match v { + Some(buf) => { + let offsets = + OffsetBuffer::new(ScalarBuffer::from(vec![0i64, buf.len() as i64])); + Arc::new(GenericByteArray::::new( + offsets, + buf.clone(), + None, + )) + } + None => arrow_array::new_null_array(&DataType::LargeUtf8, 1), + }, + Self::Utf8View(v) => { + let s = v.as_deref().map(|b| { + std::str::from_utf8(b).expect("Utf8View scalar must contain valid UTF-8") + }); + Arc::new(GenericByteViewArray::::from(vec![s])) + } + Self::Binary(v) => match v { + Some(buf) => { + let offsets = + OffsetBuffer::new(ScalarBuffer::from(vec![0i32, buf.len() as i32])); + Arc::new(GenericByteArray::::new( + offsets, + buf.clone(), + None, + )) + } + None => arrow_array::new_null_array(&DataType::Binary, 1), + }, + Self::LargeBinary(v) => match v { + Some(buf) => { + let offsets = + OffsetBuffer::new(ScalarBuffer::from(vec![0i64, buf.len() as i64])); + Arc::new( + GenericByteArray::::new( + offsets, + buf.clone(), + None, + ), + ) + } + None => arrow_array::new_null_array(&DataType::LargeBinary, 1), + }, + Self::BinaryView(v) => Arc::new(GenericByteViewArray::::from(vec![ + v.as_deref() + ])), + Self::FixedSizeBinary(size, v) => { + let arr = match v { + Some(buf) => FixedSizeBinaryArray::try_from_sparse_iter_with_size( + std::iter::once(Some(buf.as_ref())), + *size, + ) + .expect("Invalid fixed size binary"), + None => FixedSizeBinaryArray::try_from_sparse_iter_with_size( + std::iter::once(None::<&[u8]>), + *size, + ) + .expect("Invalid fixed size binary"), + }; + Arc::new(arr) + } + Self::Date32(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Date64(v) => Arc::new(PrimitiveArray::::from(vec![*v])), + Self::Time32(v, unit) => match unit { + TimeUnit::Second => Arc::new(PrimitiveArray::::from(vec![*v])), + TimeUnit::Millisecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + _ => panic!("Invalid time unit for Time32: {:?}", unit), + }, + Self::Time64(v, unit) => match unit { + TimeUnit::Microsecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + TimeUnit::Nanosecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + _ => panic!("Invalid time unit for Time64: {:?}", unit), + }, + Self::Timestamp(v, unit, tz) => match unit { + TimeUnit::Second => Arc::new( + PrimitiveArray::::from(vec![*v]) + .with_timezone_opt(tz.clone()), + ), + TimeUnit::Millisecond => Arc::new( + PrimitiveArray::::from(vec![*v]) + .with_timezone_opt(tz.clone()), + ), + TimeUnit::Microsecond => Arc::new( + PrimitiveArray::::from(vec![*v]) + .with_timezone_opt(tz.clone()), + ), + TimeUnit::Nanosecond => Arc::new( + PrimitiveArray::::from(vec![*v]) + .with_timezone_opt(tz.clone()), + ), + }, + Self::Duration(v, unit) => match unit { + TimeUnit::Second => Arc::new(PrimitiveArray::::from(vec![*v])), + TimeUnit::Millisecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + TimeUnit::Microsecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + TimeUnit::Nanosecond => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + }, + Self::IntervalYearMonth(v) => { + Arc::new(PrimitiveArray::::from(vec![*v])) + } + Self::IntervalDayTime(v) => { + let v = v.map(|combined| { + let days = (combined & 0xFFFFFFFF) as i32; + let ms = (combined >> 32) as i32; + arrow_buffer::IntervalDayTime::new(days, ms) + }); + Arc::new(PrimitiveArray::::from(vec![v])) + } + Self::IntervalMonthDayNano(v) => { + let v = v.map(|combined| { + let months = (combined & 0xFFFFFFFF) as i32; + let days = ((combined >> 32) & 0xFFFFFFFF) as i32; + let ns = (combined >> 64) as i64; + arrow_buffer::IntervalMonthDayNano::new(months, days, ns) + }); + Arc::new(PrimitiveArray::::from(vec![v])) + } + Self::List(arr) => arr.clone(), + Self::LargeList(arr) => arr.clone(), + Self::FixedSizeList(arr) => arr.clone(), + Self::Struct(fields, values) => { + let arrays: Vec = values.iter().map(|v| v.to_array()).collect(); + Arc::new(StructArray::new(fields.clone(), arrays, None)) + } + Self::Map(arr) => arr.clone(), + Self::Dictionary(key_type, value) => { + let value_arr = value.to_array(); + match key_type.as_ref() { + DataType::Int8 => { + let keys = PrimitiveArray::::from(vec![Some(0i8)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::Int16 => { + let keys = PrimitiveArray::::from(vec![Some(0i16)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::Int32 => { + let keys = PrimitiveArray::::from(vec![Some(0i32)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::Int64 => { + let keys = PrimitiveArray::::from(vec![Some(0i64)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::UInt8 => { + let keys = PrimitiveArray::::from(vec![Some(0u8)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::UInt16 => { + let keys = PrimitiveArray::::from(vec![Some(0u16)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::UInt32 => { + let keys = PrimitiveArray::::from(vec![Some(0u32)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + DataType::UInt64 => { + let keys = PrimitiveArray::::from(vec![Some(0u64)]); + Arc::new( + arrow_array::DictionaryArray::try_new(keys, value_arr) + .expect("Invalid dictionary"), + ) + } + _ => panic!("Invalid dictionary key type: {:?}", key_type), + } + } + } + } +} + +/// Converts an iterator of scalars to an Arrow array. +/// +/// All scalars must have the same data type. +pub fn iter_to_array(iter: impl Iterator) -> Result { + let scalars: Vec = iter.collect(); + if scalars.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Cannot create array from empty iterator".to_string(), + )); + } + + let arrays: Vec = scalars.iter().map(|s| s.to_array()).collect(); + let refs: Vec<&dyn Array> = arrays.iter().map(|a| a.as_ref()).collect(); + arrow_select::concat::concat(&refs) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Int32Array, StringArray}; + use arrow_buffer::Buffer; + use rstest::rstest; + + #[rstest] + #[case::int32(Arc::new(Int32Array::from(vec![Some(42), None, Some(7)])) as ArrayRef, 0, Scalar::Int32(Some(42)))] + #[case::int32_null(Arc::new(Int32Array::from(vec![Some(42), None, Some(7)])) as ArrayRef, 1, Scalar::Int32(None))] + #[case::string(Arc::new(StringArray::from(vec![Some("hello"), None])) as ArrayRef, 0, Scalar::Utf8(Some(Buffer::from("hello".as_bytes()))))] + fn test_try_from_array( + #[case] array: ArrayRef, + #[case] index: usize, + #[case] expected: Scalar, + ) { + let result = try_from_array(array.as_ref(), index).unwrap(); + assert_eq!(result.data_type(), expected.data_type()); + assert_eq!(result.is_null(), expected.is_null()); + } + + #[test] + fn test_round_trip_primitives() { + let original: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + for i in 0..original.len() { + let scalar = try_from_array(original.as_ref(), i).unwrap(); + let arr = scalar.to_array(); + assert_eq!(arr.len(), 1); + let back = try_from_array(arr.as_ref(), 0).unwrap(); + assert_eq!(back.data_type(), scalar.data_type()); + } + } + + #[test] + fn test_iter_to_array() { + let scalars = vec![ + Scalar::Int32(Some(1)), + Scalar::Int32(Some(2)), + Scalar::Int32(None), + Scalar::Int32(Some(4)), + ]; + let arr = iter_to_array(scalars.into_iter()).unwrap(); + assert_eq!(arr.len(), 4); + assert_eq!(arr.null_count(), 1); + } +} diff --git a/rust/arrow-scalar/src/display.rs b/rust/arrow-scalar/src/display.rs new file mode 100644 index 00000000000..87e68f6f80b --- /dev/null +++ b/rust/arrow-scalar/src/display.rs @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Display formatting for Scalar values. + +use std::fmt::{Display, Formatter, Result}; + +use crate::Scalar; + +impl Display for Scalar { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + use Scalar::*; + + match self { + Null => write!(f, "NULL"), + Boolean(None) => write!(f, "NULL"), + Boolean(Some(v)) => write!(f, "{}", v), + Int8(None) => write!(f, "NULL"), + Int8(Some(v)) => write!(f, "{}", v), + Int16(None) => write!(f, "NULL"), + Int16(Some(v)) => write!(f, "{}", v), + Int32(None) => write!(f, "NULL"), + Int32(Some(v)) => write!(f, "{}", v), + Int64(None) => write!(f, "NULL"), + Int64(Some(v)) => write!(f, "{}", v), + UInt8(None) => write!(f, "NULL"), + UInt8(Some(v)) => write!(f, "{}", v), + UInt16(None) => write!(f, "NULL"), + UInt16(Some(v)) => write!(f, "{}", v), + UInt32(None) => write!(f, "NULL"), + UInt32(Some(v)) => write!(f, "{}", v), + UInt64(None) => write!(f, "NULL"), + UInt64(Some(v)) => write!(f, "{}", v), + Float16(None) => write!(f, "NULL"), + Float16(Some(v)) => write!(f, "{}", v), + Float32(None) => write!(f, "NULL"), + Float32(Some(v)) => write!(f, "{}", v), + Float64(None) => write!(f, "NULL"), + Float64(Some(v)) => write!(f, "{}", v), + Decimal128(None, _, _) => write!(f, "NULL"), + Decimal128(Some(v), precision, scale) => write_decimal(f, *v, *precision, *scale), + Decimal256(None, _, _) => write!(f, "NULL"), + Decimal256(Some(v), _precision, scale) => { + // Convert i256 to string representation + let s = format!("{}", v); + // Apply scale + if *scale > 0 { + let scale = *scale as usize; + if s.len() <= scale { + write!(f, "0.{:0>width$}", s, width = scale) + } else { + let (int_part, frac_part) = s.split_at(s.len() - scale); + write!(f, "{}.{}", int_part, frac_part) + } + } else { + write!(f, "{}", s) + } + } + Utf8(None) | LargeUtf8(None) | Utf8View(None) => write!(f, "NULL"), + Utf8(Some(v)) | LargeUtf8(Some(v)) | Utf8View(Some(v)) => { + let s = std::str::from_utf8(v).expect("Utf8 scalar must contain valid UTF-8"); + write!(f, "\"{}\"", s) + } + Binary(None) | LargeBinary(None) | BinaryView(None) => write!(f, "NULL"), + Binary(Some(v)) | LargeBinary(Some(v)) | BinaryView(Some(v)) => write_hex(f, v), + FixedSizeBinary(_, None) => write!(f, "NULL"), + FixedSizeBinary(_, Some(v)) => write_hex(f, v), + Date32(None) => write!(f, "NULL"), + Date32(Some(days)) => { + // Days since Unix epoch + let date = chrono_date_from_days(*days as i64); + write!(f, "{}", date) + } + Date64(None) => write!(f, "NULL"), + Date64(Some(ms)) => { + // Milliseconds since Unix epoch + let date = chrono_date_from_ms(*ms); + write!(f, "{}", date) + } + Time32(None, _) => write!(f, "NULL"), + Time32(Some(v), unit) => { + let (h, m, s, ns) = time_parts_from_unit(*v as i64, unit); + write!(f, "{:02}:{:02}:{:02}.{:09}", h, m, s, ns) + } + Time64(None, _) => write!(f, "NULL"), + Time64(Some(v), unit) => { + let (h, m, s, ns) = time_parts_from_unit(*v, unit); + write!(f, "{:02}:{:02}:{:02}.{:09}", h, m, s, ns) + } + Timestamp(None, _, _) => write!(f, "NULL"), + Timestamp(Some(v), unit, tz) => { + let ns = match unit { + arrow_schema::TimeUnit::Second => *v * 1_000_000_000, + arrow_schema::TimeUnit::Millisecond => *v * 1_000_000, + arrow_schema::TimeUnit::Microsecond => *v * 1_000, + arrow_schema::TimeUnit::Nanosecond => *v, + }; + let secs = ns / 1_000_000_000; + let subsec_ns = (ns % 1_000_000_000) as u32; + + if let Some(tz) = tz { + write!(f, "{}T{:09} {}", secs, subsec_ns, tz) + } else { + write!(f, "{}T{:09}", secs, subsec_ns) + } + } + Duration(None, _) => write!(f, "NULL"), + Duration(Some(v), unit) => { + let label = match unit { + arrow_schema::TimeUnit::Second => "s", + arrow_schema::TimeUnit::Millisecond => "ms", + arrow_schema::TimeUnit::Microsecond => "us", + arrow_schema::TimeUnit::Nanosecond => "ns", + }; + write!(f, "{}{}", v, label) + } + IntervalYearMonth(None) => write!(f, "NULL"), + IntervalYearMonth(Some(v)) => { + let years = v / 12; + let months = v % 12; + write!(f, "{}y{}m", years, months) + } + IntervalDayTime(None) => write!(f, "NULL"), + IntervalDayTime(Some(v)) => { + let days = (*v & 0xFFFFFFFF) as i32; + let ms = (*v >> 32) as i32; + write!(f, "{}d{}ms", days, ms) + } + IntervalMonthDayNano(None) => write!(f, "NULL"), + IntervalMonthDayNano(Some(v)) => { + let months = (*v & 0xFFFFFFFF) as i32; + let days = ((*v >> 32) & 0xFFFFFFFF) as i32; + let ns = (*v >> 64) as i64; + write!(f, "{}m{}d{}ns", months, days, ns) + } + List(arr) => { + if arr.is_null(0) { + write!(f, "NULL") + } else { + write!(f, "[list]") + } + } + LargeList(arr) => { + if arr.is_null(0) { + write!(f, "NULL") + } else { + write!(f, "[large_list]") + } + } + FixedSizeList(arr) => { + if arr.is_null(0) { + write!(f, "NULL") + } else { + write!(f, "[fixed_size_list]") + } + } + Struct(fields, values) => { + if values.is_empty() { + write!(f, "NULL") + } else { + write!(f, "{{")?; + for (i, (field, value)) in fields.iter().zip(values.iter()).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", field.name(), value)?; + } + write!(f, "}}") + } + } + Map(arr) => { + if arr.is_null(0) { + write!(f, "NULL") + } else { + write!(f, "[map]") + } + } + Dictionary(_, value) => write!(f, "{}", value), + } + } +} + +fn write_decimal(f: &mut Formatter<'_>, value: i128, _precision: u8, scale: i8) -> Result { + if scale <= 0 { + write!(f, "{}", value) + } else { + let scale = scale as usize; + let is_neg = value < 0; + let abs_value = value.unsigned_abs(); + let s = format!("{}", abs_value); + + let result = if s.len() <= scale { + format!("0.{:0>width$}", s, width = scale) + } else { + let (int_part, frac_part) = s.split_at(s.len() - scale); + format!("{}.{}", int_part, frac_part) + }; + + if is_neg { + write!(f, "-{}", result) + } else { + write!(f, "{}", result) + } + } +} + +fn write_hex(f: &mut Formatter<'_>, bytes: &[u8]) -> Result { + write!(f, "0x")?; + for b in bytes { + write!(f, "{:02x}", b)?; + } + Ok(()) +} + +fn chrono_date_from_days(days: i64) -> String { + // Unix epoch is 1970-01-01 + // Simple calculation without chrono dependency + let epoch_days = 719_468i64; // Days from year 0 to 1970-01-01 + let total_days = epoch_days + days; + + // Algorithm from https://howardhinnant.github.io/date_algorithms.html + let era = if total_days >= 0 { + total_days / 146097 + } else { + (total_days - 146096) / 146097 + }; + let doe = (total_days - era * 146097) as u32; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let year = if m <= 2 { y + 1 } else { y }; + + format!("{:04}-{:02}-{:02}", year, m, d) +} + +fn chrono_date_from_ms(ms: i64) -> String { + let days = ms / (24 * 60 * 60 * 1000); + chrono_date_from_days(days) +} + +fn time_parts_from_unit(value: i64, unit: &arrow_schema::TimeUnit) -> (u32, u32, u32, u32) { + let ns = match unit { + arrow_schema::TimeUnit::Second => value * 1_000_000_000, + arrow_schema::TimeUnit::Millisecond => value * 1_000_000, + arrow_schema::TimeUnit::Microsecond => value * 1_000, + arrow_schema::TimeUnit::Nanosecond => value, + }; + + let total_secs = ns / 1_000_000_000; + let subsec_ns = (ns % 1_000_000_000) as u32; + + let h = (total_secs / 3600) as u32; + let m = ((total_secs % 3600) / 60) as u32; + let s = (total_secs % 60) as u32; + + (h, m, s, subsec_ns) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::Buffer; + + #[test] + fn test_display_primitives() { + assert_eq!(format!("{}", Scalar::Null), "NULL"); + assert_eq!(format!("{}", Scalar::Boolean(Some(true))), "true"); + assert_eq!(format!("{}", Scalar::Boolean(Some(false))), "false"); + assert_eq!(format!("{}", Scalar::Boolean(None)), "NULL"); + assert_eq!(format!("{}", Scalar::Int32(Some(42))), "42"); + assert_eq!(format!("{}", Scalar::Int32(Some(-42))), "-42"); + assert_eq!(format!("{}", Scalar::Int32(None)), "NULL"); + } + + #[test] + fn test_display_floats() { + assert_eq!(format!("{}", Scalar::Float64(Some(3.14))), "3.14"); + assert_eq!(format!("{}", Scalar::Float64(None)), "NULL"); + } + + #[test] + fn test_display_strings() { + assert_eq!( + format!("{}", Scalar::Utf8(Some(Buffer::from("hello".as_bytes())))), + "\"hello\"" + ); + assert_eq!(format!("{}", Scalar::Utf8(None)), "NULL"); + } + + #[test] + fn test_display_binary() { + assert_eq!( + format!("{}", Scalar::Binary(Some(Buffer::from(vec![0xABu8, 0xCD])))), + "0xabcd" + ); + } + + #[test] + fn test_display_decimal() { + assert_eq!( + format!("{}", Scalar::Decimal128(Some(12345), 10, 2)), + "123.45" + ); + assert_eq!( + format!("{}", Scalar::Decimal128(Some(-12345), 10, 2)), + "-123.45" + ); + assert_eq!(format!("{}", Scalar::Decimal128(Some(45), 10, 2)), "0.45"); + } + + #[test] + fn test_display_date() { + // 2020-01-01 is 18262 days since Unix epoch + assert_eq!(format!("{}", Scalar::Date32(Some(18262))), "2020-01-01"); + } +} diff --git a/rust/arrow-scalar/src/lib.rs b/rust/arrow-scalar/src/lib.rs new file mode 100644 index 00000000000..451776c3339 --- /dev/null +++ b/rust/arrow-scalar/src/lib.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Arrow Scalar - A scalar value representation for Apache Arrow types. +//! +//! This crate provides a `Scalar` enum for representing single Arrow values, +//! similar to DataFusion's `ScalarValue` but without DataFusion dependencies. +//! +//! # Features +//! +//! - Represents all Arrow primitive and complex types +//! - Converts between Arrow arrays and scalar values +//! - Byte serialization for storage and indexing +//! - Implements `Eq`, `Ord`, `Hash` with proper null/NaN handling +//! +//! # Example +//! +//! ``` +//! use arrow_scalar::{Scalar, try_from_array}; +//! use arrow_array::{Int32Array, Array}; +//! use std::sync::Arc; +//! +//! // Create a scalar from an array element +//! let array = Int32Array::from(vec![Some(1), None, Some(3)]); +//! let scalar = try_from_array(&array, 0).unwrap(); +//! assert_eq!(scalar, Scalar::Int32(Some(1))); +//! +//! // Convert back to an array +//! let arr = scalar.to_array(); +//! assert_eq!(arr.len(), 1); +//! ``` +//! +//! # Comparison Semantics (designed to match DataFusion's ScalarValue) +//! +//! - `NULL == NULL` for equality +//! - `NaN == NaN` using total_cmp semantics for floats +//! - Nulls sort first (less than all non-null values) +//! - Floats use `total_cmp()` for ordering + +mod bytes; +mod cmp; +mod convert; +mod display; +mod scalar; + +pub use convert::{iter_to_array, try_from_array}; +pub use scalar::Scalar; diff --git a/rust/arrow-scalar/src/scalar.rs b/rust/arrow-scalar/src/scalar.rs new file mode 100644 index 00000000000..b7076f864c6 --- /dev/null +++ b/rust/arrow-scalar/src/scalar.rs @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Core Scalar enum definition for representing single Arrow values. + +use std::sync::Arc; + +use arrow_array::ArrayRef; +use arrow_buffer::{i256, Buffer}; +use arrow_schema::{DataType, Field, Fields, IntervalUnit, TimeUnit}; +use half::f16; + +/// A single value in Arrow format. +/// +/// This enum represents a single scalar value for each Arrow data type. +/// It is similar to DataFusion's `ScalarValue` but without DataFusion dependencies. +/// +/// For primitive types, the value is wrapped in `Option` where `None` represents null. +/// For complex types like List and Struct, null values are represented within the +/// contained arrays or vectors. +#[derive(Clone, Debug)] +pub enum Scalar { + /// Null value with unknown type + Null, + + /// Boolean value + Boolean(Option), + + /// Signed 8-bit integer + Int8(Option), + /// Signed 16-bit integer + Int16(Option), + /// Signed 32-bit integer + Int32(Option), + /// Signed 64-bit integer + Int64(Option), + + /// Unsigned 8-bit integer + UInt8(Option), + /// Unsigned 16-bit integer + UInt16(Option), + /// Unsigned 32-bit integer + UInt32(Option), + /// Unsigned 64-bit integer + UInt64(Option), + + /// 16-bit floating point + Float16(Option), + /// 32-bit floating point + Float32(Option), + /// 64-bit floating point + Float64(Option), + + /// 128-bit decimal with precision and scale + Decimal128(Option, u8, i8), + /// 256-bit decimal with precision and scale + Decimal256(Option, u8, i8), + + /// UTF-8 encoded string + Utf8(Option), + /// UTF-8 encoded string with 64-bit offsets + LargeUtf8(Option), + /// UTF-8 encoded string view + Utf8View(Option), + + /// Variable-length binary data + Binary(Option), + /// Variable-length binary data with 64-bit offsets + LargeBinary(Option), + /// Binary view + BinaryView(Option), + /// Fixed-size binary data (size in bytes, value) + FixedSizeBinary(i32, Option), + + /// Days since Unix epoch + Date32(Option), + /// Milliseconds since Unix epoch + Date64(Option), + + /// Time of day with specified unit (only Second and Millisecond valid) + Time32(Option, TimeUnit), + /// Time of day with specified unit (only Microsecond and Nanosecond valid) + Time64(Option, TimeUnit), + + /// Timestamp with time unit and optional timezone + Timestamp(Option, TimeUnit, Option>), + + /// Duration with time unit + Duration(Option, TimeUnit), + + /// Interval in months + IntervalYearMonth(Option), + /// Interval in days and milliseconds (stored as i64: days in lower 32 bits, ms in upper 32 bits) + IntervalDayTime(Option), + /// Interval in months, days, and nanoseconds (stored as i128) + IntervalMonthDayNano(Option), + + /// List array (stored as a length-1 array for the element) + List(ArrayRef), + /// Large list array (stored as a length-1 array for the element) + LargeList(ArrayRef), + /// Fixed-size list array (stored as a length-1 array for the element) + FixedSizeList(ArrayRef), + + /// Struct with named fields + Struct(Fields, Vec), + + /// Map array (stored as a length-1 array) + Map(ArrayRef), + + /// Dictionary-encoded value (key type, value) + Dictionary(Box, Box), +} + +impl Scalar { + /// Returns the Arrow [`DataType`] for this scalar. + pub fn data_type(&self) -> DataType { + match self { + Self::Null => DataType::Null, + Self::Boolean(_) => DataType::Boolean, + Self::Int8(_) => DataType::Int8, + Self::Int16(_) => DataType::Int16, + Self::Int32(_) => DataType::Int32, + Self::Int64(_) => DataType::Int64, + Self::UInt8(_) => DataType::UInt8, + Self::UInt16(_) => DataType::UInt16, + Self::UInt32(_) => DataType::UInt32, + Self::UInt64(_) => DataType::UInt64, + Self::Float16(_) => DataType::Float16, + Self::Float32(_) => DataType::Float32, + Self::Float64(_) => DataType::Float64, + Self::Decimal128(_, precision, scale) => DataType::Decimal128(*precision, *scale), + Self::Decimal256(_, precision, scale) => DataType::Decimal256(*precision, *scale), + Self::Utf8(_) => DataType::Utf8, + Self::LargeUtf8(_) => DataType::LargeUtf8, + Self::Utf8View(_) => DataType::Utf8View, + Self::Binary(_) => DataType::Binary, + Self::LargeBinary(_) => DataType::LargeBinary, + Self::BinaryView(_) => DataType::BinaryView, + Self::FixedSizeBinary(size, _) => DataType::FixedSizeBinary(*size), + Self::Date32(_) => DataType::Date32, + Self::Date64(_) => DataType::Date64, + Self::Time32(_, unit) => DataType::Time32(*unit), + Self::Time64(_, unit) => DataType::Time64(*unit), + Self::Timestamp(_, unit, tz) => DataType::Timestamp(*unit, tz.clone()), + Self::Duration(_, unit) => DataType::Duration(*unit), + Self::IntervalYearMonth(_) => DataType::Interval(IntervalUnit::YearMonth), + Self::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + Self::IntervalMonthDayNano(_) => DataType::Interval(IntervalUnit::MonthDayNano), + Self::List(arr) => { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("List scalar must contain ListArray"); + DataType::List(Arc::new(Field::new("item", list_arr.value_type(), true))) + } + Self::LargeList(arr) => { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("LargeList scalar must contain LargeListArray"); + DataType::LargeList(Arc::new(Field::new("item", list_arr.value_type(), true))) + } + Self::FixedSizeList(arr) => { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("FixedSizeList scalar must contain FixedSizeListArray"); + DataType::FixedSizeList( + Arc::new(Field::new("item", list_arr.value_type(), true)), + list_arr.value_length(), + ) + } + Self::Struct(fields, _) => DataType::Struct(fields.clone()), + Self::Map(arr) => { + let map_arr = arr + .as_any() + .downcast_ref::() + .expect("Map scalar must contain MapArray"); + DataType::Map(map_arr.entries().fields().first().unwrap().clone(), false) + } + Self::Dictionary(key_type, value) => { + DataType::Dictionary(key_type.clone(), Box::new(value.data_type())) + } + } + } + + /// Returns `true` if this scalar is null. + pub fn is_null(&self) -> bool { + match self { + Self::Null => true, + Self::Boolean(v) => v.is_none(), + Self::Int8(v) => v.is_none(), + Self::Int16(v) => v.is_none(), + Self::Int32(v) => v.is_none(), + Self::Int64(v) => v.is_none(), + Self::UInt8(v) => v.is_none(), + Self::UInt16(v) => v.is_none(), + Self::UInt32(v) => v.is_none(), + Self::UInt64(v) => v.is_none(), + Self::Float16(v) => v.is_none(), + Self::Float32(v) => v.is_none(), + Self::Float64(v) => v.is_none(), + Self::Decimal128(v, _, _) => v.is_none(), + Self::Decimal256(v, _, _) => v.is_none(), + Self::Utf8(v) => v.is_none(), + Self::LargeUtf8(v) => v.is_none(), + Self::Utf8View(v) => v.is_none(), + Self::Binary(v) => v.is_none(), + Self::LargeBinary(v) => v.is_none(), + Self::BinaryView(v) => v.is_none(), + Self::FixedSizeBinary(_, v) => v.is_none(), + Self::Date32(v) => v.is_none(), + Self::Date64(v) => v.is_none(), + Self::Time32(v, _) => v.is_none(), + Self::Time64(v, _) => v.is_none(), + Self::Timestamp(v, _, _) => v.is_none(), + Self::Duration(v, _) => v.is_none(), + Self::IntervalYearMonth(v) => v.is_none(), + Self::IntervalDayTime(v) => v.is_none(), + Self::IntervalMonthDayNano(v) => v.is_none(), + Self::List(arr) => arr.is_null(0), + Self::LargeList(arr) => arr.is_null(0), + Self::FixedSizeList(arr) => arr.is_null(0), + Self::Struct(_, values) => values.is_empty(), + Self::Map(arr) => arr.is_null(0), + Self::Dictionary(_, v) => v.is_null(), + } + } + + /// Returns an estimate of the memory size in bytes. + pub fn size(&self) -> usize { + std::mem::size_of::() + + match self { + Self::Null + | Self::Boolean(_) + | Self::Int8(_) + | Self::Int16(_) + | Self::Int32(_) + | Self::Int64(_) + | Self::UInt8(_) + | Self::UInt16(_) + | Self::UInt32(_) + | Self::UInt64(_) + | Self::Float16(_) + | Self::Float32(_) + | Self::Float64(_) + | Self::Decimal128(_, _, _) + | Self::Decimal256(_, _, _) + | Self::Date32(_) + | Self::Date64(_) + | Self::Time32(_, _) + | Self::Time64(_, _) + | Self::Timestamp(_, _, _) + | Self::Duration(_, _) + | Self::IntervalYearMonth(_) + | Self::IntervalDayTime(_) + | Self::IntervalMonthDayNano(_) => 0, + Self::Utf8(v) + | Self::LargeUtf8(v) + | Self::Utf8View(v) + | Self::Binary(v) + | Self::LargeBinary(v) + | Self::BinaryView(v) + | Self::FixedSizeBinary(_, v) => v.as_ref().map(|b| b.len()).unwrap_or(0), + Self::List(arr) + | Self::LargeList(arr) + | Self::FixedSizeList(arr) + | Self::Map(arr) => arr.get_array_memory_size(), + Self::Struct(fields, values) => { + fields.iter().map(|f| f.size()).sum::() + + values.iter().map(|v| v.size()).sum::() + } + Self::Dictionary(_, v) => v.size(), + } + } + + /// Creates a null scalar for the given data type. + pub fn null_for_type(data_type: &DataType) -> Self { + match data_type { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean(None), + DataType::Int8 => Self::Int8(None), + DataType::Int16 => Self::Int16(None), + DataType::Int32 => Self::Int32(None), + DataType::Int64 => Self::Int64(None), + DataType::UInt8 => Self::UInt8(None), + DataType::UInt16 => Self::UInt16(None), + DataType::UInt32 => Self::UInt32(None), + DataType::UInt64 => Self::UInt64(None), + DataType::Float16 => Self::Float16(None), + DataType::Float32 => Self::Float32(None), + DataType::Float64 => Self::Float64(None), + DataType::Decimal128(p, s) => Self::Decimal128(None, *p, *s), + DataType::Decimal256(p, s) => Self::Decimal256(None, *p, *s), + DataType::Utf8 => Self::Utf8(None), + DataType::LargeUtf8 => Self::LargeUtf8(None), + DataType::Utf8View => Self::Utf8View(None), + DataType::Binary => Self::Binary(None), + DataType::LargeBinary => Self::LargeBinary(None), + DataType::BinaryView => Self::BinaryView(None), + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(*size, None), + DataType::Date32 => Self::Date32(None), + DataType::Date64 => Self::Date64(None), + DataType::Time32(unit) => Self::Time32(None, *unit), + DataType::Time64(unit) => Self::Time64(None, *unit), + DataType::Timestamp(unit, tz) => Self::Timestamp(None, *unit, tz.clone()), + DataType::Duration(unit) => Self::Duration(None, *unit), + DataType::Interval(IntervalUnit::YearMonth) => Self::IntervalYearMonth(None), + DataType::Interval(IntervalUnit::DayTime) => Self::IntervalDayTime(None), + DataType::Interval(IntervalUnit::MonthDayNano) => Self::IntervalMonthDayNano(None), + DataType::List(_) => { + let empty = arrow_array::new_null_array(data_type, 1); + Self::List(empty) + } + DataType::LargeList(_) => { + let empty = arrow_array::new_null_array(data_type, 1); + Self::LargeList(empty) + } + DataType::FixedSizeList(_, _) => { + let empty = arrow_array::new_null_array(data_type, 1); + Self::FixedSizeList(empty) + } + DataType::Struct(fields) => { + let values = fields + .iter() + .map(|f| Self::null_for_type(f.data_type())) + .collect(); + Self::Struct(fields.clone(), values) + } + DataType::Map(_, _) => { + let empty = arrow_array::new_null_array(data_type, 1); + Self::Map(empty) + } + DataType::Dictionary(key_type, value_type) => { + Self::Dictionary(key_type.clone(), Box::new(Self::null_for_type(value_type))) + } + _ => Self::Null, + } + } +}