From 1cb05c995ae9eddcf91f9b1e29c3b7a59d9521bf Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Tue, 18 Mar 2025 08:51:58 +0000 Subject: [PATCH 1/5] Improve caching in backlog estimation and processing --- .../beam/sdk/io/kafka/KafkaIOUtils.java | 15 + .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 649 ++++++++++-------- .../sdk/io/kafka/ReadFromKafkaDoFnTest.java | 2 + 3 files changed, 361 insertions(+), 305 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java index 8b778ce5481e..f404c8d9640f 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOUtils.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kafka; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.nio.charset.StandardCharsets; import java.util.HashMap; @@ -130,6 +131,20 @@ static Map getOffsetConsumerConfig( return offsetConsumerConfig; } + static Map overrideBootstrapServersConfig( + Map currentConfig, KafkaSourceDescriptor description) { + checkState( + currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) + || description.getBootStrapServers() != null); + Map config = new HashMap<>(currentConfig); + if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) { + config.put( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + String.join(",", description.getBootStrapServers())); + } + return config; + } + /* * Maintains approximate average over last 1000 elements. * Usage is only thread-safe for a single producer and multiple consumers. diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 3ab6c4f502e7..19e8c0e47107 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,17 +19,18 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.io.Closeable; import java.math.BigDecimal; import java.math.MathContext; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -50,19 +51,22 @@ import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.MemoizingPerInstantiationSerializableSupplier; import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; @@ -70,6 +74,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigDef; import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.serialization.Deserializer; import org.checkerframework.checker.nullness.qual.Nullable; @@ -186,19 +191,98 @@ private static class Bounded extends ReadFromKafkaDoFn { private ReadFromKafkaDoFn( ReadSourceDescriptors transform, TupleTag>> recordTag) { + final SerializableFunction, Consumer> consumerFactoryFn = + transform.getConsumerFactoryFn(); this.consumerConfig = transform.getConsumerConfig(); - this.offsetConsumerConfig = transform.getOffsetConsumerConfig(); this.keyDeserializerProvider = Preconditions.checkArgumentNotNull(transform.getKeyDeserializerProvider()); this.valueDeserializerProvider = Preconditions.checkArgumentNotNull(transform.getValueDeserializerProvider()); - this.consumerFactoryFn = transform.getConsumerFactoryFn(); this.extractOutputTimestampFn = transform.getExtractOutputTimestampFn(); this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn(); this.timestampPolicyFactory = transform.getTimestampPolicyFactory(); this.checkStopReadingFn = transform.getCheckStopReadingFn(); this.badRecordRouter = transform.getBadRecordRouter(); this.recordTag = recordTag; + this.avgRecordSizeCacheSupplier = + new MemoizingPerInstantiationSerializableSupplier<>( + () -> + CacheBuilder.newBuilder() + .concurrencyLevel(Runtime.getRuntime().availableProcessors()) + .weakValues() + .build( + new CacheLoader() { + @Override + public MovingAvg load(KafkaSourceDescriptor kafkaSourceDescriptor) + throws Exception { + return new MovingAvg(); + } + })); + this.latestOffsetEstimatorCacheSupplier = + new MemoizingPerInstantiationSerializableSupplier<>( + () -> + CacheBuilder.newBuilder() + .concurrencyLevel(Runtime.getRuntime().availableProcessors()) + .weakValues() + .removalListener( + (RemovalNotification + notification) -> { + final @Nullable KafkaLatestOffsetEstimator value; + if (notification.getCause() == RemovalCause.COLLECTED + && (value = notification.getValue()) != null) { + value.close(); + } + }) + .build( + new CacheLoader() { + @Override + public KafkaLatestOffsetEstimator load( + final KafkaSourceDescriptor sourceDescriptor) { + LOG.info( + "Creating Kafka consumer for offset estimation for {}", + sourceDescriptor); + final Map config = + KafkaIOUtils.overrideBootstrapServersConfig( + consumerConfig, sourceDescriptor); + final Consumer consumer = + consumerFactoryFn.apply(config); + return new KafkaLatestOffsetEstimator( + consumer, sourceDescriptor.getTopicPartition()); + } + })); + this.pollConsumerCacheSupplier = + new MemoizingPerInstantiationSerializableSupplier<>( + () -> + CacheBuilder.newBuilder() + .concurrencyLevel(Runtime.getRuntime().availableProcessors()) + .weakValues() + .removalListener( + (RemovalNotification> + notification) -> { + final @Nullable Consumer value; + if (notification.getCause() == RemovalCause.COLLECTED + && (value = notification.getValue()) != null) { + value.close(); + } + }) + .build( + new CacheLoader>() { + @Override + public Consumer load( + KafkaSourceDescriptor sourceDescriptor) { + LOG.info( + "Creating Kafka consumer for restriction processing for {}", + sourceDescriptor); + final Map config = + KafkaIOUtils.overrideBootstrapServersConfig( + consumerConfig, sourceDescriptor); + final Consumer consumer = + consumerFactoryFn.apply(config); + consumer.assign( + Collections.singleton(sourceDescriptor.getTopicPartition())); + return consumer; + } + })); if (transform.getConsumerPollingTimeout() > 0) { this.consumerPollingTimeout = transform.getConsumerPollingTimeout(); } else { @@ -208,29 +292,10 @@ private ReadFromKafkaDoFn( private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class); - /** - * A holder class for all construction time unique instances of {@link ReadFromKafkaDoFn}. Caches - * must run clean up tasks when {@link #teardown()} is called. - */ - private static final class SharedStateHolder { - - private static final Map> - OFFSET_ESTIMATOR_CACHE = new ConcurrentHashMap<>(); - private static final Map> - AVG_RECORD_SIZE_CACHE = new ConcurrentHashMap<>(); - } - - private static final AtomicLong FN_ID = new AtomicLong(); - - // A unique identifier for the instance. Generally unique unless the ID generator overflows. - private final long fnId = FN_ID.getAndIncrement(); - - private final @Nullable Map offsetConsumerConfig; + private static final Joiner COMMA_JOINER = Joiner.on(','); private final @Nullable CheckStopReadingFn checkStopReadingFn; - private final SerializableFunction, Consumer> - consumerFactoryFn; private final @Nullable SerializableFunction, Instant> extractOutputTimestampFn; private final @Nullable SerializableFunction> createWatermarkEstimatorFn; @@ -240,13 +305,19 @@ private static final class SharedStateHolder { private final TupleTag>> recordTag; + private final SerializableSupplier> + avgRecordSizeCacheSupplier; + + private final SerializableSupplier< + LoadingCache> + latestOffsetEstimatorCacheSupplier; + + private final SerializableSupplier>> + pollConsumerCacheSupplier; + // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; - private transient @Nullable LoadingCache - offsetEstimatorCache; - - private transient @Nullable LoadingCache avgRecordSizeCache; private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @@ -262,82 +333,133 @@ private static final class SharedStateHolder { * fetch backlog. */ private static class KafkaLatestOffsetEstimator - implements GrowableOffsetRangeTracker.RangeEndEstimator { - + implements GrowableOffsetRangeTracker.RangeEndEstimator, Closeable { + private static final AtomicReferenceFieldUpdater + CURRENT_REFRESH_TASK = + (AtomicReferenceFieldUpdater) + AtomicReferenceFieldUpdater.newUpdater( + KafkaLatestOffsetEstimator.class, Runnable.class, "currentRefreshTask"); + private final Executor executor; private final Consumer offsetConsumer; private final TopicPartition topicPartition; - private final Supplier memoizedBacklog; + private long lastRefreshEndOffset; + private long nextRefreshNanos; + private volatile @Nullable Runnable currentRefreshTask; KafkaLatestOffsetEstimator( - Consumer offsetConsumer, TopicPartition topicPartition) { + final Consumer offsetConsumer, final TopicPartition topicPartition) { + this.executor = Executors.newSingleThreadExecutor(); this.offsetConsumer = offsetConsumer; this.topicPartition = topicPartition; - memoizedBacklog = - Suppliers.memoizeWithExpiration( - () -> { - synchronized (offsetConsumer) { - return Preconditions.checkStateNotNull( - offsetConsumer - .endOffsets(Collections.singleton(topicPartition)) - .get(topicPartition), - "No end offset found for partition %s.", - topicPartition); - } - }, - 1, - TimeUnit.SECONDS); + this.lastRefreshEndOffset = -1L; + this.nextRefreshNanos = Long.MIN_VALUE; + this.currentRefreshTask = null; } @Override - protected void finalize() { - try { - Closeables.close(offsetConsumer, true); - LOG.info("Offset Estimator consumer was closed for {}", topicPartition); - } catch (Exception anyException) { - LOG.warn("Failed to close offset consumer for {}", topicPartition); + public long estimate() { + final @Nullable Runnable task = currentRefreshTask; // volatile load (acquire) + + final long currentNanos; + if (task == null + && nextRefreshNanos < (currentNanos = System.nanoTime()) // normal load + && CURRENT_REFRESH_TASK.compareAndSet(this, null, this::refresh)) { // volatile load/store + try { + executor.execute(this::refresh); + } catch (RejectedExecutionException ex) { + LOG.error("Execution of end offset refresh rejected for {}", topicPartition, ex); + nextRefreshNanos = currentNanos + TimeUnit.SECONDS.toNanos(1); // normal store + CURRENT_REFRESH_TASK.lazySet(this, null); // ordered store (release) + } } + + return lastRefreshEndOffset; // normal load } @Override - public long estimate() { - return memoizedBacklog.get(); + public void close() { + offsetConsumer.close(); + } + + private void refresh() { + @Nullable + Long endOffset = + offsetConsumer.endOffsets(Collections.singleton(topicPartition)).get(topicPartition); + if (endOffset == null) { + LOG.warn("No end offset found for partition {}.", topicPartition); + } else { + lastRefreshEndOffset = endOffset; // normal store + } + nextRefreshNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(1); // normal store + CURRENT_REFRESH_TASK.lazySet(this, null); // ordered store (release) } } @GetInitialRestriction public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSourceDescriptor) { - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - TopicPartition partition = kafkaSourceDescriptor.getTopicPartition(); - LOG.info("Creating Kafka consumer for initial restriction for {}", kafkaSourceDescriptor); - try (Consumer offsetConsumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - offsetConsumer.assign(ImmutableList.of(partition)); - long startOffset; - @Nullable Instant startReadTime = kafkaSourceDescriptor.getStartReadTime(); - if (kafkaSourceDescriptor.getStartReadOffset() != null) { - startOffset = kafkaSourceDescriptor.getStartReadOffset(); - } else if (startReadTime != null) { - startOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, startReadTime); - } else { - startOffset = offsetConsumer.position(partition); - } + final Consumer consumer = + pollConsumerCacheSupplier.get().getUnchecked(kafkaSourceDescriptor); + + final long startOffset; + final long stopOffset; + + final @Nullable Long startReadOffset = kafkaSourceDescriptor.getStartReadOffset(); + final @Nullable Instant startReadTime = kafkaSourceDescriptor.getStartReadTime(); + if (startReadOffset != null) { + startOffset = startReadOffset; + } else if (startReadTime != null) { + startOffset = + Preconditions.checkStateNotNull( + consumer + .offsetsForTimes( + Collections.singletonMap( + kafkaSourceDescriptor.getTopicPartition(), startReadTime.getMillis())) + .get(kafkaSourceDescriptor.getTopicPartition())) + .offset(); + } else { + startOffset = consumer.position(kafkaSourceDescriptor.getTopicPartition()); + } - long endOffset = Long.MAX_VALUE; - @Nullable Instant stopReadTime = kafkaSourceDescriptor.getStopReadTime(); - if (kafkaSourceDescriptor.getStopReadOffset() != null) { - endOffset = kafkaSourceDescriptor.getStopReadOffset(); - } else if (stopReadTime != null) { - endOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, stopReadTime); - } - new OffsetRange(startOffset, endOffset); - Lineage.getSources() - .add( - "kafka", - ImmutableList.of( - (String) updatedConsumerConfig.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), - MoreObjects.firstNonNull(kafkaSourceDescriptor.getTopic(), partition.topic()))); - return new OffsetRange(startOffset, endOffset); + final @Nullable Long stopReadOffset = kafkaSourceDescriptor.getStopReadOffset(); + final @Nullable Instant stopReadTime = kafkaSourceDescriptor.getStopReadTime(); + if (stopReadOffset != null) { + stopOffset = stopReadOffset; + } else if (stopReadTime != null) { + stopOffset = + Preconditions.checkStateNotNull( + consumer + .offsetsForTimes( + Collections.singletonMap( + kafkaSourceDescriptor.getTopicPartition(), stopReadTime.getMillis())) + .get(kafkaSourceDescriptor.getTopicPartition())) + .offset(); + } else { + stopOffset = Long.MAX_VALUE; } + + final OffsetRange initialRestriction = new OffsetRange(startOffset, stopOffset); + Lineage.getSources() + .add( + "kafka", + ImmutableList.of( + Optional.ofNullable( + KafkaIOUtils.overrideBootstrapServersConfig( + consumerConfig, kafkaSourceDescriptor) + .get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)) + .map( + value -> + (@Nullable List) + ConfigDef.parseType( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + value, + ConfigDef.Type.LIST)) + .map(ImmutableSet::copyOf) + .map(COMMA_JOINER::join) + .get(), + MoreObjects.firstNonNull( + kafkaSourceDescriptor.getTopic(), + kafkaSourceDescriptor.getTopicPartition().topic()))); + return initialRestriction; } @GetInitialWatermarkEstimatorState @@ -355,13 +477,11 @@ public WatermarkEstimator newWatermarkEstimator( @GetSize public double getSize( - @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) - throws ExecutionException { - // If present, estimates the record size. - final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); + @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) { + // If present, estimates the record size to offset gap ratio. Compacted topics may hold less + // records than the estimated offset range due to record deletion within a partition. final @Nullable MovingAvg avgRecordSize = - avgRecordSizeCache.getIfPresent(kafkaSourceDescriptor); + avgRecordSizeCacheSupplier.get().getIfPresent(kafkaSourceDescriptor); // The tracker estimates the offset range by subtracting the last claimed position from the // currently observed end offset for the partition belonging to this split. final double estimatedOffsetRange = @@ -377,8 +497,7 @@ public double getSize( @NewTracker public OffsetRangeTracker restrictionTracker( - @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) - throws ExecutionException { + @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) { if (restriction.getTo() < Long.MAX_VALUE) { return new OffsetRangeTracker(restriction); } @@ -386,12 +505,9 @@ public OffsetRangeTracker restrictionTracker( // OffsetEstimators are cached for each topic-partition because they hold a stateful connection, // so we want to minimize the amount of connections that we start and track with Kafka. Another // point is that it has a memoized backlog, and this should make that more reusable estimations. - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); - final KafkaLatestOffsetEstimator offsetEstimator = - offsetEstimatorCache.get(kafkaSourceDescriptor); - - return new GrowableOffsetRangeTracker(restriction.getFrom(), offsetEstimator); + return new GrowableOffsetRangeTracker( + restriction.getFrom(), + latestOffsetEstimatorCacheSupplier.get().getUnchecked(kafkaSourceDescriptor)); } @ProcessElement @@ -401,10 +517,11 @@ public ProcessContinuation processElement( WatermarkEstimator watermarkEstimator, MultiOutputReceiver receiver) throws Exception { - final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); + final MovingAvg avgRecordSize = avgRecordSizeCacheSupplier.get().get(kafkaSourceDescriptor); + final KafkaLatestOffsetEstimator latestOffsetEstimator = + latestOffsetEstimatorCacheSupplier.get().get(kafkaSourceDescriptor); + final Consumer consumer = + pollConsumerCacheSupplier.get().get(kafkaSourceDescriptor); final Deserializer keyDeserializerInstance = Preconditions.checkStateNotNull(this.keyDeserializerInstance); final Deserializer valueDeserializerInstance = @@ -419,15 +536,12 @@ public ProcessContinuation processElement( METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + "backlogBytes_" + topicPartition.toString()); // Stop processing current TopicPartition when it's time to stop. - if (checkStopReadingFn != null - && checkStopReadingFn.apply(kafkaSourceDescriptor.getTopicPartition())) { + if (checkStopReadingFn != null && checkStopReadingFn.apply(topicPartition)) { // Attempt to claim the last element in the restriction, such that the restriction tracker // doesn't throw an exception when checkDone is called tracker.tryClaim(tracker.currentRestriction().getTo() - 1); return ProcessContinuation.stop(); } - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); // If there is a timestampPolicyFactory, create the TimestampPolicy for current // TopicPartition. TimestampPolicy timestampPolicy = null; @@ -437,137 +551,129 @@ public ProcessContinuation processElement( topicPartition, Optional.ofNullable(watermarkEstimator.currentWatermark())); } - LOG.info("Creating Kafka consumer for process continuation for {}", kafkaSourceDescriptor); - try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - consumer.assign(ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); - long startOffset = tracker.currentRestriction().getFrom(); - long expectedOffset = startOffset; - consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset); - ConsumerRecords rawRecords = ConsumerRecords.empty(); - long skippedRecords = 0L; - final Stopwatch sw = Stopwatch.createStarted(); - - KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics(); - try { - while (true) { - // Fetch the record size accumulator. - final MovingAvg avgRecordSize = avgRecordSizeCache.getUnchecked(kafkaSourceDescriptor); - rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition(), kafkaMetrics); - // When there are no records available for the current TopicPartition, self-checkpoint - // and move to process the next element. - if (rawRecords.isEmpty()) { - if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), - consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { - return ProcessContinuation.stop(); - } - if (timestampPolicy != null) { - updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - } - return ProcessContinuation.resume(); + long startOffset = tracker.currentRestriction().getFrom(); + long expectedOffset = startOffset; + consumer.resume(Collections.singleton(topicPartition)); + consumer.seek(topicPartition, startOffset); + long skippedRecords = 0L; + final Stopwatch sw = Stopwatch.createStarted(); + + final KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics(); + try { + while (true) { + final ConsumerRecords rawRecords = + poll(consumer, topicPartition, kafkaMetrics); + // When there are no records available for the current TopicPartition, self-checkpoint + // and move to process the next element. + if (rawRecords.isEmpty()) { + if (!topicPartitionExists( + kafkaSourceDescriptor.getTopicPartition(), + consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { + return ProcessContinuation.stop(); } - for (ConsumerRecord rawRecord : rawRecords) { - // If the Kafka consumer returns a record with an offset that is already processed - // the record can be safely skipped. This is needed because there is a possibility - // that the seek() above fails to move the offset to the desired position. In which - // case poll() would return records that are already cnsumed. - if (rawRecord.offset() < startOffset) { - // If the start offset is not reached even after skipping the records for 10 seconds - // then the processing is stopped with a backoff to give the Kakfa server some time - // catch up. - if (sw.elapsed().getSeconds() > 10L) { - LOG.error( - "The expected offset ({}) was not reached even after" - + " skipping consumed records for 10 seconds. The offset we could" - + " reach was {}. The processing of this bundle will be attempted" - + " at a later time.", - expectedOffset, - rawRecord.offset()); - return ProcessContinuation.resume() - .withResumeDelay(org.joda.time.Duration.standardSeconds(10L)); - } - skippedRecords++; - continue; - } - if (skippedRecords > 0L) { - LOG.warn( - "{} records were skipped due to seek returning an" - + " earlier position than requested position of {}", - skippedRecords, - expectedOffset); - skippedRecords = 0L; + if (timestampPolicy != null) { + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + } + return ProcessContinuation.resume(); + } + for (ConsumerRecord rawRecord : rawRecords) { + // If the Kafka consumer returns a record with an offset that is already processed + // the record can be safely skipped. This is needed because there is a possibility + // that the seek() above fails to move the offset to the desired position. In which + // case poll() would return records that are already cnsumed. + if (rawRecord.offset() < startOffset) { + // If the start offset is not reached even after skipping the records for 10 seconds + // then the processing is stopped with a backoff to give the Kakfa server some time + // catch up. + if (sw.elapsed().getSeconds() > 10L) { + LOG.error( + "The expected offset ({}) was not reached even after" + + " skipping consumed records for 10 seconds. The offset we could" + + " reach was {}. The processing of this bundle will be attempted" + + " at a later time.", + expectedOffset, + rawRecord.offset()); + consumer.pause(Collections.singleton(topicPartition)); + return ProcessContinuation.resume() + .withResumeDelay(org.joda.time.Duration.standardSeconds(10L)); } - if (!tracker.tryClaim(rawRecord.offset())) { - return ProcessContinuation.stop(); + skippedRecords++; + continue; + } + if (skippedRecords > 0L) { + LOG.warn( + "{} records were skipped due to seek returning an" + + " earlier position than requested position of {}", + skippedRecords, + expectedOffset); + skippedRecords = 0L; + } + if (!tracker.tryClaim(rawRecord.offset())) { + consumer.seek(topicPartition, rawRecord.offset()); + consumer.pause(Collections.singleton(topicPartition)); + + return ProcessContinuation.stop(); + } + try { + KafkaRecord kafkaRecord = + new KafkaRecord<>( + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + ConsumerSpEL.getRecordTimestamp(rawRecord), + ConsumerSpEL.getRecordTimestampType(rawRecord), + ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null, + ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord), + ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord)); + int recordSize = + (rawRecord.key() == null ? 0 : rawRecord.key().length) + + (rawRecord.value() == null ? 0 : rawRecord.value().length); + avgRecordSize.update(recordSize); + rawSizes.update(recordSize); + expectedOffset = rawRecord.offset() + 1; + Instant outputTimestamp; + // The outputTimestamp and watermark will be computed by timestampPolicy, where the + // WatermarkEstimator should be a manual one. + if (timestampPolicy != null) { + TimestampPolicyContext context = + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord); + } else { + Preconditions.checkStateNotNull(this.extractOutputTimestampFn); + outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord); } - try { - KafkaRecord kafkaRecord = - new KafkaRecord<>( - rawRecord.topic(), - rawRecord.partition(), - rawRecord.offset(), - ConsumerSpEL.getRecordTimestamp(rawRecord), - ConsumerSpEL.getRecordTimestampType(rawRecord), - ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null, - ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord), - ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord)); - int recordSize = - (rawRecord.key() == null ? 0 : rawRecord.key().length) - + (rawRecord.value() == null ? 0 : rawRecord.value().length); - avgRecordSize.update(recordSize); - rawSizes.update(recordSize); - expectedOffset = rawRecord.offset() + 1; - Instant outputTimestamp; - // The outputTimestamp and watermark will be computed by timestampPolicy, where the - // WatermarkEstimator should be a manual one. - if (timestampPolicy != null) { - TimestampPolicyContext context = - updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord); - } else { - Preconditions.checkStateNotNull(this.extractOutputTimestampFn); - outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord); - } - receiver - .get(recordTag) - .outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp); - } catch (SerializationException e) { - // This exception should only occur during the key and value deserialization when - // creating the Kafka Record - badRecordRouter.route( - receiver, - rawRecord, - null, - e, - "Failure deserializing Key or Value of Kakfa record reading from Kafka"); - if (timestampPolicy != null) { - updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - } + receiver + .get(recordTag) + .outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp); + } catch (SerializationException e) { + // This exception should only occur during the key and value deserialization when + // creating the Kafka Record + badRecordRouter.route( + receiver, + rawRecord, + null, + e, + "Failure deserializing Key or Value of Kakfa record reading from Kafka"); + if (timestampPolicy != null) { + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } } - - backlogBytes.set( - (long) - (BigDecimal.valueOf( - Preconditions.checkStateNotNull( - offsetEstimatorCache.get(kafkaSourceDescriptor).estimate())) - .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) - .doubleValue() - * avgRecordSize.get())); - kafkaMetrics.updateBacklogBytes( - kafkaSourceDescriptor.getTopic(), - kafkaSourceDescriptor.getPartition(), - (long) - (BigDecimal.valueOf( - Preconditions.checkStateNotNull( - offsetEstimatorCache.get(kafkaSourceDescriptor).estimate())) - .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) - .doubleValue() - * avgRecordSize.get())); } - } finally { - kafkaMetrics.flushBufferedMetrics(); + + final long estimatedBacklogBytes = + (long) + (BigDecimal.valueOf(latestOffsetEstimator.estimate()) + .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) + .doubleValue() + * avgRecordSize.get()); + backlogBytes.set(estimatedBacklogBytes); + kafkaMetrics.updateBacklogBytes( + kafkaSourceDescriptor.getTopic(), + kafkaSourceDescriptor.getPartition(), + estimatedBacklogBytes); } + } finally { + kafkaMetrics.flushBufferedMetrics(); } } @@ -628,58 +734,8 @@ public Coder restrictionCoder() { @Setup public void setup() throws Exception { - // Start to track record size. - avgRecordSizeCache = - SharedStateHolder.AVG_RECORD_SIZE_CACHE.computeIfAbsent( - fnId, - k -> { - return CacheBuilder.newBuilder() - .maximumSize(1000L) - .build( - new CacheLoader() { - @Override - public MovingAvg load(KafkaSourceDescriptor kafkaSourceDescriptor) - throws Exception { - return new MovingAvg(); - } - }); - }); keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); - offsetEstimatorCache = - SharedStateHolder.OFFSET_ESTIMATOR_CACHE.computeIfAbsent( - fnId, - k -> { - final Map consumerConfig = ImmutableMap.copyOf(this.consumerConfig); - final @Nullable Map offsetConsumerConfig = - this.offsetConsumerConfig == null - ? null - : ImmutableMap.copyOf(this.offsetConsumerConfig); - return CacheBuilder.newBuilder() - .weakValues() - .expireAfterAccess(1, TimeUnit.MINUTES) - .build( - new CacheLoader() { - @Override - public KafkaLatestOffsetEstimator load( - KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { - LOG.info( - "Creating Kafka consumer for offset estimation for {}", - kafkaSourceDescriptor); - - TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - Consumer offsetConsumer = - consumerFactoryFn.apply( - KafkaIOUtils.getOffsetConsumerConfig( - "tracker-" + topicPartition, - offsetConsumerConfig, - updatedConsumerConfig)); - return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); - } - }); - }); if (checkStopReadingFn != null) { checkStopReadingFn.setup(); } @@ -687,10 +743,6 @@ public KafkaLatestOffsetEstimator load( @Teardown public void teardown() throws Exception { - final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); try { if (valueDeserializerInstance != null) { Closeables.close(valueDeserializerInstance, true); @@ -708,22 +760,9 @@ public void teardown() throws Exception { } // Allow the cache to perform clean up tasks when this instance is about to be deleted. - avgRecordSizeCache.cleanUp(); - offsetEstimatorCache.cleanUp(); - } - - private Map overrideBootstrapServersConfig( - Map currentConfig, KafkaSourceDescriptor description) { - checkState( - currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) - || description.getBootStrapServers() != null); - Map config = new HashMap<>(currentConfig); - if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) { - config.put( - ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, - String.join(",", description.getBootStrapServers())); - } - return config; + avgRecordSizeCacheSupplier.get().cleanUp(); + latestOffsetEstimatorCacheSupplier.get().cleanUp(); + pollConsumerCacheSupplier.get().cleanUp(); } private static Instant ensureTimestampWithinBounds(Instant timestamp) { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index cbff0f896619..00bbd523f105 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -266,12 +266,14 @@ public synchronized List partitionsFor(String partition) { @Override public synchronized void assign(Collection partitions) { assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition)); + super.assign(partitions); } @Override public synchronized void seek(TopicPartition partition, long offset) { assertTrue(partition.equals(this.topicPartition)); this.startOffset = offset; + super.seek(partition, offset); } @Override From dbaecc4823a5b3d644ba42d672521ee11af0029b Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Thu, 17 Apr 2025 17:32:37 +0000 Subject: [PATCH 2/5] Add comment to explain the behavior of volatile guard field in KafkaLatestOffsetEstimator --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 19e8c0e47107..50f8cb1240cd 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -342,10 +342,43 @@ private static class KafkaLatestOffsetEstimator private final Executor executor; private final Consumer offsetConsumer; private final TopicPartition topicPartition; + // TODO(sjvanrossum): Use VarHandle.setOpaque/getOpaque when Java 8 support is dropped private long lastRefreshEndOffset; + // TODO(sjvanrossum): Use VarHandle.setOpaque/getOpaque when Java 8 support is dropped private long nextRefreshNanos; private volatile @Nullable Runnable currentRefreshTask; + /* + Periodic refreshes of lastRefreshEndOffset and nextRefreshNanos are guarded by the volatile + field currentRefreshTask. This guard's correctness depends on specific ordering of reads and + writes (loads and stores). + + To validate the behavior of this guard please read the Java Memory Model (JMM) specification. + For the current context consider the following oversimplifications of the JMM: + - Writes to a non-volatile long or double field are non-atomic. + - Writes to a non-volatile field may never become visible to another core. + - Writes to a volatile field are atomic and will become visible to another core. + - Lazy writes to a volatile field are atomic and will become visible to another core for + reads of that volatile field. + - Writes preceeding writes or lazy writes to a volatile field are visible to another core. + + In short, the contents of this class' guarded fields are visible if the guard field is (lazily) + written last and read first. The contents of the volatile guard may be stale in comparison to + the contents of the guarded fields. For this method it is important that no more than one + thread will schedule a refresh task. Using currentRefreshTask as the guard field ensures that + lastRefreshEndOffset and nextRefreshNanos are at least as stale as currentRefreshTask. + It's fine if lastRefreshEndOffset and nextRefreshNanos are less stale than currentRefreshTask. + + Removing currentRefreshTask by guarding on nextRefreshNanos is possible, but executing + currentRefreshTask == null is practically free (measured in cycles) compared to executing + nextRefreshNanos < System.nanoTime() (measured in nanoseconds). + + Note that the JMM specifies that writes to a long or double are not guaranteed to be atomic. + In practice, every 64-bit JVM will treat them as atomic (and the JMM encourages this). + There's no way to force atomicity without visibility in Java 8 so atomicity guards have been + omitted. Java 9 introduces VarHandle with "opaque" getters/setters which do provide this. + */ + KafkaLatestOffsetEstimator( final Consumer offsetConsumer, final TopicPartition topicPartition) { this.executor = Executors.newSingleThreadExecutor(); From 52a64d80569feb30e5f6b453d327599e214084fa Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Tue, 22 Apr 2025 15:45:11 +0000 Subject: [PATCH 3/5] Guard against exceptions in endOffset refresh --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 50f8cb1240cd..39728a473677 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -415,16 +415,19 @@ public void close() { } private void refresh() { - @Nullable - Long endOffset = - offsetConsumer.endOffsets(Collections.singleton(topicPartition)).get(topicPartition); - if (endOffset == null) { - LOG.warn("No end offset found for partition {}.", topicPartition); - } else { - lastRefreshEndOffset = endOffset; // normal store + try { + @Nullable + Long endOffset = + offsetConsumer.endOffsets(Collections.singleton(topicPartition)).get(topicPartition); + if (endOffset == null) { + LOG.warn("No end offset found for partition {}.", topicPartition); + } else { + lastRefreshEndOffset = endOffset; // normal store + } + nextRefreshNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(1); // normal store + } finally { + CURRENT_REFRESH_TASK.lazySet(this, null); // ordered store (release) } - nextRefreshNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(1); // normal store - CURRENT_REFRESH_TASK.lazySet(this, null); // ordered store (release) } } From 7e1cc54f9a18c7624071f680c0e94f943dd9f9f0 Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Tue, 22 Apr 2025 20:20:13 +0000 Subject: [PATCH 4/5] Call cancelIfTimeouted in roundtripElements to shutdown lingering pipelines --- .../test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java index adf31dc72b54..a09d46b674ae 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java @@ -322,7 +322,8 @@ private void roundtripElements( .withKeySerializer(IntegerSerializer.class) .withValueSerializer(StringSerializer.class)); - wPipeline.run().waitUntilFinish(Duration.standardSeconds(10)); + final PipelineResult wResult = wPipeline.run(); + cancelIfTimeouted(wResult, wResult.waitUntilFinish(Duration.standardSeconds(10))); rPipeline .apply( @@ -345,7 +346,9 @@ private void roundtripElements( .apply(ParDo.of(new CrashOnExtra(records.values()))) .apply(ParDo.of(new LogFn())); - rPipeline.run().waitUntilFinish(Duration.standardSeconds(options.getReadTimeout())); + final PipelineResult rResult = rPipeline.run(); + cancelIfTimeouted( + rResult, rResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout()))); for (String value : records.values()) { kafkaIOITExpectedLogs.verifyError(value); From 9655065c01e54754780914126c02e478310d3cdd Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Thu, 24 Apr 2025 23:45:25 +0000 Subject: [PATCH 5/5] Add missing calls to seek and/or pause before return points added in #34202 --- .../java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index fe4da4d50fb6..cd3ad2526d30 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -611,6 +611,8 @@ public ProcessContinuation processElement( // No progress when the polling timeout expired. // Self-checkpoint and move to process the next element. if (rawRecords == ConsumerRecords.empty()) { + consumer.pause(Collections.singleton(topicPartition)); + if (!topicPartitionExists( kafkaSourceDescriptor.getTopicPartition(), consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { @@ -712,6 +714,9 @@ public ProcessContinuation processElement( // Claim up to the current position. if (expectedOffset < (expectedOffset = consumer.position(topicPartition))) { if (!tracker.tryClaim(expectedOffset - 1)) { + consumer.seek(topicPartition, expectedOffset - 1); + consumer.pause(Collections.singleton(topicPartition)); + return ProcessContinuation.stop(); } if (timestampPolicy != null) {