diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java index a91f534a8dd85..ad4f14a8c3adc 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java @@ -110,7 +110,7 @@ public void writeUnsignedVarint(int i) { @Override public void writeByteBuffer(ByteBuffer src) { - buf.put(src); + buf.put(src.duplicate()); } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Message.java b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java index 2a313ff9fdcaf..75641d3c0e2d8 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Message.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java @@ -47,7 +47,20 @@ public interface Message { * If the specified version is too new to be supported * by this software. */ - int size(ObjectSerializationCache cache, short version); + default int size(ObjectSerializationCache cache, short version) { + MessageSizeAccumulator size = new MessageSizeAccumulator(); + addSize(size, cache, version); + return size.totalSize(); + } + + /** + * Add the size of this message to an accumulator. + * + * @param size The size accumulator to add to + * @param cache The serialization size cache to populate. + * @param version The version to use. + */ + void addSize(MessageSizeAccumulator size, ObjectSerializationCache cache, short version); /** * Writes out this message to the given Writable. diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java b/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java new file mode 100644 index 0000000000000..9c187e017a376 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +/** + * Helper class which facilitates zero-copy network transmission. See {@link SendBuilder}. + */ +public class MessageSizeAccumulator { + private int totalSize = 0; + private int zeroCopySize = 0; + + /** + * Get the total size of the message. + * + * @return total size in bytes + */ + public int totalSize() { + return totalSize; + } + + /** + * Get the total "zero-copy" size of the message. This is the summed + * total of all fields which have either have a type of 'bytes' with + * 'zeroCopy' enabled, or a type of 'records' + * + * @return total size of zero-copy data in the message + */ + public int zeroCopySize() { + return zeroCopySize; + } + + public void addZeroCopyBytes(int size) { + zeroCopySize += size; + totalSize += size; + } + + public void addBytes(int size) { + totalSize += size; + } + + public void add(MessageSizeAccumulator size) { + this.totalSize += size.totalSize; + this.zeroCopySize += size.zeroCopySize; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java b/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java index 98addf443ff0f..208b0567737ee 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java @@ -36,14 +36,12 @@ public ObjectSerializationCache() { this.map = new IdentityHashMap<>(); } - public void setArraySizeInBytes(Object o, int size) { - map.put(o, Integer.valueOf(size)); + public void setArraySizeInBytes(Object o, Integer size) { + map.put(o, size); } - public int getArraySizeInBytes(Object o) { - Object value = map.get(o); - Integer sizeInBytes = (Integer) value; - return sizeInBytes; + public Integer getArraySizeInBytes(Object o) { + return (Integer) map.get(o); } public void cacheSerializedValue(Object o, byte[] val) { diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java index 653f88c767331..805b3c7316612 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java @@ -17,9 +17,9 @@ package org.apache.kafka.common.protocol; -import org.apache.kafka.common.protocol.types.RawTaggedField; - import org.apache.kafka.common.UUID; +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.apache.kafka.common.record.MemoryRecords; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -54,6 +54,16 @@ default List readUnknownTaggedField(List unknown return unknowns; } + default MemoryRecords readRecords(int length) { + if (length < 0) { + // no records + return null; + } else { + ByteBuffer recordsBuffer = readByteBuffer(length); + return MemoryRecords.readableRecords(recordsBuffer); + } + } + /** * Read a UUID with the most significant digits first. */ diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java deleted file mode 100644 index 6458fe18307b9..0000000000000 --- a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.protocol; - -import org.apache.kafka.common.record.BaseRecords; -import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.utils.ByteUtils; - -import java.nio.ByteBuffer; - -/** - * Implementation of Readable which reads from a byte buffer and can read records as {@link MemoryRecords} - * - * @see org.apache.kafka.common.requests.FetchResponse - */ -public class RecordsReadable implements Readable { - private final ByteBuffer buf; - - public RecordsReadable(ByteBuffer buf) { - this.buf = buf; - } - - @Override - public byte readByte() { - return buf.get(); - } - - @Override - public short readShort() { - return buf.getShort(); - } - - @Override - public int readInt() { - return buf.getInt(); - } - - @Override - public long readLong() { - return buf.getLong(); - } - - @Override - public double readDouble() { - return ByteUtils.readDouble(buf); - } - - @Override - public void readArray(byte[] arr) { - buf.get(arr); - } - - @Override - public int readUnsignedVarint() { - return ByteUtils.readUnsignedVarint(buf); - } - - @Override - public ByteBuffer readByteBuffer(int length) { - ByteBuffer res = buf.slice(); - res.limit(length); - - buf.position(buf.position() + length); - - return res; - } - - @Override - public int readVarint() { - return ByteUtils.readVarint(buf); - } - - @Override - public long readVarlong() { - return ByteUtils.readVarlong(buf); - } - - public BaseRecords readRecords(int length) { - if (length < 0) { - // no records - return null; - } else { - ByteBuffer recordsBuffer = readByteBuffer(length); - return MemoryRecords.readableRecords(recordsBuffer); - } - } - -} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java deleted file mode 100644 index 9d49129f00e6f..0000000000000 --- a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.protocol; - -import org.apache.kafka.common.network.ByteBufferSend; -import org.apache.kafka.common.network.Send; -import org.apache.kafka.common.record.BaseRecords; -import org.apache.kafka.common.utils.ByteUtils; - -import java.io.DataOutput; -import java.nio.ByteBuffer; -import java.util.function.Consumer; - -/** - * Implementation of Writable which produces a sequence of {@link Send} objects. This allows for deferring the transfer - * of data from a record-set's file channel to the eventual socket channel. - * - * Excepting {@link #writeRecords(BaseRecords)}, calls to the write methods on this class will append to a byte array - * according to the format specified in {@link DataOutput}. When a call is made to writeRecords, any previously written - * bytes will be flushed as a new {@link ByteBufferSend} to the given Send consumer. After flushing the pending bytes, - * another Send is passed to the consumer which wraps the underlying record-set's transfer logic. - * - * For example, - * - *
- *     recordsWritable.writeInt(10);
- *     recordsWritable.writeRecords(records1);
- *     recordsWritable.writeInt(20);
- *     recordsWritable.writeRecords(records2);
- *     recordsWritable.writeInt(30);
- *     recordsWritable.flush();
- * 
- * - * Will pass 5 Send objects to the consumer given in the constructor. Care must be taken by callers to flush any - * pending bytes at the end of the writing sequence to ensure everything is flushed to the consumer. This class is - * intended to be used with {@link org.apache.kafka.common.record.MultiRecordsSend}. - * - * @see org.apache.kafka.common.requests.FetchResponse - */ -public class RecordsWritable implements Writable { - private final String dest; - private final Consumer sendConsumer; - private final ByteBuffer buffer; - private int mark; - - public RecordsWritable(String dest, int messageSizeExcludingRecords, Consumer sendConsumer) { - this.dest = dest; - this.sendConsumer = sendConsumer; - this.buffer = ByteBuffer.allocate(messageSizeExcludingRecords); - this.mark = 0; - } - - @Override - public void writeByte(byte val) { - buffer.put(val); - } - - @Override - public void writeShort(short val) { - buffer.putShort(val); - } - - @Override - public void writeInt(int val) { - buffer.putInt(val); - } - - @Override - public void writeLong(long val) { - buffer.putLong(val); - } - - @Override - public void writeDouble(double val) { - ByteUtils.writeDouble(val, buffer); - } - - @Override - public void writeByteArray(byte[] arr) { - buffer.put(arr); - } - - @Override - public void writeUnsignedVarint(int i) { - ByteUtils.writeUnsignedVarint(i, buffer); - } - - @Override - public void writeByteBuffer(ByteBuffer src) { - buffer.put(src); - } - - @Override - public void writeVarint(int i) { - ByteUtils.writeVarint(i, buffer); - } - - @Override - public void writeVarlong(long i) { - ByteUtils.writeVarlong(i, buffer); - } - - public void writeRecords(BaseRecords records) { - flush(); - sendConsumer.accept(records.toSend(dest)); - } - - /** - * Flush any pending bytes as a ByteBufferSend - */ - public void flush() { - int end = buffer.position(); - int len = end - mark; - - if (len > 0) { - int limit = buffer.limit(); - - // Set the desired absolute position and limit before slicing - buffer.position(mark); - buffer.limit(end); - ByteBuffer slice = buffer.slice(); - - // Restore absolute position and limit on original buffer - buffer.limit(limit); - buffer.position(end); - - // Update the mark to the end of slice we just took - mark = end; - - ByteBufferSend send = new ByteBufferSend(dest, slice); - sendConsumer.accept(send); - } - } -} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java b/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java new file mode 100644 index 0000000000000..18074fd5cb9d5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.MultiRecordsSend; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.ResponseHeader; +import org.apache.kafka.common.utils.ByteUtils; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * This class provides a way to build {@link Send} objects for network transmission + * from generated {@link org.apache.kafka.common.protocol.ApiMessage} types without + * allocating new space for "zero-copy" fields (see {@link #writeByteBuffer(ByteBuffer)} + * and {@link #writeRecords(BaseRecords)}). + * + * See {@link org.apache.kafka.common.requests.EnvelopeRequest#toSend(String, RequestHeader)} + * for example usage. + */ +public class SendBuilder implements Writable { + private final Queue sends = new ArrayDeque<>(); + private final ByteBuffer buffer; + private final String destinationId; + + SendBuilder(String destinationId, int size) { + this.destinationId = destinationId; + this.buffer = ByteBuffer.allocate(size); + this.buffer.mark(); + } + + private void flushCurrentBuffer() { + int latestPosition = buffer.position(); + buffer.reset(); + + if (latestPosition > buffer.position()) { + buffer.limit(latestPosition); + addByteBufferSend(buffer.slice()); + buffer.position(latestPosition); + buffer.limit(buffer.capacity()); + buffer.mark(); + } + } + + private void addByteBufferSend(ByteBuffer buffer) { + sends.add(new ByteBufferSend(destinationId, buffer)); + } + + @Override + public void writeByte(byte val) { + buffer.put(val); + } + + @Override + public void writeShort(short val) { + buffer.putShort(val); + } + + @Override + public void writeInt(int val) { + buffer.putInt(val); + } + + @Override + public void writeLong(long val) { + buffer.putLong(val); + } + + @Override + public void writeDouble(double val) { + buffer.putDouble(val); + } + + @Override + public void writeByteArray(byte[] arr) { + buffer.put(arr); + } + + @Override + public void writeUnsignedVarint(int i) { + ByteUtils.writeUnsignedVarint(i, buffer); + } + + /** + * Write a byte buffer. The reference to the underlying buffer will + * be retained in the result of {@link #build()}. + * + * @param buf the buffer to write + */ + @Override + public void writeByteBuffer(ByteBuffer buf) { + flushCurrentBuffer(); + addByteBufferSend(buf.duplicate()); + } + + @Override + public void writeVarint(int i) { + ByteUtils.writeVarint(i, buffer); + } + + @Override + public void writeVarlong(long i) { + ByteUtils.writeVarlong(i, buffer); + } + + /** + * Write a record set. The underlying record data will be retained + * in the result of {@link #build()}. See {@link BaseRecords#toSend(String)}. + * + * @param records the records to write + */ + @Override + public void writeRecords(BaseRecords records) { + flushCurrentBuffer(); + sends.add(records.toSend(destinationId)); + } + + public Send build() { + flushCurrentBuffer(); + return new MultiRecordsSend(destinationId, sends); + } + + public static Send buildRequestSend( + String destination, + RequestHeader header, + ApiMessage apiRequest + ) { + return buildSend( + destination, + header.data(), + header.headerVersion(), + apiRequest, + header.apiVersion() + ); + } + + public static Send buildResponseSend( + String destination, + ResponseHeader header, + ApiMessage apiResponse, + short apiVersion + ) { + return buildSend( + destination, + header.data(), + header.headerVersion(), + apiResponse, + apiVersion + ); + } + + private static Send buildSend( + String destination, + ApiMessage header, + short headerVersion, + ApiMessage apiMessage, + short apiVersion + ) { + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + MessageSizeAccumulator messageSize = new MessageSizeAccumulator(); + + header.addSize(messageSize, serializationCache, headerVersion); + apiMessage.addSize(messageSize, serializationCache, apiVersion); + + int totalSize = messageSize.totalSize(); + int sizeExcludingZeroCopyFields = totalSize - messageSize.zeroCopySize(); + + SendBuilder builder = new SendBuilder(destination, sizeExcludingZeroCopyFields + 4); + builder.writeInt(totalSize); + header.write(builder, serializationCache, headerVersion); + apiMessage.write(builder, serializationCache, apiVersion); + return builder.build(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java index fd3e2b85c8bd9..ac66509741136 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java @@ -18,6 +18,8 @@ package org.apache.kafka.common.protocol; import org.apache.kafka.common.UUID; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.MemoryRecords; import java.nio.ByteBuffer; @@ -33,6 +35,15 @@ public interface Writable { void writeVarint(int i); void writeVarlong(long i); + default void writeRecords(BaseRecords records) { + if (records instanceof MemoryRecords) { + MemoryRecords memRecords = (MemoryRecords) records; + writeByteBuffer(memRecords.buffer()); + } else { + throw new UnsupportedOperationException("Unsupported record type " + records.getClass()); + } + } + default void writeUUID(UUID uuid) { writeLong(uuid.getMostSignificantBits()); writeLong(uuid.getLeastSignificantBits()); diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java index 0f90e25dab25f..2ff143c0eb4ab 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java @@ -40,14 +40,6 @@ protected Send toSend(String destination, ResponseHeader header, short apiVersio return new NetworkSend(destination, RequestUtils.serialize(header.toStruct(), toStruct(apiVersion))); } - /** - * Used for forwarding response serialization, typically {@link #toSend(String, ResponseHeader, short)} - * should be used instead. - */ - public ByteBuffer serialize(short version, ResponseHeader responseHeader) { - return RequestUtils.serialize(responseHeader.toStruct(), toStruct(version)); - } - /** * Visible for testing, typically {@link #toSend(String, ResponseHeader, short)} should be used instead. */ @@ -88,16 +80,11 @@ protected void updateErrorCounts(Map errorCounts, Errors error) protected abstract Struct toStruct(short version); - public ByteBuffer serializeBody(short version) { - Struct dataStruct = toStruct(version); - ByteBuffer buffer = ByteBuffer.allocate(dataStruct.sizeOf()); - dataStruct.writeTo(buffer); - buffer.flip(); - - return buffer; - } - - public static AbstractResponse deserializeBody(ByteBuffer byteBuffer, RequestHeader header) { + /** + * Parse a response from the provided buffer. The buffer is expected to hold both + * the {@link ResponseHeader} as well as the response payload. + */ + public static AbstractResponse parseResponse(ByteBuffer byteBuffer, RequestHeader header) { ApiKeys apiKey = header.apiKey(); short apiVersion = header.apiVersion(); diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java index 4a95b32f1799c..58f2d245c530d 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java @@ -18,6 +18,8 @@ import org.apache.kafka.common.message.EnvelopeRequestData; import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.SendBuilder; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.types.Struct; @@ -36,7 +38,7 @@ public Builder(ByteBuffer requestData, super(ApiKeys.ENVELOPE); this.data = new EnvelopeRequestData() .setRequestData(requestData) - .setRequestPrincipal(ByteBuffer.wrap(serializedPrincipal)) + .setRequestPrincipal(serializedPrincipal) .setClientHostAddress(clientAddress); } @@ -71,10 +73,8 @@ public byte[] clientAddress() { return data.clientHostAddress(); } - public byte[] principalData() { - byte[] serializedPrincipal = new byte[data.requestPrincipal().limit()]; - data.requestPrincipal().get(serializedPrincipal); - return serializedPrincipal; + public byte[] requestPrincipal() { + return data.requestPrincipal(); } @Override @@ -91,4 +91,14 @@ public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { public static EnvelopeRequest parse(ByteBuffer buffer, short version) { return new EnvelopeRequest(ApiKeys.ENVELOPE.parseRequest(version, buffer), version); } + + public EnvelopeRequestData data() { + return data; + } + + @Override + public Send toSend(String destination, RequestHeader header) { + return SendBuilder.buildRequestSend(destination, header, this.data); + } + } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java index aadf987cc3083..d95325b2214c3 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java @@ -17,6 +17,8 @@ package org.apache.kafka.common.requests; import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.SendBuilder; import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.types.Struct; @@ -59,4 +61,14 @@ protected Struct toStruct(short version) { public Errors error() { return Errors.forCode(data.errorCode()); } + + public EnvelopeResponseData data() { + return data; + } + + @Override + protected Send toSend(String destination, ResponseHeader header, short apiVersion) { + return SendBuilder.buildResponseSend(destination, header, this.data, apiVersion); + } + } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java index e9c0394e7c986..33bacbe10d319 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java @@ -18,21 +18,16 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.message.FetchResponseData; -import org.apache.kafka.common.message.ResponseHeaderData; -import org.apache.kafka.common.network.ByteBufferSend; import org.apache.kafka.common.network.Send; import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.ObjectSerializationCache; -import org.apache.kafka.common.protocol.RecordsReadable; -import org.apache.kafka.common.protocol.RecordsWritable; +import org.apache.kafka.common.protocol.SendBuilder; import org.apache.kafka.common.protocol.types.Struct; import org.apache.kafka.common.record.BaseRecords; import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.MultiRecordsSend; import java.nio.ByteBuffer; -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; @@ -294,37 +289,7 @@ public Struct toStruct(short version) { @Override public Send toSend(String dest, ResponseHeader responseHeader, short apiVersion) { - // Generate the Sends for the response fields and records - ArrayDeque sends = new ArrayDeque<>(); - ObjectSerializationCache cache = new ObjectSerializationCache(); - int totalRecordSize = data.responses().stream() - .flatMap(fetchableTopicResponse -> fetchableTopicResponse.partitionResponses().stream()) - .mapToInt(fetchablePartitionResponse -> fetchablePartitionResponse.recordSet().sizeInBytes()) - .sum(); - int totalMessageSize = data.size(cache, apiVersion); - - RecordsWritable writer = new RecordsWritable(dest, totalMessageSize - totalRecordSize, sends::add); - data.write(writer, cache, apiVersion); - writer.flush(); - - // Compute the total size of all the Sends and write it out along with the header in the first Send - ResponseHeaderData responseHeaderData = responseHeader.data(); - - int headerSize = responseHeaderData.size(cache, responseHeader.headerVersion()); - int bodySize = Math.toIntExact(sends.stream().mapToLong(Send::size).sum()); - - ByteBuffer buffer = ByteBuffer.allocate(headerSize + 4); - ByteBufferAccessor headerWriter = new ByteBufferAccessor(buffer); - - // Write out the size and header - buffer.putInt(headerSize + bodySize); - responseHeaderData.write(headerWriter, cache, responseHeader.headerVersion()); - - // Rewind the buffer and set this the first Send in the MultiRecordsSend - buffer.rewind(); - sends.addFirst(new ByteBufferSend(dest, buffer)); - - return new MultiRecordsSend(dest, sends); + return SendBuilder.buildResponseSend(dest, responseHeader, this.data, apiVersion); } public Errors error() { @@ -355,7 +320,7 @@ public Map errorCounts() { public static FetchResponse parse(ByteBuffer buffer, short version) { FetchResponseData fetchResponseData = new FetchResponseData(); - RecordsReadable reader = new RecordsReadable(buffer); + ByteBufferAccessor reader = new ByteBufferAccessor(buffer); fetchResponseData.read(reader, version); return new FetchResponse<>(fetchResponseData); } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java index d7f94ca7d261c..7136584d93a88 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java @@ -106,11 +106,30 @@ public RequestAndSize parseRequest(ByteBuffer buffer) { } } - public Send buildResponse(AbstractResponse body) { + /** + * Build a {@link Send} for direct transmission of the provided response + * over the network. + */ + public Send buildResponseSend(AbstractResponse body) { ResponseHeader responseHeader = header.toResponseHeader(); return body.toSend(connectionId, responseHeader, apiVersion()); } + /** + * Serialize a response into a {@link ByteBuffer}. This is used when the response + * will be encapsulated in an {@link EnvelopeResponse}. The buffer will contain + * both the serialized {@link ResponseHeader} as well as the bytes from the response. + * There is no `size` prefix unlike the output from {@link #buildResponseSend(AbstractResponse)}. + * + * Note that envelope requests are reserved only for APIs which have set the + * {@link ApiKeys#forwardable} flag. Notably the `Fetch` API cannot be forwarded, + * so we do not lose the benefit of "zero copy" transfers from disk. + */ + public ByteBuffer buildResponseEnvelopePayload(AbstractResponse body) { + ResponseHeader responseHeader = header.toResponseHeader(); + return RequestUtils.serialize(responseHeader.toStruct(), body.toStruct(header.apiVersion())); + } + private boolean isUnsupportedApiVersionsRequest() { return header.apiKey() == API_VERSIONS && !API_VERSIONS.isVersionSupported(header.apiVersion()); } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java index c1a822fa6dc44..d145f49280852 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java @@ -89,10 +89,11 @@ public ResponseHeader toResponseHeader() { public static RequestHeader parse(ByteBuffer buffer) { short apiKey = -1; try { + int position = buffer.position(); apiKey = buffer.getShort(); short apiVersion = buffer.getShort(); short headerVersion = ApiKeys.forId(apiKey).requestHeaderVersion(apiVersion); - buffer.rewind(); + buffer.position(position); return new RequestHeader(new RequestHeaderData( new ByteBufferAccessor(buffer), headerVersion), headerVersion); } catch (UnsupportedVersionException e) { diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java b/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java index 118e5d3506d72..3fe7d9ebf4c7b 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java @@ -21,6 +21,7 @@ import org.apache.kafka.common.protocol.types.Struct; import java.nio.ByteBuffer; +import java.util.Objects; /** * A response header in the kafka protocol. @@ -67,4 +68,18 @@ public static ResponseHeader parse(ByteBuffer buffer, short headerVersion) { new ResponseHeaderData(new ByteBufferAccessor(buffer), headerVersion), headerVersion); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ResponseHeader that = (ResponseHeader) o; + return headerVersion == that.headerVersion && + Objects.equals(data, that.data); + } + + @Override + public int hashCode() { + return Objects.hash(data, headerVersion); + } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java index 81e701bb3bc78..924ae4431d9ad 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java @@ -588,7 +588,7 @@ else if (!apiVersionsRequest.isValid()) * {@link #sendAuthenticationFailureResponse()} is called. */ private void buildResponseOnAuthenticateFailure(RequestContext context, AbstractResponse response) { - authenticationFailureSend = context.buildResponse(response); + authenticationFailureSend = context.buildResponseSend(response); } /** @@ -602,7 +602,7 @@ private void sendAuthenticationFailureResponse() throws IOException { } private void sendKafkaResponse(RequestContext context, AbstractResponse response) throws IOException { - sendKafkaResponse(context.buildResponse(response)); + sendKafkaResponse(context.buildResponseSend(response)); } private void sendKafkaResponse(Send send) throws IOException { diff --git a/clients/src/main/resources/common/message/EnvelopeRequest.json b/clients/src/main/resources/common/message/EnvelopeRequest.json index c416a4aba5e70..a1aa760a29b18 100644 --- a/clients/src/main/resources/common/message/EnvelopeRequest.json +++ b/clients/src/main/resources/common/message/EnvelopeRequest.json @@ -23,7 +23,7 @@ "fields": [ { "name": "RequestData", "type": "bytes", "versions": "0+", "zeroCopy": true, "about": "The embedded request header and data."}, - { "name": "RequestPrincipal", "type": "bytes", "versions": "0+", "zeroCopy": true, "nullableVersions": "0+", + { "name": "RequestPrincipal", "type": "bytes", "versions": "0+", "nullableVersions": "0+", "about": "Value of the initial client principal when the request is redirected by a broker." }, { "name": "ClientHostAddress", "type": "bytes", "versions": "0+", "about": "The original client's address in bytes." } diff --git a/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java b/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java index dbb4b4a080b9d..f3bc05d23da97 100644 --- a/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java +++ b/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java @@ -16,22 +16,17 @@ */ package org.apache.kafka.common.message; -import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.ObjectSerializationCache; -import org.apache.kafka.common.protocol.RecordsReadable; -import org.apache.kafka.common.protocol.RecordsWritable; import org.apache.kafka.common.protocol.types.Schema; import org.apache.kafka.common.protocol.types.Struct; import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.MultiRecordsSend; import org.apache.kafka.common.record.SimpleRecord; -import org.apache.kafka.common.requests.ByteBufferChannel; import org.junit.Test; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayDeque; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -97,7 +92,7 @@ private SimpleRecordsMessageData deserializeThroughStruct(ByteBuffer buffer, sho } private SimpleRecordsMessageData deserialize(ByteBuffer buffer, short version) { - RecordsReadable readable = new RecordsReadable(buffer); + ByteBufferAccessor readable = new ByteBufferAccessor(buffer); return new SimpleRecordsMessageData(readable, version); } @@ -109,22 +104,14 @@ private ByteBuffer serializeThroughStruct(SimpleRecordsMessageData message, shor return buffer; } - private ByteBuffer serialize(SimpleRecordsMessageData message, short version) throws IOException { - ArrayDeque sends = new ArrayDeque<>(); + private ByteBuffer serialize(SimpleRecordsMessageData message, short version) { ObjectSerializationCache cache = new ObjectSerializationCache(); int totalMessageSize = message.size(cache, version); - - int recordsSize = message.recordSet() == null ? 0 : message.recordSet().sizeInBytes(); - RecordsWritable writer = new RecordsWritable("0", - totalMessageSize - recordsSize, sends::add); + ByteBuffer buffer = ByteBuffer.allocate(totalMessageSize); + ByteBufferAccessor writer = new ByteBufferAccessor(buffer); message.write(writer, cache, version); - writer.flush(); - - MultiRecordsSend send = new MultiRecordsSend("0", sends); - ByteBufferChannel channel = new ByteBufferChannel(send.size()); - send.writeTo(channel); - channel.close(); - return channel.buffer(); + buffer.flip(); + return buffer; } } diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/RecordsWritableTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/RecordsWritableTest.java deleted file mode 100644 index 937da4d449611..0000000000000 --- a/clients/src/test/java/org/apache/kafka/common/protocol/RecordsWritableTest.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.protocol; - -import org.apache.kafka.common.network.ByteBufferSend; -import org.apache.kafka.common.network.Send; -import org.junit.Assert; -import org.junit.Test; - -import java.util.ArrayDeque; -import java.util.Queue; - -public class RecordsWritableTest { - @Test - public void testBufferSlice() { - Queue sends = new ArrayDeque<>(); - RecordsWritable writer = new RecordsWritable("dest", 10000 /* enough for tests */, sends::add); - for (int i = 0; i < 4; i++) { - writer.writeInt(i); - } - writer.flush(); - Assert.assertEquals(sends.size(), 1); - ByteBufferSend send = (ByteBufferSend) sends.remove(); - Assert.assertEquals(send.size(), 16); - Assert.assertEquals(send.remaining(), 16); - - // No new data, flush shouldn't do anything - writer.flush(); - Assert.assertEquals(sends.size(), 0); - - // Cause the buffer to expand a few times - for (int i = 0; i < 100; i++) { - writer.writeInt(i); - } - writer.flush(); - Assert.assertEquals(sends.size(), 1); - send = (ByteBufferSend) sends.remove(); - Assert.assertEquals(send.size(), 400); - Assert.assertEquals(send.remaining(), 400); - - writer.writeByte((byte) 5); - writer.flush(); - Assert.assertEquals(sends.size(), 1); - send = (ByteBufferSend) sends.remove(); - Assert.assertEquals(send.size(), 1); - Assert.assertEquals(send.remaining(), 1); - } -} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java new file mode 100644 index 0000000000000..e06fa87b7edd9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SendBuilderTest { + + @Test + public void testZeroCopyByteBuffer() { + byte[] data = Utils.utf8("foo"); + ByteBuffer zeroCopyBuffer = ByteBuffer.wrap(data); + SendBuilder builder = new SendBuilder("a", 8); + + builder.writeInt(5); + builder.writeByteBuffer(zeroCopyBuffer); + builder.writeInt(15); + Send send = builder.build(); + + // Overwrite the original buffer in order to prove the data was not copied + byte[] overwrittenData = Utils.utf8("bar"); + assertEquals(data.length, overwrittenData.length); + zeroCopyBuffer.rewind(); + zeroCopyBuffer.put(overwrittenData); + zeroCopyBuffer.rewind(); + + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(8 + data.length, buffer.remaining()); + assertEquals(5, buffer.getInt()); + assertEquals("bar", getString(buffer, data.length)); + assertEquals(15, buffer.getInt()); + } + + @Test + public void testWriteByteBufferRespectsPosition() { + byte[] data = Utils.utf8("yolo"); + assertEquals(4, data.length); + + ByteBuffer buffer = ByteBuffer.wrap(data); + SendBuilder builder = new SendBuilder("a", 0); + + buffer.limit(2); + builder.writeByteBuffer(buffer); + assertEquals(0, buffer.position()); + + buffer.position(2); + buffer.limit(4); + builder.writeByteBuffer(buffer); + assertEquals(2, buffer.position()); + + Send send = builder.build(); + ByteBuffer readBuffer = TestUtils.toBuffer(send); + assertEquals("yolo", getString(readBuffer, 4)); + } + + @Test + public void testZeroCopyRecords() { + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecords records = createRecords(buffer, "foo"); + + SendBuilder builder = new SendBuilder("a", 8); + builder.writeInt(5); + builder.writeRecords(records); + builder.writeInt(15); + Send send = builder.build(); + + // Overwrite the original buffer in order to prove the data was not copied + buffer.rewind(); + MemoryRecords overwrittenRecords = createRecords(buffer, "bar"); + + ByteBuffer readBuffer = TestUtils.toBuffer(send); + assertEquals(5, readBuffer.getInt()); + assertEquals(overwrittenRecords, getRecords(readBuffer, records.sizeInBytes())); + assertEquals(15, readBuffer.getInt()); + } + + private String getString(ByteBuffer buffer, int size) { + byte[] readData = new byte[size]; + buffer.get(readData); + return Utils.utf8(readData); + } + + private MemoryRecords getRecords(ByteBuffer buffer, int size) { + int initialPosition = buffer.position(); + int initialLimit = buffer.limit(); + int recordsLimit = initialPosition + size; + + buffer.limit(recordsLimit); + MemoryRecords records = MemoryRecords.readableRecords(buffer.slice()); + + buffer.position(recordsLimit); + buffer.limit(initialLimit); + return records; + } + + private MemoryRecords createRecords(ByteBuffer buffer, String value) { + MemoryRecordsBuilder recordsBuilder = MemoryRecords.builder( + buffer, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0L + ); + recordsBuilder.append(new SimpleRecord(Utils.utf8(value))); + return recordsBuilder.build(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java b/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java index 38d381c75e5db..81f007420d8c2 100644 --- a/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java +++ b/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java @@ -69,7 +69,7 @@ private NonOverflowingByteBufferChannel(long size) { } @Override - public long write(ByteBuffer[] srcs) throws IOException { + public long write(ByteBuffer[] srcs) { // Instead of overflowing, this channel refuses additional writes once the buffer is full, // which allows us to test the MultiRecordsSend behavior on a per-send basis. if (!buffer().hasRemaining()) diff --git a/clients/src/test/java/org/apache/kafka/common/requests/AbstractResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/AbstractResponseTest.java deleted file mode 100644 index 5932431b7d4bb..0000000000000 --- a/clients/src/test/java/org/apache/kafka/common/requests/AbstractResponseTest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.kafka.common.requests; - -import org.apache.kafka.common.message.CreateTopicsResponseData; -import org.apache.kafka.common.protocol.ApiKeys; -import org.apache.kafka.common.protocol.Errors; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -public class AbstractResponseTest { - - @Test - public void testResponseSerde() { - CreateTopicsResponseData.CreatableTopicResultCollection collection = - new CreateTopicsResponseData.CreatableTopicResultCollection(); - collection.add(new CreateTopicsResponseData.CreatableTopicResult() - .setTopicConfigErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()) - .setNumPartitions(5)); - CreateTopicsResponse createTopicsResponse = new CreateTopicsResponse( - new CreateTopicsResponseData() - .setThrottleTimeMs(10) - .setTopics(collection) - ); - - final short version = (short) (CreateTopicsResponseData.SCHEMAS.length - 1); - final RequestHeader header = new RequestHeader(ApiKeys.CREATE_TOPICS, version, "client", 4); - - final EnvelopeResponse envelopeResponse = new EnvelopeResponse( - createTopicsResponse.serialize(version, header.toResponseHeader()), - Errors.NONE - ); - - CreateTopicsResponse extractedResponse = (CreateTopicsResponse) CreateTopicsResponse.deserializeBody( - envelopeResponse.responseData(), header); - assertEquals(createTopicsResponse.data(), extractedResponse.data()); - } -} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java index 0446df97c322a..44bb358159eb0 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java @@ -16,7 +16,6 @@ */ package org.apache.kafka.common.requests; -import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.GatheringByteChannel; @@ -31,24 +30,25 @@ public ByteBufferChannel(long size) { } @Override - public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + public long write(ByteBuffer[] srcs, int offset, int length) { int position = buf.position(); for (int i = 0; i < length; i++) { ByteBuffer src = srcs[i].duplicate(); - if (i == 0) - src.position(offset); + if (i == 0) { + src.position(src.position() + offset); + } buf.put(src); } return buf.position() - position; } @Override - public long write(ByteBuffer[] srcs) throws IOException { + public long write(ByteBuffer[] srcs) { return write(srcs, 0, srcs.length); } @Override - public int write(ByteBuffer src) throws IOException { + public int write(ByteBuffer src) { int position = buf.position(); buf.put(src); return buf.position() - position; @@ -60,7 +60,7 @@ public boolean isOpen() { } @Override - public void close() throws IOException { + public void close() { buf.flip(); closed = true; } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java new file mode 100644 index 0000000000000..f8798b8a51cf8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.utils.Utils; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ByteBufferChannelTest { + + @Test + public void testWriteBufferArrayWithNonZeroPosition() { + byte[] data = Utils.utf8("hello"); + ByteBuffer buffer = ByteBuffer.allocate(32); + buffer.position(10); + buffer.put(data); + + int limit = buffer.position(); + buffer.position(10); + buffer.limit(limit); + + ByteBufferChannel channel = new ByteBufferChannel(buffer.remaining()); + ByteBuffer[] buffers = new ByteBuffer[] {buffer}; + channel.write(buffers); + channel.close(); + ByteBuffer channelBuffer = channel.buffer(); + assertEquals(data.length, channelBuffer.remaining()); + assertEquals("hello", Utils.utf8(channelBuffer)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java index c573bc15c56bd..af2627640230b 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java @@ -17,10 +17,16 @@ package org.apache.kafka.common.requests; import org.apache.kafka.common.message.EnvelopeRequestData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder; +import org.apache.kafka.test.TestUtils; import org.junit.Test; +import java.io.IOException; +import java.net.InetAddress; import java.nio.ByteBuffer; import static org.junit.Assert.assertEquals; @@ -35,6 +41,29 @@ public void testGetPrincipal() { EnvelopeRequest.Builder requestBuilder = new EnvelopeRequest.Builder(ByteBuffer.allocate(0), kafkaPrincipalBuilder.serialize(kafkaPrincipal), "client-address".getBytes()); EnvelopeRequest request = requestBuilder.build(EnvelopeRequestData.HIGHEST_SUPPORTED_VERSION); - assertEquals(kafkaPrincipal, kafkaPrincipalBuilder.deserialize(request.principalData())); + assertEquals(kafkaPrincipal, kafkaPrincipalBuilder.deserialize(request.requestPrincipal())); } + + @Test + public void testToSend() throws IOException { + for (short version = ApiKeys.ENVELOPE.oldestVersion(); version <= ApiKeys.ENVELOPE.latestVersion(); version++) { + ByteBuffer requestData = ByteBuffer.wrap("foobar".getBytes()); + RequestHeader header = new RequestHeader(ApiKeys.ENVELOPE, version, "clientId", 15); + EnvelopeRequest request = new EnvelopeRequest.Builder( + requestData, + "principal".getBytes(), + InetAddress.getLocalHost().getAddress() + ).build(version); + + Send send = request.toSend("a", header); + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(send.size() - 4, buffer.getInt()); + assertEquals(header, RequestHeader.parse(buffer)); + + EnvelopeRequestData parsedRequestData = new EnvelopeRequestData(); + parsedRequestData.read(new ByteBufferAccessor(buffer), version); + assertEquals(request.data(), parsedRequestData); + } + } + } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java new file mode 100644 index 0000000000000..76c20092692d9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class EnvelopeResponseTest { + + @Test + public void testToSend() { + for (short version = ApiKeys.ENVELOPE.oldestVersion(); version <= ApiKeys.ENVELOPE.latestVersion(); version++) { + ByteBuffer responseData = ByteBuffer.wrap("foobar".getBytes()); + EnvelopeResponse response = new EnvelopeResponse(responseData, Errors.NONE); + short headerVersion = ApiKeys.ENVELOPE.responseHeaderVersion(version); + ResponseHeader header = new ResponseHeader(15, headerVersion); + + Send send = response.toSend("a", header, version); + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(send.size() - 4, buffer.getInt()); + assertEquals(header, ResponseHeader.parse(buffer, headerVersion)); + + EnvelopeResponseData parsedResponseData = new EnvelopeResponseData(); + parsedResponseData.read(new ByteBufferAccessor(buffer), version); + assertEquals(response.data(), parsedResponseData); + } + } + +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java index 13e068cf266d2..47d49e618e9ab 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java @@ -18,6 +18,7 @@ import org.apache.kafka.common.message.ApiVersionsResponseData; import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionsResponseKeyCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData; import org.apache.kafka.common.network.ClientInformation; import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.network.Send; @@ -57,7 +58,7 @@ public void testSerdeUnsupportedApiVersionRequest() throws Exception { ApiVersionsRequest request = (ApiVersionsRequest) requestAndSize.request; assertTrue(request.hasUnsupportedRequestVersion()); - Send send = context.buildResponse(new ApiVersionsResponse(new ApiVersionsResponseData() + Send send = context.buildResponseSend(new ApiVersionsResponse(new ApiVersionsResponseData() .setThrottleTimeMs(0) .setErrorCode(Errors.UNSUPPORTED_VERSION.code()) .setApiKeys(new ApiVersionsResponseKeyCollection()))); @@ -78,4 +79,30 @@ public void testSerdeUnsupportedApiVersionRequest() throws Exception { assertEquals(Errors.UNSUPPORTED_VERSION.code(), response.data.errorCode()); assertTrue(response.data.apiKeys().isEmpty()); } + + @Test + public void testEnvelopeResponseSerde() throws Exception { + CreateTopicsResponseData.CreatableTopicResultCollection collection = + new CreateTopicsResponseData.CreatableTopicResultCollection(); + collection.add(new CreateTopicsResponseData.CreatableTopicResult() + .setTopicConfigErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()) + .setNumPartitions(5)); + CreateTopicsResponseData expectedResponse = new CreateTopicsResponseData() + .setThrottleTimeMs(10) + .setTopics(collection); + + int correlationId = 15; + String clientId = "clientId"; + RequestHeader header = new RequestHeader(ApiKeys.CREATE_TOPICS, ApiKeys.CREATE_TOPICS.latestVersion(), + clientId, correlationId); + + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), + KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL, + ClientInformation.EMPTY, true); + + ByteBuffer buffer = context.buildResponseEnvelopePayload(new CreateTopicsResponse(expectedResponse)); + CreateTopicsResponse parsedResponse = (CreateTopicsResponse) AbstractResponse.parseResponse(buffer, header); + assertEquals(expectedResponse, parsedResponse.data()); + } + } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java index 2842ae8e019ce..a3fa922359b4d 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java @@ -73,4 +73,19 @@ public void testRequestHeaderV2() { RequestHeader deserialized = RequestHeader.parse(buffer); assertEquals(header, deserialized); } + + @Test + public void parseHeaderFromBufferWithNonZeroPosition() { + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.position(10); + + RequestHeader header = new RequestHeader(ApiKeys.FIND_COORDINATOR, (short) 1, "", 10); + header.toStruct().writeTo(buffer); + int limit = buffer.position(); + buffer.position(10); + buffer.limit(limit); + + RequestHeader parsed = RequestHeader.parse(buffer); + assertEquals(header, parsed); + } } diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java index 383c550c89f33..f006a20ba2dda 100644 --- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java +++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java @@ -25,9 +25,11 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.internals.Topic; import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.Send; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.requests.ByteBufferChannel; import org.apache.kafka.common.requests.MetadataResponse; import org.apache.kafka.common.requests.RequestHeader; import org.apache.kafka.common.utils.Exit; @@ -531,6 +533,17 @@ public static ByteBuffer toBuffer(Struct struct) { return buffer; } + public static ByteBuffer toBuffer(Send send) { + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + try { + assertEquals(send.size(), send.writeTo(channel)); + channel.close(); + return channel.buffer(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public static Set generateRandomTopicPartitions(int numTopic, int numPartitionPerTopic) { Set tps = new HashSet<>(); for (int i = 0; i < numTopic; i++) { diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala index 87e6bfdccf2ee..064c0e85e7f2d 100644 --- a/core/src/main/scala/kafka/network/RequestChannel.scala +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -113,13 +113,11 @@ object RequestChannel extends Logging { def buildResponseSend(abstractResponse: AbstractResponse): Send = { envelope match { case Some(request) => - val envelopeResponse = new EnvelopeResponse( - abstractResponse.serialize(header.apiVersion, header.toResponseHeader), - Errors.NONE - ) - request.context.buildResponse(envelopeResponse) + val responseBytes = context.buildResponseEnvelopePayload(abstractResponse) + val envelopeResponse = new EnvelopeResponse(responseBytes, Errors.NONE) + request.context.buildResponseSend(envelopeResponse) case None => - context.buildResponse(abstractResponse) + context.buildResponseSend(abstractResponse) } } diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index dcfc08d75dc3f..1914c7bd4c084 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -1026,7 +1026,7 @@ private[kafka] class Processor(val id: Int, envelopeRequest: RequestChannel.Request, envelopeResponse: EnvelopeResponse ): Unit = { - val envelopResponseSend = envelopeRequest.context.buildResponse(envelopeResponse) + val envelopResponseSend = envelopeRequest.context.buildResponseSend(envelopeResponse) enqueueResponse(new RequestChannel.SendResponse( envelopeRequest, envelopResponseSend, @@ -1042,7 +1042,7 @@ private[kafka] class Processor(val id: Int, val envelope = envelopeRequest.body[EnvelopeRequest] try { principalSerde.map { serde => - serde.deserialize(envelope.principalData()) + serde.deserialize(envelope.requestPrincipal()) } } catch { case e: Exception => diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala index b5e1978361032..d0de33ad15878 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala @@ -129,6 +129,7 @@ class BrokerToControllerChannelManagerImpl(metadataCache: kafka.server.MetadataC requestQueue.put(BrokerToControllerQueueItem(request, callback)) requestThread.wakeup() } + } case class BrokerToControllerQueueItem(request: AbstractRequest.Builder[_ <: AbstractRequest], diff --git a/core/src/main/scala/kafka/server/ForwardingManager.scala b/core/src/main/scala/kafka/server/ForwardingManager.scala index 9b137bbe8ddf9..218a7e9ca424f 100644 --- a/core/src/main/scala/kafka/server/ForwardingManager.scala +++ b/core/src/main/scala/kafka/server/ForwardingManager.scala @@ -65,7 +65,7 @@ class ForwardingManager(metadataCache: kafka.server.MetadataCache, debug(s"Forwarded request $request failed with an error in envelope response $envelopeError") request.body[AbstractRequest].getErrorResponse(Errors.UNKNOWN_SERVER_ERROR.exception()) } else { - AbstractResponse.deserializeBody(envelopeResponse.responseData, request.header) + AbstractResponse.parseResponse(envelopeResponse.responseData, request.header) } responseCallback(response) } diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala index 675e66a6b7817..a8f7cd93838c3 100644 --- a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala +++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala @@ -92,7 +92,7 @@ class TestRaftRequestHandler( val response = responseOpt match { case Some(response) => - val responseSend = request.context.buildResponse(response) + val responseSend = request.context.buildResponseSend(response) val responseString = if (RequestChannel.isRequestLoggingEnabled) Some(response.toString(request.context.apiVersion)) else None diff --git a/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala index b27d5d34ba417..afd8110deca97 100644 --- a/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala +++ b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala @@ -25,18 +25,18 @@ import org.junit.Test import scala.jdk.CollectionConverters._ -class CreateTopicsRequestWithForwardingTest extends CreateTopicsRequestTest { +class CreateTopicsRequestWithForwardingTest extends AbstractCreateTopicsRequestTest { override def brokerPropertyOverrides(properties: Properties): Unit = { - super.brokerPropertyOverrides(properties) properties.put(KafkaConfig.EnableMetadataQuorumProp, true.toString) } @Test - override def testNotController(): Unit = { + def testForwardToController(): Unit = { val req = topicsReq(Seq(topicReq("topic1"))) val response = sendCreateTopicRequest(req, notControllerSocketServer) // With forwarding enabled, request could be forwarded to the active controller. assertEquals(Map(Errors.NONE -> 1), response.errorCounts().asScala) } + } diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 192cca970280a..216b845c8cf86 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -327,7 +327,7 @@ class KafkaApisTest { assertEquals(Errors.NONE, response.error) - val innerResponse = AbstractResponse.deserializeBody(response.responseData(), + val innerResponse = AbstractResponse.parseResponse(response.responseData(), requestHeader).asInstanceOf[AlterConfigsResponse] val responseMap = innerResponse.data.responses().asScala.map { resourceResponse => diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java index 9dcaa761386d4..67e6b8226eb0e 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -109,7 +109,7 @@ private void generateClass(Optional topLevelMessageSpec, buffer.printf("%n"); generateClassToStruct(className, struct, parentVersions); buffer.printf("%n"); - generateClassSize(className, struct, parentVersions); + generateClassMessageSize(className, struct, parentVersions); if (isSetElement) { buffer.printf("%n"); generateClassEquals(className, struct, true); @@ -628,18 +628,8 @@ private void generateVariableLengthReader(Versions fieldFlexibleVersions, buffer.printf("%snewBytes%s", assignmentPrefix, assignmentSuffix); } } else if (type.isRecords()) { - headerGenerator.addImport(MessageGenerator.RECORDS_READABLE_CLASS); - buffer.printf("if (_readable instanceof RecordsReadable) {%n"); - buffer.incrementIndent(); - buffer.printf("%s((RecordsReadable) _readable).readRecords(%s)%s", + buffer.printf("%s_readable.readRecords(%s)%s", assignmentPrefix, lengthVar, assignmentSuffix); - buffer.decrementIndent(); - buffer.printf("} else {%n"); - buffer.incrementIndent(); - buffer.printf("throw new RuntimeException(\"Cannot read records from " + - "reader of class: \" + _readable.getClass().getSimpleName());%n"); - buffer.decrementIndent(); - buffer.printf("}%n"); } else if (type.isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) type; if (isStructArrayWithKeys) { @@ -1142,16 +1132,7 @@ private void generateVariableLengthWriter(Versions fieldFlexibleVersions, buffer.printf("_writable.writeByteArray(%s);%n", name); } } else if (type.isRecords()) { - headerGenerator.addImport(MessageGenerator.RECORDS_WRITABLE_CLASS); - buffer.printf("if (_writable instanceof RecordsWritable) {%n"); - buffer.incrementIndent(); - buffer.printf("((RecordsWritable) _writable).writeRecords(%s);%n", name); - buffer.decrementIndent(); - buffer.printf("} else {%n"); - buffer.incrementIndent(); - buffer.printf("throw new RuntimeException(\"Cannot write records to writer of class: \" + _writable.getClass().getSimpleName());%n"); - buffer.decrementIndent(); - buffer.printf("}%n"); + buffer.printf("_writable.writeRecords(%s);%n", name); } else if (type.isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) type; FieldType elementType = arrayType.elementType(); @@ -1332,10 +1313,10 @@ private void generateTaggedFieldToMap(FieldSpec field, Versions versions) { private void generateFieldToObjectArray(FieldSpec field) { FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); FieldType elementType = arrayType.elementType(); - String boxdElementType = elementType.isStruct() ? "Struct" : + String boxedElementType = elementType.isStruct() ? "Struct" : elementType.getBoxedJavaType(headerGenerator); buffer.printf("%s[] _nestedObjects = new %s[%s.size()];%n", - boxdElementType, boxdElementType, field.camelCaseName()); + boxedElementType, boxedElementType, field.camelCaseName()); buffer.printf("int i = 0;%n"); buffer.printf("for (%s element : this.%s) {%n", arrayType.elementType().getBoxedJavaType(headerGenerator), field.camelCaseName()); @@ -1349,13 +1330,17 @@ private void generateFieldToObjectArray(FieldSpec field) { buffer.printf("}%n"); } - private void generateClassSize(String className, StructSpec struct, - Versions parentVersions) { + private void generateClassMessageSize( + String className, + StructSpec struct, + Versions parentVersions + ) { headerGenerator.addImport(MessageGenerator.OBJECT_SERIALIZATION_CACHE_CLASS); + headerGenerator.addImport(MessageGenerator.MESSAGE_SIZE_ACCUMULATOR_CLASS); buffer.printf("@Override%n"); - buffer.printf("public int size(ObjectSerializationCache _cache, short _version) {%n"); + buffer.printf("public void addSize(MessageSizeAccumulator _size, ObjectSerializationCache _cache, short _version) {%n"); buffer.incrementIndent(); - buffer.printf("int _size = 0, _numTaggedFields = 0;%n"); + buffer.printf("int _numTaggedFields = 0;%n"); VersionConditional.forVersions(parentVersions, struct.versions()). allowMembershipCheckAlwaysFalse(false). ifNotMember(__ -> { @@ -1384,9 +1369,9 @@ private void generateClassSize(String className, StructSpec struct, buffer.printf("for (RawTaggedField _field : _unknownTaggedFields) {%n"); buffer.incrementIndent(); headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_size += ByteUtils.sizeOfUnsignedVarint(_field.tag());%n"); - buffer.printf("_size += ByteUtils.sizeOfUnsignedVarint(_field.size());%n"); - buffer.printf("_size += _field.size();%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_field.tag()));%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_field.size()));%n"); + buffer.printf("_size.addBytes(_field.size());%n"); buffer.decrementIndent(); buffer.printf("}%n"); buffer.decrementIndent(); @@ -1397,10 +1382,9 @@ private void generateClassSize(String className, StructSpec struct, }). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_size += ByteUtils.sizeOfUnsignedVarint(_numTaggedFields);%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_numTaggedFields));%n"); }). generate(buffer); - buffer.printf("return _size;%n"); buffer.decrementIndent(); buffer.printf("}%n"); } @@ -1417,28 +1401,29 @@ private void generateVariableLengthArrayElementSize(Versions flexibleVersions, generateStringToBytes(fieldName); VersionConditional.forVersions(flexibleVersions, versions). ifNotMember(__ -> { - buffer.printf("_arraySize += _stringBytes.length + 2;%n"); + buffer.printf("_size.addBytes(_stringBytes.length + 2);%n"); }). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_arraySize += _stringBytes.length + " + - "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1);%n"); + buffer.printf("_size.addBytes(_stringBytes.length + " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1));%n"); }). generate(buffer); } else if (type instanceof FieldType.BytesFieldType) { + buffer.printf("_size.addBytes(%s.length);%n", fieldName); VersionConditional.forVersions(flexibleVersions, versions). ifNotMember(__ -> { - buffer.printf("_arraySize += %s.length + 4;%n", fieldName); + buffer.printf("_size.addBytes(4);%n"); }). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_arraySize += %s.length + " + - "ByteUtils.sizeOfUnsignedVarint(%s.length + 1);%n", - fieldName, fieldName); + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.length + 1));%n", + fieldName); }). generate(buffer); } else if (type instanceof FieldType.StructType) { - buffer.printf("_arraySize += %s.size(_cache, _version);%n", fieldName); + buffer.printf("%s.addSize(_size, _cache, _version);%n", fieldName); } else { throw new RuntimeException("Unsupported type " + type); } @@ -1463,16 +1448,16 @@ private void generateFixedLengthFieldSize(FieldSpec field, "this.", field.nullableVersions()); buffer.incrementIndent(); buffer.printf("_numTaggedFields++;%n"); - buffer.printf("_size += %d;%n", + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); // Account for the tagged field prefix length. - buffer.printf("_size += %d;%n", + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(field.type().fixedLength().get())); - buffer.printf("_size += %d;%n", field.type().fixedLength().get()); + buffer.printf("_size.addBytes(%d);%n", field.type().fixedLength().get()); buffer.decrementIndent(); buffer.printf("}%n"); } else { - buffer.printf("_size += %d;%n", field.type().fixedLength().get()); + buffer.printf("_size.addBytes(%d);%n", field.type().fixedLength().get()); } } @@ -1489,12 +1474,12 @@ private void generateVariableLengthFieldSize(FieldSpec field, ifMember(__ -> { if (tagged) { buffer.printf("_numTaggedFields++;%n"); - buffer.printf("_size += %d;%n", + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); - buffer.printf("_size += %d;%n", MessageGenerator.sizeOfUnsignedVarint( + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint( MessageGenerator.sizeOfUnsignedVarint(0))); } - buffer.printf("_size += %d;%n", MessageGenerator.sizeOfUnsignedVarint(0)); + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(0)); }). ifNotMember(__ -> { if (tagged) { @@ -1502,9 +1487,9 @@ private void generateVariableLengthFieldSize(FieldSpec field, " should not be present in non-flexible versions."); } if (field.type().isString()) { - buffer.printf("_size += 2;%n"); + buffer.printf("_size.addBytes(2);%n"); } else { - buffer.printf("_size += 4;%n"); + buffer.printf("_size.addBytes(4);%n"); } }). generate(buffer); @@ -1518,7 +1503,7 @@ private void generateVariableLengthFieldSize(FieldSpec field, buffer.incrementIndent(); } buffer.printf("_numTaggedFields++;%n"); - buffer.printf("_size += %d;%n", + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); } if (field.type().isString()) { @@ -1529,11 +1514,11 @@ private void generateVariableLengthFieldSize(FieldSpec field, if (tagged) { buffer.printf("int _stringPrefixSize = " + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1);%n"); - buffer.printf("_size += _stringBytes.length + _stringPrefixSize + " + - "ByteUtils.sizeOfUnsignedVarint(_stringPrefixSize);%n"); + buffer.printf("_size.addBytes(_stringBytes.length + _stringPrefixSize + " + + "ByteUtils.sizeOfUnsignedVarint(_stringPrefixSize));%n"); } else { - buffer.printf("_size += _stringBytes.length + " + - "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1);%n"); + buffer.printf("_size.addBytes(_stringBytes.length + " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1));%n"); } }). ifNotMember(__ -> { @@ -1541,25 +1526,27 @@ private void generateVariableLengthFieldSize(FieldSpec field, throw new RuntimeException("Tagged field " + field.name() + " should not be present in non-flexible versions."); } - buffer.printf("_size += _stringBytes.length + 2;%n"); + buffer.printf("_size.addBytes(_stringBytes.length + 2);%n"); }). generate(buffer); } else if (field.type().isArray()) { - buffer.printf("int _arraySize = 0;%n"); + if (tagged) { + buffer.printf("int _sizeBeforeArray = _size.totalSize();%n"); + } VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_arraySize += ByteUtils.sizeOfUnsignedVarint(%s.size() + 1);%n", + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(%s.size() + 1));%n", field.camelCaseName()); }). ifNotMember(__ -> { - buffer.printf("_arraySize += 4;%n"); + buffer.printf("_size.addBytes(4);%n"); }). generate(buffer); FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); FieldType elementType = arrayType.elementType(); if (elementType.fixedLength().isPresent()) { - buffer.printf("_arraySize += %s.size() * %d;%n", + buffer.printf("_size.addBytes(%s.size() * %d);%n", field.camelCaseName(), elementType.fixedLength().get()); } else if (elementType instanceof FieldType.ArrayType) { @@ -1579,58 +1566,60 @@ private void generateVariableLengthFieldSize(FieldSpec field, } if (tagged) { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("int _arraySize = _size.totalSize() - _sizeBeforeArray;%n"); buffer.printf("_cache.setArraySizeInBytes(%s, _arraySize);%n", field.camelCaseName()); - buffer.printf("_size += _arraySize + ByteUtils.sizeOfUnsignedVarint(_arraySize);%n"); - } else { - buffer.printf("_size += _arraySize;%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_arraySize));%n"); } } else if (field.type().isBytes()) { + if (tagged) { + buffer.printf("int _sizeBeforeBytes = _size.totalSize();%n"); + } if (field.zeroCopy()) { - buffer.printf("int _bytesSize = %s.remaining();%n", field.camelCaseName()); + buffer.printf("_size.addZeroCopyBytes(%s.remaining());%n", field.camelCaseName()); } else { - buffer.printf("int _bytesSize = %s.length;%n", field.camelCaseName()); + buffer.printf("_size.addBytes(%s.length);%n", field.camelCaseName()); } VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); if (field.zeroCopy()) { - buffer.printf("_bytesSize += " + - "ByteUtils.sizeOfUnsignedVarint(%s.remaining() + 1);%n", field.camelCaseName()); + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.remaining() + 1));%n", field.camelCaseName()); } else { - buffer.printf("_bytesSize += ByteUtils.sizeOfUnsignedVarint(%s.length + 1);%n", + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(%s.length + 1));%n", field.camelCaseName()); } }). ifNotMember(__ -> { - buffer.printf("_bytesSize += 4;%n"); + buffer.printf("_size.addBytes(4);%n"); }). generate(buffer); if (tagged) { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_size += _bytesSize + ByteUtils.sizeOfUnsignedVarint(_bytesSize);%n"); - } else { - buffer.printf("_size += _bytesSize;%n"); + buffer.printf("int _bytesSize = _size.totalSize() - _sizeBeforeBytes;%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_bytesSize));%n"); } } else if (field.type().isRecords()) { - buffer.printf("int _recordsSize = %s.sizeInBytes();%n", field.camelCaseName()); + buffer.printf("_size.addZeroCopyBytes(%s.sizeInBytes());%n", field.camelCaseName()); VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). ifMember(__ -> { headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); - buffer.printf("_recordsSize += " + - "ByteUtils.sizeOfUnsignedVarint(%s.sizeInBytes() + 1);%n", field.camelCaseName()); + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.sizeInBytes() + 1));%n", field.camelCaseName()); }). ifNotMember(__ -> { - buffer.printf("_recordsSize += 4;%n"); + buffer.printf("_size.addBytes(4);%n"); }). generate(buffer); - buffer.printf("_size += _recordsSize;%n"); } else if (field.type().isStruct()) { - buffer.printf("int size = this.%s.size(_cache, _version);%n", field.camelCaseName()); + buffer.printf("int _sizeBeforeStruct = _size.totalSize();%n", field.camelCaseName()); + buffer.printf("this.%s.addSize(_size, _cache, _version);%n", field.camelCaseName()); + buffer.printf("int _structSize = _size.totalSize() - _sizeBeforeStruct;%n", field.camelCaseName()); + if (tagged) { - buffer.printf("_size += ByteUtils.sizeOfUnsignedVarint(size);%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_structSize));%n"); } - buffer.printf("_size += size;%n"); } else { throw new RuntimeException("unhandled type " + field.type()); } diff --git a/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java index 4b5848bddc7d4..507d53a39c7f3 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java @@ -59,12 +59,8 @@ public final class MessageGenerator { static final String READABLE_CLASS = "org.apache.kafka.common.protocol.Readable"; - static final String RECORDS_READABLE_CLASS = "org.apache.kafka.common.protocol.RecordsReadable"; - static final String WRITABLE_CLASS = "org.apache.kafka.common.protocol.Writable"; - static final String RECORDS_WRITABLE_CLASS = "org.apache.kafka.common.protocol.RecordsWritable"; - static final String ARRAYS_CLASS = "java.util.Arrays"; static final String OBJECTS_CLASS = "java.util.Objects"; @@ -116,6 +112,8 @@ public final class MessageGenerator { static final String OBJECT_SERIALIZATION_CACHE_CLASS = "org.apache.kafka.common.protocol.ObjectSerializationCache"; + static final String MESSAGE_SIZE_ACCUMULATOR_CLASS = "org.apache.kafka.common.protocol.MessageSizeAccumulator"; + static final String RAW_TAGGED_FIELD_CLASS = "org.apache.kafka.common.protocol.types.RawTaggedField"; static final String RAW_TAGGED_FIELD_WRITER_CLASS = "org.apache.kafka.common.protocol.types.RawTaggedFieldWriter";