diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 4f353df865ff..3d228242b8a5 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -81,6 +81,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; @@ -686,7 +687,7 @@ private static class UnboundedKafkaReader extends UnboundedReader> availableRecordsQueue = new SynchronousQueue<>(); - private volatile boolean closed = false; + private AtomicBoolean closed = new AtomicBoolean(false); // Backlog support : // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() @@ -792,10 +793,10 @@ public PartitionState apply(TopicPartition tp) { private void consumerPollLoop() { // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue - while (!closed) { + while (!closed.get()) { try { ConsumerRecords records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); - if (!records.isEmpty()) { + if (!records.isEmpty() && !closed.get()) { availableRecordsQueue.put(records); // blocks until dequeued. } } catch (InterruptedException e) { @@ -817,6 +818,7 @@ private void nextBatch() { records = availableRecordsQueue.poll(NEW_RECORDS_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); LOG.warn("{}: Unexpected", this, e); return; } @@ -1041,11 +1043,32 @@ public long getSplitBacklogBytes() { @Override public void close() throws IOException { - closed = true; - availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. - consumer.wakeup(); + closed.set(true); consumerPollThread.shutdown(); offsetFetcherThread.shutdown(); + + boolean isShutdown = false; + + // Wait for threads to shutdown. Trying this a loop to handle a tiny race where poll thread + // might block to enqueue right after availableRecordsQueue.poll() below. + while (!isShutdown) { + + consumer.wakeup(); + offsetConsumer.wakeup(); + availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. + try { + isShutdown = consumerPollThread.awaitTermination(10, TimeUnit.SECONDS) + && offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); // not expected + } + + if (!isShutdown) { + LOG.warn("An internal thread is taking a long time to shutdown. will retry."); + } + } + Closeables.close(offsetConsumer, true); Closeables.close(consumer, true); }