From f04eb99b51e36ab989638788a5fd86ff7f421034 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 10 Dec 2019 17:48:31 +0900 Subject: [PATCH 1/6] [SPARK-21869][SS] Revise Kafka producer pool to implement 'expire' correctly --- .../sql/kafka010/CachedKafkaProducer.scala | 128 ----------- .../spark/sql/kafka010/KafkaDataWriter.scala | 23 +- .../spark/sql/kafka010/KafkaWriteTask.scala | 20 +- .../apache/spark/sql/kafka010/package.scala | 7 + .../producer/CachedKafkaProducer.scala | 41 ++++ .../producer/InternalKafkaProducerPool.scala | 204 ++++++++++++++++++ .../kafka010/CachedKafkaProducerSuite.scala | 77 ------- .../apache/spark/sql/kafka010/KafkaTest.scala | 3 +- .../InternalKafkaProducerPoolSuite.scala | 192 +++++++++++++++++ 9 files changed, 469 insertions(+), 226 deletions(-) delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala deleted file mode 100644 index fc177cdc9037e..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} -import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit} - -import com.google.common.cache._ -import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.kafka.clients.producer.KafkaProducer -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal - -import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging -import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil} - -private[kafka010] object CachedKafkaProducer extends Logging { - - private type Producer = KafkaProducer[Array[Byte], Array[Byte]] - - private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) - - private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get) - .map(_.conf.get(PRODUCER_CACHE_TIMEOUT)) - .getOrElse(defaultCacheExpireTimeout) - - private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { - override def load(config: Seq[(String, Object)]): Producer = { - createKafkaProducer(config) - } - } - - private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() { - override def onRemoval( - notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = { - val paramsSeq: Seq[(String, Object)] = notification.getKey - val producer: Producer = notification.getValue - if (log.isDebugEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logDebug(s"Evicting kafka producer $producer params: $redactedParamsSeq, " + - s"due to ${notification.getCause}") - } - close(paramsSeq, producer) - } - } - - private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] = - CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS) - .removalListener(removalListener) - .build[Seq[(String, Object)], Producer](cacheLoader) - - private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { - val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) - if (log.isDebugEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") - } - kafkaProducer - } - - /** - * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't - * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep - * one instance per specified kafkaParams. - */ - private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = { - val updatedKafkaProducerConfiguration = - KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) - .setAuthenticationConfigIfNeeded() - .build() - val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration) - try { - guavaCache.get(paramsSeq) - } catch { - case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError) - if e.getCause != null => - throw e.getCause - } - } - - private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { - val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) - paramsSeq - } - - /** For explicitly closing kafka producer */ - private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { - val paramsSeq = paramsToSeq(kafkaParams) - guavaCache.invalidate(paramsSeq) - } - - /** Auto close on cache evict */ - private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = { - try { - if (log.isInfoEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.") - } - producer.close() - } catch { - case NonFatal(e) => logWarning("Error while closing kafka producer.", e) - } - } - - private[kafka010] def clear(): Unit = { - logInfo("Cleaning up guava cache.") - guavaCache.invalidateAll() - } - - // Intended for testing purpose only. - private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap() -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index 9a2b369933616..63863a6cc6d6f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -22,6 +22,7 @@ import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool} /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we @@ -44,11 +45,14 @@ private[kafka010] class KafkaDataWriter( inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams) + private var producer: Option[CachedKafkaProducer] = None def write(row: InternalRow): Unit = { checkForErrors() - sendRow(row, producer) + if (producer.isEmpty) { + producer = Some(InternalKafkaProducerPool.acquire(producerParams)) + } + producer.foreach { p => sendRow(row, p.producer) } } def commit(): WriterCommitMessage = { @@ -56,22 +60,15 @@ private[kafka010] class KafkaDataWriter( // This requires flushing and then checking that no callbacks produced errors. // We also check for errors before to fail as soon as possible - the check is cheap. checkForErrors() - producer.flush() + producer.foreach(_.producer.flush()) checkForErrors() KafkaDataWriterCommitMessage } def abort(): Unit = {} - def close(): Unit = {} - - /** explicitly invalidate producer from pool. only for testing. */ - private[kafka010] def invalidateProducer(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(producerParams) - } + def close(): Unit = { + producer.foreach(InternalKafkaProducerPool.release) + producer = None } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 8b907065af1d0..fddba3f0f9919 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -27,6 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection} +import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool} import org.apache.spark.sql.types.BinaryType /** @@ -39,25 +40,30 @@ private[kafka010] class KafkaWriteTask( inputSchema: Seq[Attribute], topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + private var producer: Option[CachedKafkaProducer] = None /** * Writes key value data out to topics. */ def execute(iterator: Iterator[InternalRow]): Unit = { - producer = CachedKafkaProducer.getOrCreate(producerConfiguration) + producer = Some(InternalKafkaProducerPool.acquire(producerConfiguration)) + val internalProducer = producer.get.producer while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + sendRow(currentRow, internalProducer) } } def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() + try { checkForErrors() - producer = null + producer.foreach { p => + p.producer.flush() + checkForErrors() + } + } finally { + producer.foreach(InternalKafkaProducerPool.release) + producer = None } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala index 6f6ae55fc4971..460bb8bd34ec6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -32,6 +32,13 @@ package object kafka010 { // scalastyle:ignore .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("10m") + private[kafka010] val PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL = + ConfigBuilder("spark.kafka.producer.cache.evictorThreadRunInterval") + .doc("The interval of time between runs of the idle evictor thread for producer pool. " + + "When non-positive, no idle evictor thread will be run.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1m") + private[kafka010] val CONSUMER_CACHE_CAPACITY = ConfigBuilder("spark.kafka.consumer.cache.capacity") .doc("The maximum number of consumers cached. Please note it's a soft limit" + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala new file mode 100644 index 0000000000000..6682944fe4f4e --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} + +import scala.util.control.NonFatal + +import org.apache.kafka.clients.producer.KafkaProducer + +import org.apache.spark.internal.Logging + +private[kafka010] class CachedKafkaProducer( + val cacheKey: Seq[(String, Object)], + val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging { + val id: String = ju.UUID.randomUUID().toString + + private[producer] def close(): Unit = { + try { + logInfo(s"Closing the KafkaProducer with id: $id.") + producer.close() + } catch { + case NonFatal(e) => logWarning("Error while closing kafka producer.", e) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala new file mode 100644 index 0000000000000..9384713af250f --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.kafka.clients.producer.KafkaProducer + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil} +import org.apache.spark.sql.kafka010.{PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, PRODUCER_CACHE_TIMEOUT} +import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, ThreadUtils, Utils} + +/** + * Provides object pool for [[CachedKafkaProducer]] which is grouped by + * [[org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool.CacheKey]]. + */ +private[producer] class InternalKafkaProducerPool( + executorService: ScheduledExecutorService, + val clock: Clock, + conf: SparkConf) extends Logging { + import InternalKafkaProducerPool._ + + def this(sparkConf: SparkConf) = { + this(ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kafka-producer-cache-evictor"), new SystemClock, sparkConf) + } + + /** exposed for testing */ + private[producer] val cacheExpireTimeout: Long = conf.get(PRODUCER_CACHE_TIMEOUT) + + private val evictorThreadRunIntervalMillis = conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) + + @GuardedBy("this") + private val cache = new mutable.HashMap[CacheKey, CachedProducerEntry] + + private def startEvictorThread(): Option[ScheduledFuture[_]] = { + if (evictorThreadRunIntervalMillis > 0) { + val future = executorService.scheduleAtFixedRate(() => { + Utils.tryLogNonFatalError(evictExpired()) + }, 0, evictorThreadRunIntervalMillis, TimeUnit.MILLISECONDS) + Some(future) + } else { + None + } + } + + private var scheduled = startEvictorThread() + + /** + * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't + * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep + * one instance per specified kafkaParams. + */ + private[producer] def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { + val updatedKafkaProducerConfiguration = + KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) + .setAuthenticationConfigIfNeeded() + .build() + val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration) + synchronized { + val entry = cache.getOrElseUpdate(paramsSeq, { + val producer = createKafkaProducer(paramsSeq) + val cachedProducer = new CachedKafkaProducer(paramsSeq, producer) + new CachedProducerEntry(cachedProducer, clock, cacheExpireTimeout) + }) + entry.handleBorrowed() + entry.producer + } + } + + private[producer] def release(producer: CachedKafkaProducer): Unit = { + def closeProducerNotInCache(producer: CachedKafkaProducer): Unit = { + logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.") + producer.close() + } + + synchronized { + cache.get(producer.cacheKey) match { + case Some(entry) if entry.producer.id == producer.id => entry.handleReturned() + case _ => closeProducerNotInCache(producer) + } + } + } + + private[producer] def shutdown(): Unit = { + ThreadUtils.shutdown(executorService) + } + + /** exposed for testing. */ + private[producer] def reset(): Unit = synchronized { + scheduled.foreach(_.cancel(true)) + cache.foreach { case (k, v) => + cache.remove(k) + v.producer.close() + } + scheduled = startEvictorThread() + } + + /** exposed for testing */ + private[producer] def getAsMap: Map[CacheKey, CachedProducerEntry] = cache.toMap + + private def evictExpired(): Unit = { + val producers = new mutable.ArrayBuffer[CachedProducerEntry]() + synchronized { + cache.filter { case (_, v) => v.expired }.foreach { case (k, v) => + cache.remove(k) + producers += v + } + } + producers.foreach { _.producer.close() } + } + + private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { + val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) + if (log.isDebugEnabled()) { + val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) + logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") + } + kafkaProducer + } + + private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { + val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) + paramsSeq + } +} + +private[kafka010] object InternalKafkaProducerPool extends Logging { + private val pool = new InternalKafkaProducerPool(SparkEnv.get.conf) + + private type CacheKey = Seq[(String, Object)] + private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + + ShutdownHookManager.addShutdownHook { () => + try { + pool.shutdown() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while shutting down pools from shutdown hook", e) + } + } + + /** + * This class is used as metadata of producer pool, and shouldn't be exposed to the public. + * This class assumes thread-safety is guaranteed by the caller. + */ + private[producer] class CachedProducerEntry( + val producer: CachedKafkaProducer, + clock: Clock, + cacheExpireTimeout: Long) { + private var _refCount: Long = 0L + private var _expireAt: Long = Long.MaxValue + + /** exposed for testing */ + private[producer] def refCount: Long = _refCount + private[producer] def expireAt: Long = _expireAt + + def handleBorrowed(): Unit = { + _refCount += 1 + _expireAt = Long.MaxValue + } + + def handleReturned(): Unit = { + _refCount -= 1 + if (_refCount <= 0) { + _expireAt = clock.getTimeMillis() + cacheExpireTimeout + } + } + + def expired: Boolean = _refCount <= 0 && _expireAt < clock.getTimeMillis() + } + + def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { + pool.acquire(kafkaParams) + } + + def release(producer: CachedKafkaProducer): Unit = { + pool.release(producer) + } + + def reset(): Unit = pool.reset() +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala deleted file mode 100644 index 7425a74315e1a..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} -import java.util.concurrent.ConcurrentMap - -import org.apache.kafka.clients.producer.KafkaProducer -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSparkSession - -class CachedKafkaProducerSuite extends SharedSparkSession with PrivateMethodTester with KafkaTest { - - type KP = KafkaProducer[Array[Byte], Array[Byte]] - - protected override def beforeEach(): Unit = { - super.beforeEach() - CachedKafkaProducer.clear() - } - - test("Should return the cached instance on calling getOrCreate with same params.") { - val kafkaParams = new ju.HashMap[String, Object]() - kafkaParams.put("acks", "0") - // Here only host should be resolvable, it does not need a running instance of kafka server. - kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") - kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) - kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) - val producer = CachedKafkaProducer.getOrCreate(kafkaParams) - val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams) - assert(producer == producer2) - - val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]](Symbol("getAsMap")) - val map = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map.size == 1) - } - - test("Should close the correct kafka producer for the given kafkaPrams.") { - val kafkaParams = new ju.HashMap[String, Object]() - kafkaParams.put("acks", "0") - kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") - kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) - kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) - val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams) - kafkaParams.put("acks", "1") - val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams) - // With updated conf, a new producer instance should be created. - assert(producer != producer2) - - val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]](Symbol("getAsMap")) - val map = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map.size == 2) - - CachedKafkaProducer.close(kafkaParams) - val map2 = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map2.size == 1) - import scala.collection.JavaConverters._ - val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0) - assert(_producer == producer) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala index 19acda95c707c..087d938f8ed8e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool /** A trait to clean cached Kafka producers in `afterAll` */ trait KafkaTest extends BeforeAndAfterAll { @@ -27,6 +28,6 @@ trait KafkaTest extends BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() - CachedKafkaProducer.clear() + InternalKafkaProducerPool.reset() } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala new file mode 100644 index 0000000000000..266304bbc975b --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} +import java.util.concurrent.{Executors, TimeUnit} + +import scala.util.Random + +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.jmock.lib.concurrent.DeterministicScheduler + +import org.apache.spark.SparkConf +import org.apache.spark.sql.kafka010.{PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, PRODUCER_CACHE_TIMEOUT} +import org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool.CachedProducerEntry +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ManualClock + +class InternalKafkaProducerPoolSuite extends SharedSparkSession { + + private var pool: InternalKafkaProducerPool = _ + + protected override def afterEach(): Unit = { + if (pool != null) { + try { + pool.shutdown() + pool = null + } catch { + // ignore as it's known issue, DeterministicScheduler doesn't support shutdown + case _: UnsupportedOperationException => + } + } + } + + test("Should return same cached instance on calling acquire with same params.") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + val producer = pool.acquire(kafkaParams) + val producer2 = pool.acquire(kafkaParams) + assert(producer === producer2) + + val map = pool.getAsMap + assert(map.size === 1) + val cacheEntry = map.head._2 + assertCacheEntry(pool, cacheEntry, 2L) + + pool.release(producer) + assertCacheEntry(pool, cacheEntry, 1L) + + pool.release(producer2) + assertCacheEntry(pool, cacheEntry, 0L) + + val producer3 = pool.acquire(kafkaParams) + assertCacheEntry(pool, cacheEntry, 1L) + assert(producer === producer3) + } + + test("Should return different cached instances on calling acquire with different params.") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + val producer = pool.acquire(kafkaParams) + kafkaParams.put("acks", "1") + val producer2 = pool.acquire(kafkaParams) + // With updated conf, a new producer instance should be created. + assert(producer !== producer2) + + val map = pool.getAsMap + assert(map.size === 2) + val cacheEntry = map.find(_._2.producer.id == producer.id).get._2 + assertCacheEntry(pool, cacheEntry, 1L) + val cacheEntry2 = map.find(_._2.producer.id == producer2.id).get._2 + assertCacheEntry(pool, cacheEntry2, 1L) + } + + test("expire instances") { + val minEvictableIdleTimeMillis = 2000L + val evictorThreadRunIntervalMillis = 500L + + val conf = new SparkConf() + conf.set(PRODUCER_CACHE_TIMEOUT, minEvictableIdleTimeMillis) + conf.set(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, evictorThreadRunIntervalMillis) + + val scheduler = new DeterministicScheduler() + val clock = new ManualClock() + pool = new InternalKafkaProducerPool(scheduler, clock, conf) + + val kafkaParams = getTestKafkaParams() + + var map = pool.getAsMap + assert(map.isEmpty) + + val producer = pool.acquire(kafkaParams) + map = pool.getAsMap + assert(map.size === 1) + + clock.advance(minEvictableIdleTimeMillis + 100) + scheduler.tick(evictorThreadRunIntervalMillis + 100, TimeUnit.MILLISECONDS) + map = pool.getAsMap + assert(map.size === 1) + + pool.release(producer) + + // This will clean up expired instance from cache. + clock.advance(minEvictableIdleTimeMillis + 100) + scheduler.tick(evictorThreadRunIntervalMillis + 100, TimeUnit.MILLISECONDS) + + map = pool.getAsMap + assert(map.size === 0) + } + + test("reference counting with concurrent access") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + + val numThreads = 100 + val numProducerUsages = 500 + + def produce(i: Int): Unit = { + val producer = pool.acquire(kafkaParams) + try { + val map = pool.getAsMap + assert(map.size === 1) + val cacheEntry = map.head._2 + assert(cacheEntry.refCount > 0L) + assert(cacheEntry.expireAt === Long.MaxValue) + + Thread.sleep(Random.nextInt(100)) + } finally { + pool.release(producer) + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numProducerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { produce(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + } finally { + threadpool.shutdown() + } + + val map = pool.getAsMap + assert(map.size === 1) + + val cacheEntry = map.head._2 + assertCacheEntry(pool, cacheEntry, 0L) + } + + private def getTestKafkaParams(): ju.HashMap[String, Object] = { + val kafkaParams = new ju.HashMap[String, Object]() + kafkaParams.put("acks", "0") + // Here only host should be resolvable, it does not need a running instance of kafka server. + kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") + kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) + kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) + kafkaParams + } + + private def assertCacheEntry( + pool: InternalKafkaProducerPool, + cacheEntry: CachedProducerEntry, + expectedRefCount: Long): Unit = { + val timeoutVal = pool.cacheExpireTimeout + assert(cacheEntry.refCount === expectedRefCount) + if (expectedRefCount > 0) { + assert(cacheEntry.expireAt === Long.MaxValue) + } else { + assert(cacheEntry.expireAt <= pool.clock.getTimeMillis() + timeoutVal) + } + } +} From a7aac78d54577f69e1a3eca91916a37c730dec90 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 11 Dec 2019 16:11:15 +0900 Subject: [PATCH 2/6] Fix UT failures - SparkEnv is not always available --- .../sql/kafka010/producer/InternalKafkaProducerPool.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala index 9384713af250f..997f28ba35bae 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -148,7 +148,8 @@ private[producer] class InternalKafkaProducerPool( } private[kafka010] object InternalKafkaProducerPool extends Logging { - private val pool = new InternalKafkaProducerPool(SparkEnv.get.conf) + private val pool = new InternalKafkaProducerPool( + Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())) private type CacheKey = Seq[(String, Object)] private type Producer = KafkaProducer[Array[Byte], Array[Byte]] From ebb9341f2ec3f52bb1ddf01ccfad784ee3ef3a11 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 13 Dec 2019 20:04:53 +0900 Subject: [PATCH 3/6] Fix build failure --- .../scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index ac242ba3d1356..e2dcd62005310 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -370,7 +370,7 @@ class KafkaContinuousSinkSuite extends KafkaSinkStreamingSuiteBase { iter.foreach(writeTask.write(_)) writeTask.commit() } finally { - writeTask.invalidateProducer() + writeTask.close() } } } From 6ea2fc5fac4f17cfdb87d95ee73c7f5c6213b622 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 17 Dec 2019 05:49:39 +0900 Subject: [PATCH 4/6] Reflect review comments --- .../producer/CachedKafkaProducer.scala | 4 +- .../producer/InternalKafkaProducerPool.scala | 46 +++++++++++-------- .../InternalKafkaProducerPoolSuite.scala | 10 ++-- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala index 6682944fe4f4e..83519de0d3b1e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala @@ -26,8 +26,8 @@ import org.apache.kafka.clients.producer.KafkaProducer import org.apache.spark.internal.Logging private[kafka010] class CachedKafkaProducer( - val cacheKey: Seq[(String, Object)], - val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging { + val cacheKey: Seq[(String, Object)], + val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging { val id: String = ju.UUID.randomUUID().toString private[producer] def close(): Unit = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala index 997f28ba35bae..fde94d4bb8fda 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -48,7 +48,7 @@ private[producer] class InternalKafkaProducerPool( } /** exposed for testing */ - private[producer] val cacheExpireTimeout: Long = conf.get(PRODUCER_CACHE_TIMEOUT) + private[producer] val cacheExpireTimeoutMillis: Long = conf.get(PRODUCER_CACHE_TIMEOUT) private val evictorThreadRunIntervalMillis = conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) @@ -83,7 +83,8 @@ private[producer] class InternalKafkaProducerPool( val entry = cache.getOrElseUpdate(paramsSeq, { val producer = createKafkaProducer(paramsSeq) val cachedProducer = new CachedKafkaProducer(paramsSeq, producer) - new CachedProducerEntry(cachedProducer, clock, cacheExpireTimeout) + new CachedProducerEntry(cachedProducer, + TimeUnit.MILLISECONDS.toNanos(cacheExpireTimeoutMillis)) }) entry.handleBorrowed() entry.producer @@ -98,23 +99,24 @@ private[producer] class InternalKafkaProducerPool( synchronized { cache.get(producer.cacheKey) match { - case Some(entry) if entry.producer.id == producer.id => entry.handleReturned() - case _ => closeProducerNotInCache(producer) + case Some(entry) if entry.producer.id == producer.id => + entry.handleReturned(clock.nanoTime()) + case _ => + closeProducerNotInCache(producer) } } } private[producer] def shutdown(): Unit = { + scheduled.foreach(_.cancel(true)) ThreadUtils.shutdown(executorService) } /** exposed for testing. */ private[producer] def reset(): Unit = synchronized { scheduled.foreach(_.cancel(true)) - cache.foreach { case (k, v) => - cache.remove(k) - v.producer.close() - } + cache.foreach { case (_, v) => v.producer.close() } + cache.clear() scheduled = startEvictorThread() } @@ -122,11 +124,16 @@ private[producer] class InternalKafkaProducerPool( private[producer] def getAsMap: Map[CacheKey, CachedProducerEntry] = cache.toMap private def evictExpired(): Unit = { + val curTimeNs = clock.nanoTime() val producers = new mutable.ArrayBuffer[CachedProducerEntry]() synchronized { - cache.filter { case (_, v) => v.expired }.foreach { case (k, v) => - cache.remove(k) - producers += v + cache.retain { case (_, v) => + if (v.expired(curTimeNs)) { + producers += v + false + } else { + true + } } } producers.foreach { _.producer.close() } @@ -142,8 +149,7 @@ private[producer] class InternalKafkaProducerPool( } private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { - val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) - paramsSeq + kafkaParams.asScala.toSeq.sortBy(x => x._1) } } @@ -169,8 +175,7 @@ private[kafka010] object InternalKafkaProducerPool extends Logging { */ private[producer] class CachedProducerEntry( val producer: CachedKafkaProducer, - clock: Clock, - cacheExpireTimeout: Long) { + cacheExpireTimeoutNs: Long) { private var _refCount: Long = 0L private var _expireAt: Long = Long.MaxValue @@ -183,14 +188,17 @@ private[kafka010] object InternalKafkaProducerPool extends Logging { _expireAt = Long.MaxValue } - def handleReturned(): Unit = { + def handleReturned(curTimeNs: Long): Unit = { _refCount -= 1 - if (_refCount <= 0) { - _expireAt = clock.getTimeMillis() + cacheExpireTimeout + require(_refCount < 0, "Reference count shouldn't be negative. Returning same producer " + + "multiple times would occur this bug. Check the logic around returning producer.") + + if (_refCount == 0) { + _expireAt = curTimeNs + cacheExpireTimeoutNs } } - def expired: Boolean = _refCount <= 0 && _expireAt < clock.getTimeMillis() + def expired(curTimeNs: Long): Boolean = _refCount == 0 && _expireAt < curTimeNs } def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala index 266304bbc975b..97885754f204c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala @@ -53,7 +53,7 @@ class InternalKafkaProducerPoolSuite extends SharedSparkSession { val kafkaParams = getTestKafkaParams() val producer = pool.acquire(kafkaParams) val producer2 = pool.acquire(kafkaParams) - assert(producer === producer2) + assert(producer eq producer2) val map = pool.getAsMap assert(map.size === 1) @@ -68,7 +68,7 @@ class InternalKafkaProducerPoolSuite extends SharedSparkSession { val producer3 = pool.acquire(kafkaParams) assertCacheEntry(pool, cacheEntry, 1L) - assert(producer === producer3) + assert(producer eq producer3) } test("Should return different cached instances on calling acquire with different params.") { @@ -79,7 +79,7 @@ class InternalKafkaProducerPoolSuite extends SharedSparkSession { kafkaParams.put("acks", "1") val producer2 = pool.acquire(kafkaParams) // With updated conf, a new producer instance should be created. - assert(producer !== producer2) + assert(producer ne producer2) val map = pool.getAsMap assert(map.size === 2) @@ -181,12 +181,12 @@ class InternalKafkaProducerPoolSuite extends SharedSparkSession { pool: InternalKafkaProducerPool, cacheEntry: CachedProducerEntry, expectedRefCount: Long): Unit = { - val timeoutVal = pool.cacheExpireTimeout + val timeoutVal = TimeUnit.MILLISECONDS.toNanos(pool.cacheExpireTimeoutMillis) assert(cacheEntry.refCount === expectedRefCount) if (expectedRefCount > 0) { assert(cacheEntry.expireAt === Long.MaxValue) } else { - assert(cacheEntry.expireAt <= pool.clock.getTimeMillis() + timeoutVal) + assert(cacheEntry.expireAt <= pool.clock.nanoTime() + timeoutVal) } } } From 5cbd425259b3e67562f6fa9e30f53c76ef08b89c Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 17 Dec 2019 07:17:50 +0900 Subject: [PATCH 5/6] Silly mistake --- .../spark/sql/kafka010/producer/InternalKafkaProducerPool.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala index fde94d4bb8fda..0f6d617a89719 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -190,7 +190,7 @@ private[kafka010] object InternalKafkaProducerPool extends Logging { def handleReturned(curTimeNs: Long): Unit = { _refCount -= 1 - require(_refCount < 0, "Reference count shouldn't be negative. Returning same producer " + + require(_refCount >= 0, "Reference count shouldn't be negative. Returning same producer " + "multiple times would occur this bug. Check the logic around returning producer.") if (_refCount == 0) { From e9dc140ccac0958fc31a71a04d845f64528ace93 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 18 Dec 2019 07:18:54 +0900 Subject: [PATCH 6/6] Reflect review comments --- .../producer/InternalKafkaProducerPool.scala | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala index 0f6d617a89719..7a0c68eb74a35 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -50,12 +50,11 @@ private[producer] class InternalKafkaProducerPool( /** exposed for testing */ private[producer] val cacheExpireTimeoutMillis: Long = conf.get(PRODUCER_CACHE_TIMEOUT) - private val evictorThreadRunIntervalMillis = conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) - @GuardedBy("this") private val cache = new mutable.HashMap[CacheKey, CachedProducerEntry] private def startEvictorThread(): Option[ScheduledFuture[_]] = { + val evictorThreadRunIntervalMillis = conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) if (evictorThreadRunIntervalMillis > 0) { val future = executorService.scheduleAtFixedRate(() => { Utils.tryLogNonFatalError(evictExpired()) @@ -92,32 +91,26 @@ private[producer] class InternalKafkaProducerPool( } private[producer] def release(producer: CachedKafkaProducer): Unit = { - def closeProducerNotInCache(producer: CachedKafkaProducer): Unit = { - logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.") - producer.close() - } - synchronized { cache.get(producer.cacheKey) match { case Some(entry) if entry.producer.id == producer.id => entry.handleReturned(clock.nanoTime()) case _ => - closeProducerNotInCache(producer) + logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.") + producer.close() } } } private[producer] def shutdown(): Unit = { - scheduled.foreach(_.cancel(true)) + scheduled.foreach(_.cancel(false)) ThreadUtils.shutdown(executorService) } /** exposed for testing. */ private[producer] def reset(): Unit = synchronized { - scheduled.foreach(_.cancel(true)) cache.foreach { case (_, v) => v.producer.close() } cache.clear() - scheduled = startEvictorThread() } /** exposed for testing */ @@ -189,10 +182,10 @@ private[kafka010] object InternalKafkaProducerPool extends Logging { } def handleReturned(curTimeNs: Long): Unit = { - _refCount -= 1 - require(_refCount >= 0, "Reference count shouldn't be negative. Returning same producer " + + require(_refCount > 0, "Reference count shouldn't become negative. Returning same producer " + "multiple times would occur this bug. Check the logic around returning producer.") + _refCount -= 1 if (_refCount == 0) { _expireAt = curTimeNs + cacheExpireTimeoutNs }