diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs index 3512ae34261..b4745168a9e 100644 --- a/rust/arrow/src/buffer.rs +++ b/rust/arrow/src/buffer.rs @@ -414,6 +414,7 @@ where let remainder_bytes = ceil(left_chunks.remainder_len(), 8); let rem = op(left_chunks.remainder_bits(), right_chunks.remainder_bits()); + // we are counting its starting from the least significant bit, to to_le_bytes should be correct let rem = &rem.to_le_bytes()[0..remainder_bytes]; result .write_all(rem) @@ -448,6 +449,7 @@ where let remainder_bytes = ceil(left_chunks.remainder_len(), 8); let rem = op(left_chunks.remainder_bits()); + // we are counting its starting from the least significant bit, to to_le_bytes should be correct let rem = &rem.to_le_bytes()[0..remainder_bytes]; result .write_all(rem) diff --git a/rust/arrow/src/util/bit_chunk_iterator.rs b/rust/arrow/src/util/bit_chunk_iterator.rs index df6caf9a9cd..ec10c6f1d41 100644 --- a/rust/arrow/src/util/bit_chunk_iterator.rs +++ b/rust/arrow/src/util/bit_chunk_iterator.rs @@ -22,8 +22,11 @@ use std::fmt::Debug; pub struct BitChunks<'a> { buffer: &'a Buffer, raw_data: *const u8, - offset: usize, + /// offset inside a byte, guaranteed to be between 0 and 7 (inclusive) + bit_offset: usize, + /// number of complete u64 chunks chunk_len: usize, + /// number of remaining bits, guaranteed to be between 0 and 63 (inclusive) remainder_len: usize, } @@ -32,11 +35,11 @@ impl<'a> BitChunks<'a> { assert!(ceil(offset + len, 8) <= buffer.len() * 8); let byte_offset = offset / 8; - let offset = offset % 8; + let bit_offset = offset % 8; let raw_data = unsafe { buffer.raw_data().add(byte_offset) }; - let chunk_bits = 64; + let chunk_bits = 8 * std::mem::size_of::(); let chunk_len = len / chunk_bits; let remainder_len = len & (chunk_bits - 1); @@ -44,7 +47,7 @@ impl<'a> BitChunks<'a> { BitChunks::<'a> { buffer: &buffer, raw_data, - offset, + bit_offset, chunk_len, remainder_len, } @@ -55,48 +58,52 @@ impl<'a> BitChunks<'a> { pub struct BitChunkIterator<'a> { buffer: &'a Buffer, raw_data: *const u8, - offset: usize, + bit_offset: usize, chunk_len: usize, index: usize, } impl<'a> BitChunks<'a> { + /// Returns the number of remaining bits, guaranteed to be between 0 and 63 (inclusive) #[inline] pub const fn remainder_len(&self) -> usize { self.remainder_len } + /// Returns the bitmask of remaining bits #[inline] pub fn remainder_bits(&self) -> u64 { let bit_len = self.remainder_len; if bit_len == 0 { 0 } else { - let byte_len = ceil(bit_len, 8); - - let mut bits = 0; - for i in 0..byte_len { - let byte = unsafe { - std::ptr::read( - self.raw_data - .add(self.chunk_len * std::mem::size_of::() + i), - ) - }; - bits |= (byte as u64) << (i * 8); - } + let bit_offset = self.bit_offset; + // number of bytes to read + // might be one more than sizeof(u64) if the offset is in the middle of a byte + let byte_len = ceil(bit_len + bit_offset, 8); + // pointer to remainder bytes after all complete chunks + let base = unsafe { + self.raw_data + .add(self.chunk_len * std::mem::size_of::()) + }; - let offset = self.offset as u64; + let mut bits = unsafe { std::ptr::read(base) } as u64 >> bit_offset; + for i in 1..byte_len { + let byte = unsafe { std::ptr::read(base.add(i)) }; + bits |= (byte as u64) << (i * 8 - bit_offset); + } - (bits >> offset) & ((1 << bit_len) - 1) + bits & ((1 << bit_len) - 1) } } + /// Returns an iterator over chunks of 64 bits represented as an u64 #[inline] pub const fn iter(&self) -> BitChunkIterator<'a> { BitChunkIterator::<'a> { buffer: self.buffer, raw_data: self.raw_data, - offset: self.offset, + bit_offset: self.bit_offset, chunk_len: self.chunk_len, index: 0, } @@ -117,31 +124,30 @@ impl Iterator for BitChunkIterator<'_> { #[inline] fn next(&mut self) -> Option { - if self.index >= self.chunk_len { + let index = self.index; + if index >= self.chunk_len { return None; } - // cast to *const u64 should be fine since we are using read_unaligned + // cast to *const u64 should be fine since we are using read_unaligned below #[allow(clippy::cast_ptr_alignment)] - let current = unsafe { - std::ptr::read_unaligned((self.raw_data as *const u64).add(self.index)) - }; + let raw_data = self.raw_data as *const u64; + + // bit-packed buffers are stored starting with the least-significant byte first + // so when reading as u64 on a big-endian machine, the bytes need to be swapped + let current = unsafe { std::ptr::read_unaligned(raw_data.add(index)).to_le() }; - let combined = if self.offset == 0 { + let combined = if self.bit_offset == 0 { current } else { - // cast to *const u64 should be fine since we are using read_unaligned - #[allow(clippy::cast_ptr_alignment)] - let next = unsafe { - std::ptr::read_unaligned( - (self.raw_data as *const u64).add(self.index + 1), - ) - }; - current >> self.offset - | (next & ((1 << self.offset) - 1)) << (64 - self.offset) + let next = + unsafe { std::ptr::read_unaligned(raw_data.add(index + 1)).to_le() }; + + current >> self.bit_offset + | (next & ((1 << self.bit_offset) - 1)) << (64 - self.bit_offset) }; - self.index += 1; + self.index = index + 1; Some(combined) } @@ -192,7 +198,6 @@ mod tests { let result = bitchunks.into_iter().collect::>(); - //assert_eq!(vec![0b00010000, 0b00100000, 0b01000000, 0b10000000, 0b00000000, 0b00000001, 0b00000010, 0b11110100], result); assert_eq!( vec![0b1111010000000010000000010000000010000000010000000010000000010000], result @@ -214,10 +219,39 @@ mod tests { let result = bitchunks.into_iter().collect::>(); - //assert_eq!(vec![0b00010000, 0b00100000, 0b01000000, 0b10000000, 0b00000000, 0b00000001, 0b00000010, 0b11110100], result); assert_eq!( vec![0b1111010000000010000000010000000010000000010000000010000000010000], result ); } + + #[test] + fn test_iter_unaligned_remainder_bits_across_bytes() { + let input: &[u8] = &[0b00111111, 0b11111100]; + let buffer: Buffer = Buffer::from(input); + + // remainder contains bits from both bytes + // result should be the highest 2 bits from first byte followed by lowest 5 bits of second bytes + let bitchunks = buffer.bit_chunks(6, 7); + + assert_eq!(7, bitchunks.remainder_len()); + assert_eq!(0b1110000, bitchunks.remainder_bits()); + } + + #[test] + fn test_iter_unaligned_remainder_bits_large() { + let input: &[u8] = &[ + 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, + 0b11111111, 0b00000000, 0b11111111, + ]; + let buffer: Buffer = Buffer::from(input); + + let bitchunks = buffer.bit_chunks(2, 63); + + assert_eq!(63, bitchunks.remainder_len()); + assert_eq!( + 0b1000000_00111111_11000000_00111111_11000000_00111111_11000000_00111111, + bitchunks.remainder_bits() + ); + } }