Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import static io.netty.util.internal.PlatformDependent.getInt;
import static io.netty.util.internal.PlatformDependent.getLong;

import org.apache.arrow.memory.BoundsChecking;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.DataSizeRoundingUtil;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;

import io.netty.buffer.ArrowBuf;
import io.netty.util.internal.PlatformDependent;

/**
* Helper class for performing generic operations on a bit vector buffer.
Expand Down Expand Up @@ -319,4 +321,83 @@ static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) {
currentByte |= bitMask;
data.setByte(byteIndex, currentByte);
}

/**
* Concat two validity buffers.
* @param input1 the first validity buffer.
* @param numBits1 the number of bits in the first validity buffer.
* @param input2 the second validity buffer.
* @param numBits2 the number of bits in the second validity buffer.
* @param output the output validity buffer. It can be the same one as the first input.
* The caller must make sure the output buffer has enough capacity.
*/
public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) {
int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1);
int numBytes2 = DataSizeRoundingUtil.divideBy8Ceil(numBits2);
int numBytesOut = DataSizeRoundingUtil.divideBy8Ceil(numBits1 + numBits2);

if (BoundsChecking.BOUNDS_CHECKING_ENABLED) {
output.checkBytes(0, numBytesOut);
}

// copy the first bit set
if (input1 != output) {
PlatformDependent.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1);
}

if (bitIndex(numBits1) == 0) {
// The number of bits for the first bit set is a multiple of 8, so the boundary is at byte boundary.
// For this case, we have a shortcut to copy all bytes from the second set after the byte boundary.
PlatformDependent.copyMemory(input2.memoryAddress(), output.memoryAddress() + numBytes1, numBytes2);
return;
}

// the number of bits to fill a full byte after the first input is processed
int numBitsToFill = 8 - bitIndex(numBits1);

// mask to clear high bits
int mask = (1 << (8 - numBitsToFill)) - 1;

int numFullBytes = numBits2 / 8;

int prevByte = output.getByte(numBytes1 - 1) & mask;
for (int i = 0; i < numFullBytes; i++) {
int curByte = input2.getByte(i) & 0xff;

// first fill the bits to a full byte
int byteToFill = (curByte << (8 - numBitsToFill)) & 0xff;
output.setByte(numBytes1 + i - 1, byteToFill | prevByte);

// fill remaining bits in the current byte
// note that it is also the previous byte for the next iteration
prevByte = curByte >>> numBitsToFill;
}

int lastOutputByte = prevByte;

// the number of extra bits for the second input, relative to full bytes
int numTrailingBits = bitIndex(numBits2);

if (numTrailingBits == 0) {
output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte);
return;
}

// process remaining bits from input2
int remByte = input2.getByte(numBytes2 - 1) & 0xff;

int byteToFill = remByte << (8 - numBitsToFill);
lastOutputByte |= byteToFill;

output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte);

if (numTrailingBits > numBitsToFill) {
// clear all bits for the last byte before writing
output.setByte(numBytes1 + numFullBytes, 0);

// some remaining bits cannot be filled in the previous byte
int leftByte = remByte >>> numBitsToFill;
output.setByte(numBytes1 + numFullBytes, leftByte);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,86 @@ public void testAllBitsSet() {
assertFalse(BitVectorHelper.checkAllBitsEqualTo(validityBuffer, bitLength, true));
}
}

@Test
public void testConcatBits() {
try (RootAllocator allocator = new RootAllocator(1024 * 1024)) {
try (ArrowBuf buf1 = allocator.buffer(1024);
ArrowBuf buf2 = allocator.buffer(1024);
ArrowBuf output = allocator.buffer(1024)) {

buf1.setZero(0, buf1.capacity());
buf2.setZero(0, buf2.capacity());

final int maxCount = 100;
for (int i = 0; i < maxCount; i++) {
if (i % 3 == 0) {
BitVectorHelper.setValidityBitToOne(buf1, i);
BitVectorHelper.setValidityBitToOne(buf2, i);
}
}

// test the case where the number of bits for both sets are multiples of 8.
concatAndVerify(buf1, 40, buf2, 48, output);

// only the number of bits in the first set is a multiple of 8
concatAndVerify(buf1, 32, buf2, 47, output);

// only the number of bits in the second set is a multiple of 8
concatAndVerify(buf1, 31, buf2, 48, output);

// neither set has a size that is a multiple of 8
concatAndVerify(buf1, 27, buf2, 52, output);

// the remaining bits in the second set is spread in two bytes
concatAndVerify(buf1, 31, buf2, 55, output);
}
}
}

@Test
public void testConcatBitsInPlace() {
try (RootAllocator allocator = new RootAllocator(1024 * 1024)) {
try (ArrowBuf buf1 = allocator.buffer(1024);
ArrowBuf buf2 = allocator.buffer(1024)) {

buf1.setZero(0, buf1.capacity());
buf2.setZero(0, buf2.capacity());

final int maxCount = 100;
for (int i = 0; i < maxCount; i++) {
if (i % 3 == 0) {
BitVectorHelper.setValidityBitToOne(buf1, i);
BitVectorHelper.setValidityBitToOne(buf2, i);
}
}

// test the case where the number of bits for both sets are multiples of 8.
concatAndVerify(buf1, 40, buf2, 48, buf1);

// only the number of bits in the first set is a multiple of 8
concatAndVerify(buf1, 32, buf2, 47, buf1);

// only the number of bits in the second set is a multiple of 8
concatAndVerify(buf1, 31, buf2, 48, buf1);

// neither set has a size that is a multiple of 8
concatAndVerify(buf1, 27, buf2, 52, buf1);

// the remaining bits in the second set is spread in two bytes
concatAndVerify(buf1, 31, buf2, 55, buf1);
}
}
}

private void concatAndVerify(ArrowBuf buf1, int count1, ArrowBuf buf2, int count2, ArrowBuf output) {
BitVectorHelper.concatBits(buf1, count1, buf2, count2, output);
int outputIdx = 0;
for (int i = 0; i < count1; i++, outputIdx++) {
assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf1, i));
}
for (int i = 0; i < count2; i++, outputIdx++) {
assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf2, i));
}
}
}