diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java index cd16f720ecf..88813019321 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java @@ -21,11 +21,13 @@ import static io.netty.util.internal.PlatformDependent.getInt; import static io.netty.util.internal.PlatformDependent.getLong; +import org.apache.arrow.memory.BoundsChecking; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.DataSizeRoundingUtil; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import io.netty.buffer.ArrowBuf; +import io.netty.util.internal.PlatformDependent; /** * Helper class for performing generic operations on a bit vector buffer. @@ -319,4 +321,83 @@ static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) { currentByte |= bitMask; data.setByte(byteIndex, currentByte); } + + /** + * Concat two validity buffers. + * @param input1 the first validity buffer. + * @param numBits1 the number of bits in the first validity buffer. + * @param input2 the second validity buffer. + * @param numBits2 the number of bits in the second validity buffer. + * @param output the output validity buffer. It can be the same one as the first input. + * The caller must make sure the output buffer has enough capacity. + */ + public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) { + int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1); + int numBytes2 = DataSizeRoundingUtil.divideBy8Ceil(numBits2); + int numBytesOut = DataSizeRoundingUtil.divideBy8Ceil(numBits1 + numBits2); + + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + output.checkBytes(0, numBytesOut); + } + + // copy the first bit set + if (input1 != output) { + PlatformDependent.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1); + } + + if (bitIndex(numBits1) == 0) { + // The number of bits for the first bit set is a multiple of 8, so the boundary is at byte boundary. + // For this case, we have a shortcut to copy all bytes from the second set after the byte boundary. + PlatformDependent.copyMemory(input2.memoryAddress(), output.memoryAddress() + numBytes1, numBytes2); + return; + } + + // the number of bits to fill a full byte after the first input is processed + int numBitsToFill = 8 - bitIndex(numBits1); + + // mask to clear high bits + int mask = (1 << (8 - numBitsToFill)) - 1; + + int numFullBytes = numBits2 / 8; + + int prevByte = output.getByte(numBytes1 - 1) & mask; + for (int i = 0; i < numFullBytes; i++) { + int curByte = input2.getByte(i) & 0xff; + + // first fill the bits to a full byte + int byteToFill = (curByte << (8 - numBitsToFill)) & 0xff; + output.setByte(numBytes1 + i - 1, byteToFill | prevByte); + + // fill remaining bits in the current byte + // note that it is also the previous byte for the next iteration + prevByte = curByte >>> numBitsToFill; + } + + int lastOutputByte = prevByte; + + // the number of extra bits for the second input, relative to full bytes + int numTrailingBits = bitIndex(numBits2); + + if (numTrailingBits == 0) { + output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); + return; + } + + // process remaining bits from input2 + int remByte = input2.getByte(numBytes2 - 1) & 0xff; + + int byteToFill = remByte << (8 - numBitsToFill); + lastOutputByte |= byteToFill; + + output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); + + if (numTrailingBits > numBitsToFill) { + // clear all bits for the last byte before writing + output.setByte(numBytes1 + numFullBytes, 0); + + // some remaining bits cannot be filled in the previous byte + int leftByte = remByte >>> numBitsToFill; + output.setByte(numBytes1 + numFullBytes, leftByte); + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java index 9d52427f536..2988932892a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java @@ -149,4 +149,86 @@ public void testAllBitsSet() { assertFalse(BitVectorHelper.checkAllBitsEqualTo(validityBuffer, bitLength, true)); } } + + @Test + public void testConcatBits() { + try (RootAllocator allocator = new RootAllocator(1024 * 1024)) { + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024); + ArrowBuf output = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int maxCount = 100; + for (int i = 0; i < maxCount; i++) { + if (i % 3 == 0) { + BitVectorHelper.setValidityBitToOne(buf1, i); + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + + // test the case where the number of bits for both sets are multiples of 8. + concatAndVerify(buf1, 40, buf2, 48, output); + + // only the number of bits in the first set is a multiple of 8 + concatAndVerify(buf1, 32, buf2, 47, output); + + // only the number of bits in the second set is a multiple of 8 + concatAndVerify(buf1, 31, buf2, 48, output); + + // neither set has a size that is a multiple of 8 + concatAndVerify(buf1, 27, buf2, 52, output); + + // the remaining bits in the second set is spread in two bytes + concatAndVerify(buf1, 31, buf2, 55, output); + } + } + } + + @Test + public void testConcatBitsInPlace() { + try (RootAllocator allocator = new RootAllocator(1024 * 1024)) { + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int maxCount = 100; + for (int i = 0; i < maxCount; i++) { + if (i % 3 == 0) { + BitVectorHelper.setValidityBitToOne(buf1, i); + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + + // test the case where the number of bits for both sets are multiples of 8. + concatAndVerify(buf1, 40, buf2, 48, buf1); + + // only the number of bits in the first set is a multiple of 8 + concatAndVerify(buf1, 32, buf2, 47, buf1); + + // only the number of bits in the second set is a multiple of 8 + concatAndVerify(buf1, 31, buf2, 48, buf1); + + // neither set has a size that is a multiple of 8 + concatAndVerify(buf1, 27, buf2, 52, buf1); + + // the remaining bits in the second set is spread in two bytes + concatAndVerify(buf1, 31, buf2, 55, buf1); + } + } + } + + private void concatAndVerify(ArrowBuf buf1, int count1, ArrowBuf buf2, int count2, ArrowBuf output) { + BitVectorHelper.concatBits(buf1, count1, buf2, count2, output); + int outputIdx = 0; + for (int i = 0; i < count1; i++, outputIdx++) { + assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf1, i)); + } + for (int i = 0; i < count2; i++, outputIdx++) { + assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf2, i)); + } + } }