From 591cf74e4728cace34d9a4a9683dcf14cdef97d9 Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Wed, 6 Nov 2019 18:54:49 +0800 Subject: [PATCH 1/4] [ARROW-7072][Java] Support concating validity bits efficiently --- .../apache/arrow/vector/BitVectorHelper.java | 71 ++++++++++ .../arrow/vector/TestBitVectorHelper.java | 130 ++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java index cd16f720ecf..5a70ef59e5a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java @@ -21,11 +21,13 @@ import static io.netty.util.internal.PlatformDependent.getInt; import static io.netty.util.internal.PlatformDependent.getLong; +import org.apache.arrow.memory.BoundsChecking; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.DataSizeRoundingUtil; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import io.netty.buffer.ArrowBuf; +import io.netty.util.internal.PlatformDependent; /** * Helper class for performing generic operations on a bit vector buffer. @@ -319,4 +321,73 @@ static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) { currentByte |= bitMask; data.setByte(byteIndex, currentByte); } + + /** + * Concat two validity buffers. + * @param input1 the first validity buffer. + * @param numBits1 the number of bits in the first validity buffer. + * @param input2 the second validity buffer. + * @param numBits2 the number of bits in the second validity buffer. + * @param output the ouput validity buffer. It can be the same one as the first input. + */ + public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) { + int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1); + int numBytes2 = DataSizeRoundingUtil.divideBy8Ceil(numBits2); + int numBytesOut = DataSizeRoundingUtil.divideBy8Ceil(numBits1 + numBits2); + + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + output.checkBytes(0, numBytesOut); + } + + // copy the first bit set + if (input1 != output) { + PlatformDependent.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1); + } + output.setZero(numBytes1, numBytesOut); + + if (numBits1 % 8 == 0) { + PlatformDependent.copyMemory(input2.memoryAddress(), output.memoryAddress() + numBytes1, numBytes2); + return; + } + + // the number of bits to fill a full byte after the first input is processed + int numBitsToFill = 8 - (numBits1 % 8); + + // the number of extra bits for the second input, relative to full bytes + int numRemainingBits = numBits2 % 8; + + int numFullBytes = numBits2 / 8; + + for (int i = 0; i < numFullBytes; i++) { + int prevByte = output.getByte(numBytes1 + i - 1) & 0xff; + byte curByte = input2.getByte(i); + + // first fill the bits to a full byte + int byteToFill = (curByte & 0xff) << (8 - numBitsToFill); + output.setByte(numBytes1 + i - 1, byteToFill | prevByte); + + // fill remaining bits in the current byte + int remByte = (curByte & 0xff) >>> numBitsToFill; + output.setByte(numBytes1 + i, remByte); + } + + // process remaining bits from input2 + if (numRemainingBits > 0) { + byte remByte = input2.getByte(numBytes2 - 1); + + // the last byte to fill in the output + byte curOutputByte = output.getByte(numBytes1 + numFullBytes - 1); + byte byteToFill = (byte) ((remByte & 0xff) << (8 - numBitsToFill)); + output.setByte(numBytes1 + numFullBytes - 1, curOutputByte | byteToFill); + + if (numRemainingBits > numBitsToFill) { + // some remaining bits cannot be filled in the previous byte + byte leftByte = (byte) ((remByte & 0xff) >>> numBitsToFill); + + // clear high bits + int mask = (1 << (numBitsToFill - numRemainingBits)) - 1; + output.setByte(numBytes1 + numFullBytes, leftByte & mask); + } + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java index 9d52427f536..e833e30e490 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java @@ -149,4 +149,134 @@ public void testAllBitsSet() { assertFalse(BitVectorHelper.checkAllBitsEqualTo(validityBuffer, bitLength, true)); } } + + @Test + public void testConcatBits() { + try (RootAllocator allocator = new RootAllocator(1024 * 1024)) { + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024); + ArrowBuf output = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int count = 100; + for (int i = 0; i < count; i++) { + if (i % 3 == 0) { + BitVectorHelper.setValidityBitToOne(buf1, i); + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + + BitVectorHelper.concatBits(buf1, count, buf2, count, output); + + // validate results + for (int i = 0; i < count * 2; i++) { + int result = BitVectorHelper.get(output, i); + if (i < count) { + assertEquals(i % 3 == 0 ? 1 : 0, result); + } else { + assertEquals((i - count) % 3 == 0 ? 1 : 0, result); + } + } + } + + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024); + ArrowBuf output = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int count1 = 100; + final int count2 = 102; + for (int i = 0; i < count1 || i < count2; i++) { + if (i % 3 != 0) { + if (i < count1) { + BitVectorHelper.setValidityBitToOne(buf1, i); + } + if (i < count2) { + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + } + + BitVectorHelper.concatBits(buf1, count1, buf2, count2, output); + + // validate results + for (int i = 0; i < count1 + count2; i++) { + int result = BitVectorHelper.get(output, i); + if (i < count1) { + assertEquals(i % 3 != 0 ? 1 : 0, result); + } else { + assertEquals((i - count1) % 3 != 0 ? 1 : 0, result); + } + } + } + } + } + + @Test + public void testConcatBitsInPlace() { + try (RootAllocator allocator = new RootAllocator(1024 * 1024)) { + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int count = 100; + for (int i = 0; i < count; i++) { + if (i % 3 == 0) { + BitVectorHelper.setValidityBitToOne(buf1, i); + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + + BitVectorHelper.concatBits(buf1, count, buf2, count, buf1); + + // validate results + for (int i = 0; i < count * 2; i++) { + int result = BitVectorHelper.get(buf1, i); + if (i < count) { + assertEquals(i % 3 == 0 ? 1 : 0, result); + } else { + assertEquals((i - count) % 3 == 0 ? 1 : 0, result); + } + } + } + + try (ArrowBuf buf1 = allocator.buffer(1024); + ArrowBuf buf2 = allocator.buffer(1024)) { + + buf1.setZero(0, buf1.capacity()); + buf2.setZero(0, buf2.capacity()); + + final int count1 = 99; + final int count2 = 102; + for (int i = 0; i < count1 || i < count2; i++) { + if (i % 3 != 0) { + if (i < count1) { + BitVectorHelper.setValidityBitToOne(buf1, i); + } + if (i < count2) { + BitVectorHelper.setValidityBitToOne(buf2, i); + } + } + } + + BitVectorHelper.concatBits(buf1, count1, buf2, count2, buf1); + + // validate results + for (int i = 0; i < count1 + count2; i++) { + int result = BitVectorHelper.get(buf1, i); + if (i < count1) { + assertEquals(i % 3 != 0 ? 1 : 0, result); + } else { + assertEquals((i - count1) % 3 != 0 ? 1 : 0, result); + } + } + } + } + } } From 4f1b9af8bb41ae01512729b69c03e542b4ac50db Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Tue, 12 Nov 2019 18:38:15 +0800 Subject: [PATCH 2/4] [ARROW-7072][Java] Resolve comments --- .../apache/arrow/vector/BitVectorHelper.java | 45 ++++--- .../arrow/vector/TestBitVectorHelper.java | 117 ++++++------------ 2 files changed, 62 insertions(+), 100 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java index 5a70ef59e5a..d35b9613bda 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java @@ -343,27 +343,30 @@ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, in if (input1 != output) { PlatformDependent.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1); } - output.setZero(numBytes1, numBytesOut); - if (numBits1 % 8 == 0) { + if (bitIndex(numBits1) == 0) { + // The number of bits for the first bit set is a multiple of 8, so the boundary is at byte boundary. + // For this case, we have a shortcut to copy all bytes from the second set after the byte boundary. PlatformDependent.copyMemory(input2.memoryAddress(), output.memoryAddress() + numBytes1, numBytes2); return; } // the number of bits to fill a full byte after the first input is processed - int numBitsToFill = 8 - (numBits1 % 8); - - // the number of extra bits for the second input, relative to full bytes - int numRemainingBits = numBits2 % 8; + int numBitsToFill = 8 - bitIndex(numBits1); + int mask = (1 << (8 - numBitsToFill)) - 1; int numFullBytes = numBits2 / 8; for (int i = 0; i < numFullBytes; i++) { int prevByte = output.getByte(numBytes1 + i - 1) & 0xff; - byte curByte = input2.getByte(i); + + // clear high bits + prevByte &= mask; + + int curByte = input2.getByte(i) & 0xff; // first fill the bits to a full byte - int byteToFill = (curByte & 0xff) << (8 - numBitsToFill); + int byteToFill = ((curByte & 0xff) << (8 - numBitsToFill)) & 0xff; output.setByte(numBytes1 + i - 1, byteToFill | prevByte); // fill remaining bits in the current byte @@ -371,23 +374,29 @@ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, in output.setByte(numBytes1 + i, remByte); } + // clear high bits for the previous byte, as it may be the last byte + int curOutputByte = output.getByte(numBytes1 + numFullBytes - 1) & 0xff; + curOutputByte &= mask; + + // the number of extra bits for the second input, relative to full bytes + int numRemainingBits = bitIndex(numBits2); + // process remaining bits from input2 if (numRemainingBits > 0) { - byte remByte = input2.getByte(numBytes2 - 1); + int remByte = input2.getByte(numBytes2 - 1) & 0xff; - // the last byte to fill in the output - byte curOutputByte = output.getByte(numBytes1 + numFullBytes - 1); - byte byteToFill = (byte) ((remByte & 0xff) << (8 - numBitsToFill)); - output.setByte(numBytes1 + numFullBytes - 1, curOutputByte | byteToFill); + int byteToFill = (remByte & 0xff) << (8 - numBitsToFill); + curOutputByte |= byteToFill; if (numRemainingBits > numBitsToFill) { - // some remaining bits cannot be filled in the previous byte - byte leftByte = (byte) ((remByte & 0xff) >>> numBitsToFill); + // clear all bits for the last byte before writing + output.setByte(numBytes1 + numFullBytes, 0); - // clear high bits - int mask = (1 << (numBitsToFill - numRemainingBits)) - 1; - output.setByte(numBytes1 + numFullBytes, leftByte & mask); + // some remaining bits cannot be filled in the previous byte + int leftByte = (byte) ((remByte & 0xff) >>> numBitsToFill) & 0xff; + output.setByte(numBytes1 + numFullBytes, leftByte); } } + output.setByte(numBytes1 + numFullBytes - 1, curOutputByte); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java index e833e30e490..880b7de394b 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java @@ -160,58 +160,28 @@ public void testConcatBits() { buf1.setZero(0, buf1.capacity()); buf2.setZero(0, buf2.capacity()); - final int count = 100; - for (int i = 0; i < count; i++) { + final int maxCount = 100; + for (int i = 0; i < maxCount; i++) { if (i % 3 == 0) { BitVectorHelper.setValidityBitToOne(buf1, i); BitVectorHelper.setValidityBitToOne(buf2, i); } } - BitVectorHelper.concatBits(buf1, count, buf2, count, output); + // test the case where the number of bits for both sets are multiples of 8. + concatAndVerify(buf1, 40, buf2, 48, output); - // validate results - for (int i = 0; i < count * 2; i++) { - int result = BitVectorHelper.get(output, i); - if (i < count) { - assertEquals(i % 3 == 0 ? 1 : 0, result); - } else { - assertEquals((i - count) % 3 == 0 ? 1 : 0, result); - } - } - } + // only the number of bits in the first set is a multiple of 8 + concatAndVerify(buf1, 32, buf2, 47, output); - try (ArrowBuf buf1 = allocator.buffer(1024); - ArrowBuf buf2 = allocator.buffer(1024); - ArrowBuf output = allocator.buffer(1024)) { + // only the number of bits in the second set is a multiple of 8 + concatAndVerify(buf1, 31, buf2, 48, output); - buf1.setZero(0, buf1.capacity()); - buf2.setZero(0, buf2.capacity()); + // neither set has a size that is a multiple of 8 + concatAndVerify(buf1, 27, buf2, 52, output); - final int count1 = 100; - final int count2 = 102; - for (int i = 0; i < count1 || i < count2; i++) { - if (i % 3 != 0) { - if (i < count1) { - BitVectorHelper.setValidityBitToOne(buf1, i); - } - if (i < count2) { - BitVectorHelper.setValidityBitToOne(buf2, i); - } - } - } - - BitVectorHelper.concatBits(buf1, count1, buf2, count2, output); - - // validate results - for (int i = 0; i < count1 + count2; i++) { - int result = BitVectorHelper.get(output, i); - if (i < count1) { - assertEquals(i % 3 != 0 ? 1 : 0, result); - } else { - assertEquals((i - count1) % 3 != 0 ? 1 : 0, result); - } - } + // the remaining bits in the second set is spread in two bytes + concatAndVerify(buf1, 31, buf2, 55, output); } } } @@ -225,57 +195,40 @@ public void testConcatBitsInPlace() { buf1.setZero(0, buf1.capacity()); buf2.setZero(0, buf2.capacity()); - final int count = 100; - for (int i = 0; i < count; i++) { + final int maxCount = 100; + for (int i = 0; i < maxCount; i++) { if (i % 3 == 0) { BitVectorHelper.setValidityBitToOne(buf1, i); BitVectorHelper.setValidityBitToOne(buf2, i); } } - BitVectorHelper.concatBits(buf1, count, buf2, count, buf1); + // test the case where the number of bits for both sets are multiples of 8. + concatAndVerify(buf1, 40, buf2, 48, buf1); - // validate results - for (int i = 0; i < count * 2; i++) { - int result = BitVectorHelper.get(buf1, i); - if (i < count) { - assertEquals(i % 3 == 0 ? 1 : 0, result); - } else { - assertEquals((i - count) % 3 == 0 ? 1 : 0, result); - } - } - } + // only the number of bits in the first set is a multiple of 8 + concatAndVerify(buf1, 32, buf2, 47, buf1); - try (ArrowBuf buf1 = allocator.buffer(1024); - ArrowBuf buf2 = allocator.buffer(1024)) { - - buf1.setZero(0, buf1.capacity()); - buf2.setZero(0, buf2.capacity()); + // only the number of bits in the second set is a multiple of 8 + concatAndVerify(buf1, 31, buf2, 48, buf1); - final int count1 = 99; - final int count2 = 102; - for (int i = 0; i < count1 || i < count2; i++) { - if (i % 3 != 0) { - if (i < count1) { - BitVectorHelper.setValidityBitToOne(buf1, i); - } - if (i < count2) { - BitVectorHelper.setValidityBitToOne(buf2, i); - } - } - } + // neither set has a size that is a multiple of 8 + concatAndVerify(buf1, 27, buf2, 52, buf1); - BitVectorHelper.concatBits(buf1, count1, buf2, count2, buf1); + // the remaining bits in the second set is spread in two bytes + concatAndVerify(buf1, 31, buf2, 55, buf1); + } + } + } - // validate results - for (int i = 0; i < count1 + count2; i++) { - int result = BitVectorHelper.get(buf1, i); - if (i < count1) { - assertEquals(i % 3 != 0 ? 1 : 0, result); - } else { - assertEquals((i - count1) % 3 != 0 ? 1 : 0, result); - } - } + private void concatAndVerify(ArrowBuf buf1, int count1, ArrowBuf buf2, int count2, ArrowBuf output) { + BitVectorHelper.concatBits(buf1, count1, buf2, count2, output); + for (int i = 0; i < count1 + count2; i++) { + int result = BitVectorHelper.get(output, i); + if (i < count1) { + assertEquals(i % 3 == 0 ? 1 : 0, result); + } else { + assertEquals((i - count1) % 3 == 0 ? 1 : 0, result); } } } From d7bcfaaa3e968399e1bda3ba44f4c739385fbb95 Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Fri, 15 Nov 2019 20:44:04 +0800 Subject: [PATCH 3/4] [ARROW-7072][Java] Further improve the performance --- .../apache/arrow/vector/BitVectorHelper.java | 25 ++++++++----------- .../arrow/vector/TestBitVectorHelper.java | 13 +++++----- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java index d35b9613bda..6c3e653e709 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java @@ -353,30 +353,27 @@ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, in // the number of bits to fill a full byte after the first input is processed int numBitsToFill = 8 - bitIndex(numBits1); + + // mask to clear high bits int mask = (1 << (8 - numBitsToFill)) - 1; int numFullBytes = numBits2 / 8; + int prevByte = output.getByte(numBytes1 - 1) & mask; for (int i = 0; i < numFullBytes; i++) { - int prevByte = output.getByte(numBytes1 + i - 1) & 0xff; - - // clear high bits - prevByte &= mask; - int curByte = input2.getByte(i) & 0xff; // first fill the bits to a full byte - int byteToFill = ((curByte & 0xff) << (8 - numBitsToFill)) & 0xff; + int byteToFill = (curByte << (8 - numBitsToFill)) & 0xff; output.setByte(numBytes1 + i - 1, byteToFill | prevByte); // fill remaining bits in the current byte - int remByte = (curByte & 0xff) >>> numBitsToFill; - output.setByte(numBytes1 + i, remByte); + // note that it is also the previous byte for the next iteration + prevByte = curByte >>> numBitsToFill; } // clear high bits for the previous byte, as it may be the last byte - int curOutputByte = output.getByte(numBytes1 + numFullBytes - 1) & 0xff; - curOutputByte &= mask; + int lastOutputByte = prevByte; // the number of extra bits for the second input, relative to full bytes int numRemainingBits = bitIndex(numBits2); @@ -385,18 +382,18 @@ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, in if (numRemainingBits > 0) { int remByte = input2.getByte(numBytes2 - 1) & 0xff; - int byteToFill = (remByte & 0xff) << (8 - numBitsToFill); - curOutputByte |= byteToFill; + int byteToFill = remByte << (8 - numBitsToFill); + lastOutputByte |= byteToFill; if (numRemainingBits > numBitsToFill) { // clear all bits for the last byte before writing output.setByte(numBytes1 + numFullBytes, 0); // some remaining bits cannot be filled in the previous byte - int leftByte = (byte) ((remByte & 0xff) >>> numBitsToFill) & 0xff; + int leftByte = remByte >>> numBitsToFill; output.setByte(numBytes1 + numFullBytes, leftByte); } } - output.setByte(numBytes1 + numFullBytes - 1, curOutputByte); + output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java index 880b7de394b..2988932892a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java @@ -223,13 +223,12 @@ public void testConcatBitsInPlace() { private void concatAndVerify(ArrowBuf buf1, int count1, ArrowBuf buf2, int count2, ArrowBuf output) { BitVectorHelper.concatBits(buf1, count1, buf2, count2, output); - for (int i = 0; i < count1 + count2; i++) { - int result = BitVectorHelper.get(output, i); - if (i < count1) { - assertEquals(i % 3 == 0 ? 1 : 0, result); - } else { - assertEquals((i - count1) % 3 == 0 ? 1 : 0, result); - } + int outputIdx = 0; + for (int i = 0; i < count1; i++, outputIdx++) { + assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf1, i)); + } + for (int i = 0; i < count2; i++, outputIdx++) { + assertEquals(BitVectorHelper.get(output, outputIdx), BitVectorHelper.get(buf2, i)); } } } From 70c11737f84fcc627a55d45557c0644be45f0861 Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Thu, 21 Nov 2019 15:00:15 +0800 Subject: [PATCH 4/4] [ARROW-7072][Java] Resolve comments --- .../apache/arrow/vector/BitVectorHelper.java | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java index 6c3e653e709..88813019321 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVectorHelper.java @@ -328,7 +328,8 @@ static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) { * @param numBits1 the number of bits in the first validity buffer. * @param input2 the second validity buffer. * @param numBits2 the number of bits in the second validity buffer. - * @param output the ouput validity buffer. It can be the same one as the first input. + * @param output the output validity buffer. It can be the same one as the first input. + * The caller must make sure the output buffer has enough capacity. */ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) { int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1); @@ -372,28 +373,31 @@ public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, in prevByte = curByte >>> numBitsToFill; } - // clear high bits for the previous byte, as it may be the last byte int lastOutputByte = prevByte; // the number of extra bits for the second input, relative to full bytes - int numRemainingBits = bitIndex(numBits2); + int numTrailingBits = bitIndex(numBits2); + + if (numTrailingBits == 0) { + output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); + return; + } // process remaining bits from input2 - if (numRemainingBits > 0) { - int remByte = input2.getByte(numBytes2 - 1) & 0xff; + int remByte = input2.getByte(numBytes2 - 1) & 0xff; - int byteToFill = remByte << (8 - numBitsToFill); - lastOutputByte |= byteToFill; + int byteToFill = remByte << (8 - numBitsToFill); + lastOutputByte |= byteToFill; - if (numRemainingBits > numBitsToFill) { - // clear all bits for the last byte before writing - output.setByte(numBytes1 + numFullBytes, 0); + output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); - // some remaining bits cannot be filled in the previous byte - int leftByte = remByte >>> numBitsToFill; - output.setByte(numBytes1 + numFullBytes, leftByte); - } + if (numTrailingBits > numBitsToFill) { + // clear all bits for the last byte before writing + output.setByte(numBytes1 + numFullBytes, 0); + + // some remaining bits cannot be filled in the previous byte + int leftByte = remByte >>> numBitsToFill; + output.setByte(numBytes1 + numFullBytes, leftByte); } - output.setByte(numBytes1 + numFullBytes - 1, lastOutputByte); } }