diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java index 9077b107ab03e..d444fb6eaa416 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java @@ -20,6 +20,10 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.record.RecordBatch; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -34,6 +38,14 @@ public class ProduceRequestResult { private final CountDownLatch latch = new CountDownLatch(1); private final TopicPartition topicPartition; + /** + * List of dependent ProduceRequestResults created when this batch is split. + * When a batch is too large to send, it's split into multiple smaller batches. + * The original batch's ProduceRequestResult tracks all the split batches here + * so that flush() can wait for all splits to complete via awaitAllDependents(). + */ + private final List dependentResults = new ArrayList<>(); + private volatile Long baseOffset = null; private volatile long logAppendTime = RecordBatch.NO_TIMESTAMP; private volatile Function errorsByIndex; @@ -41,7 +53,7 @@ public class ProduceRequestResult { /** * Create an instance of this class. * - * @param topicPartition The topic and partition to which this record set was sent was sent + * @param topicPartition The topic and partition to which this record set was sent */ public ProduceRequestResult(TopicPartition topicPartition) { this.topicPartition = topicPartition; @@ -70,7 +82,29 @@ public void done() { } /** - * Await the completion of this request + * Add a dependent ProduceRequestResult. + * This is used when a batch is split into multiple batches - in some cases like flush(), the original + * batch's result should not complete until all split batches have completed. + * + * @param dependentResult The dependent result to wait for + */ + public void addDependent(ProduceRequestResult dependentResult) { + synchronized (dependentResults) { + dependentResults.add(dependentResult); + } + } + + /** + * Await the completion of this request. + * + * This only waits for THIS request's latch and not dependent results. + * When a batch is split into multiple batches, dependent results are created and tracked + * separately, but this method does not wait for them. Individual record futures automatically + * handle waiting for their respective split batch via {@link FutureRecordMetadata#chain(FutureRecordMetadata)}, + * which redirects the future to point to the correct split batch's result. + * + * For flush() semantics that require waiting for all dependent results, use + * {@link #awaitAllDependents()}. */ public void await() throws InterruptedException { latch.await(); @@ -86,6 +120,34 @@ public boolean await(long timeout, TimeUnit unit) throws InterruptedException { return latch.await(timeout, unit); } + /** + * Await the completion of this request and all the dependent requests. + * + * This method is used by flush() to ensure all split batches have completed before + * returning. This method waits for all dependent {@link ProduceRequestResult}s that + * were created when the batch was split. + * + * @throws InterruptedException if the thread is interrupted while waiting + */ + public void awaitAllDependents() throws InterruptedException { + Queue toWait = new ArrayDeque<>(); + toWait.add(this); + + while (!toWait.isEmpty()) { + ProduceRequestResult current = toWait.poll(); + + // first wait for THIS result's latch to be released + current.latch.await(); + + // add all dependent split batches to the queue. + // we synchronize to get a consistent snapshot, then release the lock + // before continuing but the actual waiting happens outside the lock. + synchronized (current.dependentResults) { + toWait.addAll(current.dependentResults); + } + } + } + /** * The base offset for the request (the first offset in the record set) */ @@ -127,6 +189,15 @@ public TopicPartition topicPartition() { /** * Has the request completed? + * + * This method only checks if THIS request has completed and not its dependent results. + * When a batch is split into multiple batches, the dependent split batches are tracked + * separately. Individual record futures handle waiting for their respective split + * batch via {@link FutureRecordMetadata#chain(FutureRecordMetadata)}, which updates the + * {@code nextRecordMetadata} pointer to follow the correct split batch. + * + * For flush() semantics that require waiting for all dependent results, use + * {@link #awaitAllDependents()}. */ public boolean completed() { return this.latch.getCount() == 0L; diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java index 5619819dde72e..c4f9c0f7f08d3 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java @@ -321,10 +321,16 @@ private void completeFutureAndFireCallbacks( } public Deque split(int splitBatchSize) { - Deque batches = new ArrayDeque<>(); - MemoryRecords memoryRecords = recordsBuilder.build(); + RecordBatch recordBatch = validateAndGetRecordBatch(); + Deque batches = splitRecordsIntoBatches(recordBatch, splitBatchSize); + finalizeSplitBatches(batches); + return batches; + } + private RecordBatch validateAndGetRecordBatch() { + MemoryRecords memoryRecords = recordsBuilder.build(); Iterator recordBatchIter = memoryRecords.batches().iterator(); + if (!recordBatchIter.hasNext()) throw new IllegalStateException("Cannot split an empty producer batch."); @@ -336,6 +342,11 @@ public Deque split(int splitBatchSize) { if (recordBatchIter.hasNext()) throw new IllegalArgumentException("A producer batch should only have one record batch."); + return recordBatch; + } + + private Deque splitRecordsIntoBatches(RecordBatch recordBatch, int splitBatchSize) { + Deque batches = new ArrayDeque<>(); Iterator thunkIter = thunks.iterator(); // We always allocate batch size because we are already splitting a big batch. // And we also Retain the create time of the original batch. @@ -362,9 +373,23 @@ public Deque split(int splitBatchSize) { batch.closeForRecordAppends(); } + return batches; + } + + private void finalizeSplitBatches(Deque batches) { + // Chain all split batch ProduceRequestResults to the original batch's produceFuture + // Ensures the original batch's future doesn't complete until all split batches complete + for (ProducerBatch splitBatch : batches) { + produceFuture.addDependent(splitBatch.produceFuture); + } + produceFuture.set(ProduceResponse.INVALID_OFFSET, NO_TIMESTAMP, index -> new RecordBatchTooLargeException()); produceFuture.done(); + assignProducerStateToBatches(batches); + } + + private void assignProducerStateToBatches(Deque batches) { if (hasSequence()) { int sequence = baseSequence(); ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId(), producerEpoch()); @@ -373,7 +398,6 @@ public Deque split(int splitBatchSize) { sequence += newBatch.recordCount; } } - return batches; } private ProducerBatch createBatchOffAccumulatorForRecord(Record record, int batchSize) { diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java index f0c2719db9612..d3c774cb6f5a8 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java @@ -1076,8 +1076,13 @@ public void awaitFlushCompletion() throws InterruptedException { // We must be careful not to hold a reference to the ProduceBatch(s) so that garbage // collection can occur on the contents. // The sender will remove ProducerBatch(s) from the original incomplete collection. + // + // We use awaitAllDependents() here instead of await() to ensure that if any batch + // was split into multiple batches, we wait for all the split batches to complete. + // This is required to guarantee that all records sent before flush() + // must be fully complete, including records in split batches. for (ProduceRequestResult result : this.incomplete.requestResults()) - result.await(); + result.awaitAllDependents(); } finally { this.flushesInProgress.decrementAndGet(); } diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java index 750440d2595a5..d6ac75a37a0e6 100644 --- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java @@ -39,6 +39,7 @@ import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.common.requests.MetadataResponse; import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata; @@ -71,7 +72,9 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -1066,6 +1069,41 @@ public void testSplitAndReenqueue() throws ExecutionException, InterruptedExcept assertEquals(1, future2.get().offset()); } + // here I am testing the hasRoomFor() behaviour + // It allows the first record no matter the size + // but does not allow the second record + @Test + public void testHasRoomForAllowsOversizedFirstRecordButRejectsSubsequentRecords() { + long now = time.milliseconds(); + int smallBatchSize = 1024; + + // Create a large record that exceeds batch size limit + byte[] largeValue = new byte[4 * 1024]; // 4KB > 1KB + + // Create a small buffer that cannot fit the large record + ByteBuffer buffer = ByteBuffer.allocate(smallBatchSize); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, Compression.NONE, TimestampType.CREATE_TIME, 0L); + + // testing existing code: + // hasRoomFor() should return true for first record regardless of size + boolean hasRoomForFirst = builder.hasRoomFor(now, ByteBuffer.wrap(key), ByteBuffer.wrap(largeValue), Record.EMPTY_HEADERS); + assertTrue(hasRoomForFirst, "hasRoomFor() should return true for first record regardless of size when numRecords == 0"); + + // append the first oversized record - should succeed + builder.append(now, ByteBuffer.wrap(key), ByteBuffer.wrap(largeValue), Record.EMPTY_HEADERS); + assertEquals(1, builder.numRecords(), "Should have successfully appended the first oversized record"); + + // now append another large record when numRecords > 0 + boolean hasRoomForSecond = builder.hasRoomFor(now, ByteBuffer.wrap(key), ByteBuffer.wrap(largeValue), Record.EMPTY_HEADERS); + assertFalse(hasRoomForSecond, "hasRoomFor() should return false for oversized record when numRecords > 0"); + + // Now append with a smaller record that would normally fit but + // this too should be rejected due to limited buffer space + byte[] smallValue = new byte[100]; // Small record + boolean hasRoomForSmall = builder.hasRoomFor(now, ByteBuffer.wrap(key), ByteBuffer.wrap(smallValue), Record.EMPTY_HEADERS); + assertFalse(hasRoomForSmall, "hasRoomFor() should return false for any record when buffer is full from oversized first record"); + } + @Test public void testSplitBatchOffAccumulator() throws InterruptedException { long seed = System.currentTimeMillis(); @@ -1790,4 +1828,65 @@ public void testSplitAndReenqueuePreventInfiniteRecursion() throws InterruptedEx // Verify all original records are accounted for (no data loss) assertEquals(100, keyFoundMap.size(), "All original 100 records should be present after splitting"); } + + @Test + public void testProduceRequestResultAwaitAllDependents() throws Exception { + ProduceRequestResult parent = new ProduceRequestResult(tp1); + + // make two dependent ProduceRequestResults -- mimicking split batches + ProduceRequestResult dependent1 = new ProduceRequestResult(tp1); + ProduceRequestResult dependent2 = new ProduceRequestResult(tp1); + + // add dependents + parent.addDependent(dependent1); + parent.addDependent(dependent2); + + parent.set(0L, RecordBatch.NO_TIMESTAMP, null); + parent.done(); + + // parent.completed() should return true (only checks latch) + assertTrue(parent.completed(), "Parent should be completed after done()"); + + // awaitAllDependents() should block because dependents are not complete + final AtomicBoolean awaitCompleted = new AtomicBoolean(false); + final AtomicReference awaitException = new AtomicReference<>(); + + // to prove awaitAllDependents() is blocking, we run it in a separate thread + Thread awaitThread = new Thread(() -> { + try { + parent.awaitAllDependents(); + awaitCompleted.set(true); + } catch (Exception e) { + awaitException.set(e); + } + }); + awaitThread.start(); + Thread.sleep(5); + + // verify awaitAllDependents() is blocking + assertFalse(awaitCompleted.get(), + "awaitAllDependents() should block because dependents are not complete"); + + // now complete the first dependent + dependent1.set(0L, RecordBatch.NO_TIMESTAMP, null); + dependent1.done(); + + Thread.sleep(5); + + // this should still be blocking because dependent2 is not complete + assertFalse(awaitCompleted.get(), + "awaitAllDependents() should still block because dependent2 is not complete"); + + // now complete the second dependent + dependent2.set(0L, RecordBatch.NO_TIMESTAMP, null); + dependent2.done(); + + // now awaitAllDependents() should complete + awaitThread.join(5000); + + assertNull(awaitException.get(), "awaitAllDependents() should not throw exception"); + assertTrue(awaitCompleted.get(), + "awaitAllDependents() should complete after all dependents are done"); + assertFalse(awaitThread.isAlive(), "await thread should have completed"); + } } diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/RecordAccumulatorFlushBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/RecordAccumulatorFlushBenchmark.java new file mode 100644 index 0000000000000..605d76abbe70d --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/RecordAccumulatorFlushBenchmark.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.producer; + +import org.apache.kafka.clients.MetadataSnapshot; +import org.apache.kafka.clients.producer.internals.BufferPool; +import org.apache.kafka.clients.producer.internals.ProducerBatch; +import org.apache.kafka.clients.producer.internals.RecordAccumulator; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class RecordAccumulatorFlushBenchmark { + + private static final String TOPIC = "test"; + private static final int PARTITION = 0; + private static final int BATCH_SIZE = 1024; + private static final long TOTAL_SIZE = 10 * 1024 * 1024; + + @Param({"5000", "10000"}) + private int numRecords; + + private RecordAccumulator accum; + private Metrics metrics; + private TopicPartition tp; + private Time time; + + @Setup(Level.Invocation) + public void setup() throws InterruptedException { + tp = new TopicPartition(TOPIC, PARTITION); + time = new MockTime(); + metrics = new Metrics(time); + + Cluster cluster = createTestCluster(); + accum = createRecordAccumulator(); + + appendRecords(cluster); + prepareFlush(); + } + + @TearDown(Level.Invocation) + public void tearDown() { + deallocateBatches(); + if (metrics != null) { + metrics.close(); + } + } + + @Benchmark + public void measureFlushCompletion() throws InterruptedException { + accum.awaitFlushCompletion(); + } + + private Cluster createTestCluster() { + Node node = new Node(0, "localhost", 1111); + MetadataResponse.PartitionMetadata partMetadata = new MetadataResponse.PartitionMetadata( + Errors.NONE, + tp, + Optional.of(node.id()), + Optional.empty(), + null, + null, + null + ); + + Map nodes = Stream.of(node).collect(Collectors.toMap(Node::id, Function.identity())); + MetadataSnapshot metadataCache = new MetadataSnapshot( + null, + nodes, + Collections.singletonList(partMetadata), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + null, + Collections.emptyMap() + ); + return metadataCache.cluster(); + } + + private RecordAccumulator createRecordAccumulator() { + return new RecordAccumulator( + new LogContext(), + BATCH_SIZE + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, + Compression.NONE, + Integer.MAX_VALUE, // lingerMs + 100L, // retryBackoffMs + 1000L, // retryBackoffMaxMs + 3200, // deliveryTimeoutMs + metrics, + "producer-metrics", + time, + null, + new BufferPool(TOTAL_SIZE, BATCH_SIZE, metrics, time, "producer-metrics") + ); + } + + private void appendRecords(Cluster cluster) throws InterruptedException { + byte[] key = "key".getBytes(StandardCharsets.UTF_8); + byte[] value = "value".getBytes(StandardCharsets.UTF_8); + + for (int i = 0; i < numRecords; i++) { + accum.append( + TOPIC, + PARTITION, + 0L, + key, + value, + Record.EMPTY_HEADERS, + null, + 1000L, + time.milliseconds(), + cluster + ); + } + } + + private void prepareFlush() { + accum.beginFlush(); + + // Complete all batches to mimic successful sends + List batches = new ArrayList<>(accum.getDeque(tp)); + for (ProducerBatch batch : batches) { + batch.complete(0L, time.milliseconds()); + } + } + + private void deallocateBatches() { + if (accum != null && tp != null) { + List batches = new ArrayList<>(accum.getDeque(tp)); + for (ProducerBatch batch : batches) { + accum.deallocate(batch); + } + } + } +}