diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
index 3ab6c4f502e7..b2424875d7f3 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
@@ -21,6 +21,7 @@
import java.math.BigDecimal;
import java.math.MathContext;
+import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -140,8 +141,8 @@
* {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically
* by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon
* as the {@link TopicPartition} is removed. For example, the removal could happen at the same time
- * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that
- * case, the {@link ReadFromKafkaDoFn} will still output the fetched records.
+ * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the
+ * {@link ReadFromKafkaDoFn} will still output the fetched records.
*
*
Stop Reading from Stopped {@link TopicPartition}
*
@@ -199,11 +200,11 @@ private ReadFromKafkaDoFn(
this.checkStopReadingFn = transform.getCheckStopReadingFn();
this.badRecordRouter = transform.getBadRecordRouter();
this.recordTag = recordTag;
- if (transform.getConsumerPollingTimeout() > 0) {
- this.consumerPollingTimeout = transform.getConsumerPollingTimeout();
- } else {
- this.consumerPollingTimeout = DEFAULT_KAFKA_POLL_TIMEOUT;
- }
+ this.consumerPollingTimeout =
+ Duration.ofSeconds(
+ transform.getConsumerPollingTimeout() > 0
+ ? transform.getConsumerPollingTimeout()
+ : DEFAULT_KAFKA_POLL_TIMEOUT);
}
private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
@@ -248,7 +249,7 @@ private static final class SharedStateHolder {
private transient @Nullable LoadingCache avgRecordSizeCache;
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;
- @VisibleForTesting final long consumerPollingTimeout;
+ @VisibleForTesting final Duration consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider keyDeserializerProvider;
@VisibleForTesting final DeserializerProvider valueDeserializerProvider;
@VisibleForTesting final Map consumerConfig;
@@ -443,19 +444,27 @@ public ProcessContinuation processElement(
long startOffset = tracker.currentRestriction().getFrom();
long expectedOffset = startOffset;
consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset);
- ConsumerRecords rawRecords = ConsumerRecords.empty();
long skippedRecords = 0L;
final Stopwatch sw = Stopwatch.createStarted();
- KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics();
+ final KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics();
try {
while (true) {
// Fetch the record size accumulator.
final MovingAvg avgRecordSize = avgRecordSizeCache.getUnchecked(kafkaSourceDescriptor);
- rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition(), kafkaMetrics);
- // When there are no records available for the current TopicPartition, self-checkpoint
- // and move to process the next element.
- if (rawRecords.isEmpty()) {
+ // TODO: Remove this timer and use the existing fetch-latency-avg metric.
+ // A consumer will often have prefetches waiting to be returned immediately in which case
+ // this timer may contribute more latency than it measures.
+ // See https://shipilev.net/blog/2014/nanotrusting-nanotime/ for more information.
+ final Stopwatch pollTimer = Stopwatch.createStarted();
+ // Fetch the next records.
+ final ConsumerRecords rawRecords =
+ consumer.poll(this.consumerPollingTimeout);
+ kafkaMetrics.updateSuccessfulRpcMetrics(topicPartition.topic(), pollTimer.elapsed());
+
+ // No progress when the polling timeout expired.
+ // Self-checkpoint and move to process the next element.
+ if (rawRecords == ConsumerRecords.empty()) {
if (!topicPartitionExists(
kafkaSourceDescriptor.getTopicPartition(),
consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) {
@@ -466,6 +475,9 @@ public ProcessContinuation processElement(
}
return ProcessContinuation.resume();
}
+
+ // Visible progress within the consumer polling timeout.
+ // Partially or fully claim and process records in this batch.
for (ConsumerRecord rawRecord : rawRecords) {
// If the Kafka consumer returns a record with an offset that is already processed
// the record can be safely skipped. This is needed because there is a possibility
@@ -500,6 +512,7 @@ public ProcessContinuation processElement(
if (!tracker.tryClaim(rawRecord.offset())) {
return ProcessContinuation.stop();
}
+ expectedOffset = rawRecord.offset() + 1;
try {
KafkaRecord kafkaRecord =
new KafkaRecord<>(
@@ -516,7 +529,6 @@ public ProcessContinuation processElement(
+ (rawRecord.value() == null ? 0 : rawRecord.value().length);
avgRecordSize.update(recordSize);
rawSizes.update(recordSize);
- expectedOffset = rawRecord.offset() + 1;
Instant outputTimestamp;
// The outputTimestamp and watermark will be computed by timestampPolicy, where the
// WatermarkEstimator should be a manual one.
@@ -546,6 +558,17 @@ public ProcessContinuation processElement(
}
}
+ // Non-visible progress within the consumer polling timeout.
+ // Claim up to the current position.
+ if (expectedOffset < (expectedOffset = consumer.position(topicPartition))) {
+ if (!tracker.tryClaim(expectedOffset - 1)) {
+ return ProcessContinuation.stop();
+ }
+ if (timestampPolicy != null) {
+ updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
+ }
+ }
+
backlogBytes.set(
(long)
(BigDecimal.valueOf(
@@ -578,36 +601,6 @@ private boolean topicPartitionExists(
.anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition()));
}
- // see https://github.com/apache/beam/issues/25962
- private ConsumerRecords poll(
- Consumer consumer, TopicPartition topicPartition, KafkaMetrics kafkaMetrics) {
- final Stopwatch sw = Stopwatch.createStarted();
- long previousPosition = -1;
- java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout);
- java.time.Duration elapsed = java.time.Duration.ZERO;
- while (true) {
- final ConsumerRecords rawRecords = consumer.poll(timeout.minus(elapsed));
- elapsed = sw.elapsed();
- kafkaMetrics.updateSuccessfulRpcMetrics(
- topicPartition.topic(), java.time.Duration.ofMillis(elapsed.toMillis()));
- if (!rawRecords.isEmpty()) {
- // return as we have found some entries
- return rawRecords;
- }
- if (previousPosition == (previousPosition = consumer.position(topicPartition))) {
- // there was no progress on the offset/position, which indicates end of stream
- return rawRecords;
- }
- if (elapsed.toMillis() >= timeout.toMillis()) {
- // timeout is over
- LOG.warn(
- "No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.",
- consumerPollingTimeout);
- return rawRecords;
- }
- }
- }
-
private TimestampPolicyContext updateWatermarkManually(
TimestampPolicy timestampPolicy,
WatermarkEstimator watermarkEstimator,
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
index cbff0f896619..eda9dac7a298 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
@@ -717,14 +717,14 @@ public void testUnbounded() {
@Test
public void testConstructorWithPollTimeout() {
ReadSourceDescriptors descriptors = makeReadSourceDescriptor(consumer);
- // default poll timeout = 1 scond
+ // default poll timeout = 2 seconds
ReadFromKafkaDoFn dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
- Assert.assertEquals(2L, dofnInstance.consumerPollingTimeout);
+ Assert.assertEquals(Duration.ofSeconds(2L), dofnInstance.consumerPollingTimeout);
// updated timeout = 5 seconds
descriptors = descriptors.withConsumerPollingTimeout(5L);
ReadFromKafkaDoFn dofnInstanceNew =
ReadFromKafkaDoFn.create(descriptors, RECORDS);
- Assert.assertEquals(5L, dofnInstanceNew.consumerPollingTimeout);
+ Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
}
private BoundednessVisitor testBoundedness(