diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index 9c3e54ce3af..580adc72252 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -2619,6 +2619,91 @@ impl VariableFullZipDecoder { decoder } + fn slice_batch_data_and_rebase_offsets_typed( + data: &LanceBuffer, + offsets: &LanceBuffer, + ) -> Result<(LanceBuffer, LanceBuffer)> + where + T: arrow_buffer::ArrowNativeType + + Copy + + PartialOrd + + std::ops::Sub + + std::fmt::Display + + TryInto, + { + let offsets_slice = offsets.borrow_to_typed_slice::(); + let offsets_slice = offsets_slice.as_ref(); + if offsets_slice.is_empty() { + return Err(Error::Internal { + message: "Variable offsets cannot be empty".to_string(), + location: location!(), + }); + } + + let base = offsets_slice[0]; + let end = *offsets_slice.last().unwrap(); + if end < base { + return Err(Error::Internal { + message: format!( + "Invalid variable offsets: end ({end}) is less than base ({base})" + ), + location: location!(), + }); + } + + let data_start = base.try_into().map_err(|_| Error::Internal { + message: format!("Variable offset ({base}) does not fit into usize"), + location: location!(), + })?; + let data_end = end.try_into().map_err(|_| Error::Internal { + message: format!("Variable offset ({end}) does not fit into usize"), + location: location!(), + })?; + if data_end > data.len() { + return Err(Error::Internal { + message: format!( + "Invalid variable offsets: end ({data_end}) exceeds data len ({})", + data.len() + ), + location: location!(), + }); + } + + let mut rebased_offsets = Vec::with_capacity(offsets_slice.len()); + for &offset in offsets_slice { + if offset < base { + return Err(Error::Internal { + message: format!( + "Invalid variable offsets: offset ({offset}) is less than base ({base})" + ), + location: location!(), + }); + } + rebased_offsets.push(offset - base); + } + + let sliced_data = data.slice_with_length(data_start, data_end - data_start); + // Copy into a compact buffer so each output batch owns only what it references. + let sliced_data = LanceBuffer::copy_slice(&sliced_data); + let rebased_offsets = LanceBuffer::reinterpret_vec(rebased_offsets); + Ok((sliced_data, rebased_offsets)) + } + + fn slice_batch_data_and_rebase_offsets( + data: &LanceBuffer, + offsets: &LanceBuffer, + bits_per_offset: u8, + ) -> Result<(LanceBuffer, LanceBuffer)> { + match bits_per_offset { + 32 => Self::slice_batch_data_and_rebase_offsets_typed::(data, offsets), + 64 => Self::slice_batch_data_and_rebase_offsets_typed::(data, offsets), + _ => Err(Error::Internal { + message: format!("Unsupported bits_per_offset={bits_per_offset}"), + location: location!(), + }), + } + } + unsafe fn parse_length(data: &[u8], bits_per_offset: u8) -> u64 { match bits_per_offset { 8 => *data.get_unchecked(0) as u64, @@ -2746,20 +2831,14 @@ impl StructuralPageDecoder for VariableFullZipDecoder { let start = self.current_idx; let end = start + num_rows as usize; - // This might seem a little peculiar. We are returning the entire data for every single - // batch. This is because the offsets are relative to the start of the data. In other words - // imagine we have a data buffer that is 100 bytes long and the offsets are [0, 10, 20, 30, 40] - // and we return in batches of two. The second set of offsets will be [20, 30, 40]. - // - // So either we pay for a copy to normalize the offsets or we just return the entire data buffer - // which is slightly cheaper. - let data = self.data.clone(); - let offset_start = self.offset_starts[start]; let offset_end = self.offset_starts[end] + (self.bits_per_offset as usize / 8); let offsets = self .offsets .slice_with_length(offset_start, offset_end - offset_start); + // Keep each batch's variable data buffer bounded to the selected rows. + let (data, offsets) = + Self::slice_batch_data_and_rebase_offsets(&self.data, &offsets, self.bits_per_offset)?; let repdef_start = self.repdef_starts[start]; let repdef_end = self.repdef_starts[end]; @@ -5067,8 +5146,9 @@ mod tests { ChunkInstructions, DataBlock, DecodeMiniBlockTask, FixedPerValueDecompressor, FixedWidthDataBlock, FullZipCacheableState, FullZipDecodeDetails, FullZipRepIndexDetails, FullZipScheduler, MiniBlockRepIndex, PerValueDecompressor, PreambleAction, - StructuralPageScheduler, + StructuralPageScheduler, VariableFullZipDecoder, }; + use crate::buffer::LanceBuffer; use crate::compression::DefaultDecompressionStrategy; use crate::constants::{STRUCTURAL_ENCODING_META_KEY, STRUCTURAL_ENCODING_MINIBLOCK}; use crate::data::BlockInfo; @@ -5465,6 +5545,44 @@ mod tests { check(2..3, 2..4, 5..7); } + #[test] + fn test_slice_batch_data_and_rebase_offsets_u32() { + let data = LanceBuffer::copy_slice(b"0123456789abcdefghij"); + let offsets = LanceBuffer::reinterpret_vec(vec![6_u32, 8_u32, 8_u32, 12_u32]); + + let (sliced_data, normalized_offsets) = + VariableFullZipDecoder::slice_batch_data_and_rebase_offsets(&data, &offsets, 32) + .unwrap(); + + assert_eq!(sliced_data.as_ref(), b"6789ab"); + let normalized = normalized_offsets.borrow_to_typed_slice::(); + assert_eq!(normalized.as_ref(), &[0, 2, 2, 6]); + } + + #[test] + fn test_slice_batch_data_and_rebase_offsets_u64() { + let data = LanceBuffer::copy_slice(b"abcdefghijklmnopqrstuvwxyz"); + let offsets = LanceBuffer::reinterpret_vec(vec![10_u64, 12_u64, 16_u64, 20_u64]); + + let (sliced_data, normalized_offsets) = + VariableFullZipDecoder::slice_batch_data_and_rebase_offsets(&data, &offsets, 64) + .unwrap(); + + assert_eq!(sliced_data.as_ref(), b"klmnopqrst"); + let normalized = normalized_offsets.borrow_to_typed_slice::(); + assert_eq!(normalized.as_ref(), &[0, 2, 6, 10]); + } + + #[test] + fn test_slice_batch_data_and_rebase_offsets_rejects_invalid_offsets() { + let data = LanceBuffer::copy_slice(b"abcd"); + let offsets = LanceBuffer::reinterpret_vec(vec![3_u32, 2_u32]); + + let err = VariableFullZipDecoder::slice_batch_data_and_rebase_offsets(&data, &offsets, 32) + .expect_err("offset end before start should error"); + assert!(err.to_string().contains("less than base")); + } + #[test] fn test_schedule_instructions() { // Convert repetition index to bytes for testing