diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java index 15868721da9ea..7bd1d9257f84c 100644 --- a/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java +++ b/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java @@ -390,14 +390,25 @@ public static void writeDouble(double value, ByteBuffer buffer) { * Number of bytes needed to encode an integer in unsigned variable-length format. * * @param value The signed value + * + * @see #writeUnsignedVarint(int, DataOutput) */ public static int sizeOfUnsignedVarint(int value) { - int bytes = 1; - while ((value & 0xffffff80) != 0L) { - bytes += 1; - value >>>= 7; - } - return bytes; + // Protocol buffers varint encoding is variable length, with a minimum of 1 byte + // (for zero). The values themselves are not important. What's important here is + // any leading zero bits are dropped from output. We can use this leading zero + // count w/ fast intrinsic to calc the output length directly. + + // Test cases verify this matches the output for loop logic exactly. + + // return (38 - leadingZeros) / 7 + leadingZeros / 32; + + // The above formula provides the implementation, but the Java encoding is suboptimal + // when we have a narrow range of integers, so we can do better manually + + int leadingZeros = Integer.numberOfLeadingZeros(value); + int leadingZerosBelow38DividedBy7 = ((38 - leadingZeros) * 0b10010010010010011) >>> 19; + return leadingZerosBelow38DividedBy7 + (leadingZeros >>> 5); } /** @@ -413,15 +424,18 @@ public static int sizeOfVarint(int value) { * Number of bytes needed to encode a long in variable-length format. * * @param value The signed value + * @see #sizeOfUnsignedVarint(int) */ public static int sizeOfVarlong(long value) { long v = (value << 1) ^ (value >> 63); - int bytes = 1; - while ((v & 0xffffffffffffff80L) != 0L) { - bytes += 1; - v >>>= 7; - } - return bytes; + + // For implementation notes @see #sizeOfUnsignedVarint(int) + // Similar logic is applied to allow for 64bit input -> 1-9byte output. + // return (70 - leadingZeros) / 7 + leadingZeros / 64; + + int leadingZeros = Long.numberOfLeadingZeros(v); + int leadingZerosBelow70DividedBy7 = ((70 - leadingZeros) * 0b10010010010010011) >>> 19; + return leadingZerosBelow70DividedBy7 + (leadingZeros >>> 6); } private static IllegalArgumentException illegalVarintException(int value) { diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java index 8f432f7632353..5f855fa4a9c76 100644 --- a/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java +++ b/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java @@ -24,6 +24,8 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.function.IntFunction; +import java.util.function.LongFunction; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -239,6 +241,49 @@ public void testDouble() throws IOException { assertDoubleSerde(Double.NEGATIVE_INFINITY, 0xFFF0000000000000L); } + @Test + public void testSizeOfUnsignedVarint() { + // The old well-known implementation for sizeOfUnsignedVarint + IntFunction simpleImplementation = (int value) -> { + int bytes = 1; + while ((value & 0xffffff80) != 0L) { + bytes += 1; + value >>>= 7; + } + return bytes; + }; + + // compare the full range of values + for (int i = 0; i < Integer.MAX_VALUE && i >= 0; i += 13) { + final int actual = ByteUtils.sizeOfUnsignedVarint(i); + final int expected = simpleImplementation.apply(i); + assertEquals(expected, actual); + } + } + + @Test + public void testSizeOfVarlong() { + // The old well-known implementation for sizeOfVarlong + LongFunction simpleImplementation = (long value) -> { + long v = (value << 1) ^ (value >> 63); + int bytes = 1; + while ((v & 0xffffffffffffff80L) != 0L) { + bytes += 1; + v >>>= 7; + } + return bytes; + }; + + for (long l = 1; l < Long.MAX_VALUE && l >= 0; l = l << 1) { + final int expected = simpleImplementation.apply(l); + final int actual = ByteUtils.sizeOfVarlong(l); + assertEquals(expected, actual); + } + + // check zero as well + assertEquals(simpleImplementation.apply(0), ByteUtils.sizeOfVarlong(0)); + } + private void assertUnsignedVarintSerde(int value, byte[] expectedEncoding) throws IOException { ByteBuffer buf = ByteBuffer.allocate(32); ByteUtils.writeUnsignedVarint(value, buf); diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/util/ByteUtilsBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/util/ByteUtilsBenchmark.java new file mode 100644 index 0000000000000..bee8c16a260fa --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/util/ByteUtilsBenchmark.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.util; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.utils.ByteUtils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(3) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 10, time = 1) +public class ByteUtilsBenchmark { + private int inputInt; + private long inputLong; + @Setup(Level.Iteration) + public void setUp() { + inputInt = ThreadLocalRandom.current().nextInt(); + inputLong = ThreadLocalRandom.current().nextLong(); + } + + @Benchmark + public int testSizeOfUnsignedVarint() { + return ByteUtils.sizeOfUnsignedVarint(inputInt); + } + + @Benchmark + public int testSizeOfUnsignedVarintSimple() { + int value = inputInt; + int bytes = 1; + while ((value & 0xffffff80) != 0L) { + bytes += 1; + value >>>= 7; + } + return bytes; + } + + @Benchmark + public int testSizeOfVarlong() { + return ByteUtils.sizeOfVarlong(inputLong); + } + + @Benchmark + public int testSizeOfVarlongSimple() { + long v = (inputLong << 1) ^ (inputLong >> 63); + int bytes = 1; + while ((v & 0xffffffffffffff80L) != 0L) { + bytes += 1; + v >>>= 7; + } + return bytes; + } + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(ByteUtilsBenchmark.class.getSimpleName()) + .forks(2) + .build(); + + new Runner(opt).run(); + } +}