diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a507..5fafa3210c3c9 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -107,7 +107,7 @@ class CachedKafkaConsumer[K, V] private( private[kafka010] object CachedKafkaConsumer extends Logging { - private case class CacheKey(groupId: String, topic: String, partition: Int) + private case class CacheKey(groupId: String, topic: String, partition: Int, threadId: Long) // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null @@ -147,9 +147,10 @@ object CachedKafkaConsumer extends Logging { groupId: String, topic: String, partition: Int, + threadId: Long, kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) + val k = CacheKey(groupId, topic, partition, threadId) val v = cache.get(k) if (null == v) { logInfo(s"Cache miss for $k") @@ -175,8 +176,8 @@ object CachedKafkaConsumer extends Logging { new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) + def remove(groupId: String, topic: String, partition: Int, threadId: Long): Unit = { + val k = CacheKey(groupId, topic, partition, threadId) logInfo(s"Removing $k from cache") val v = CachedKafkaConsumer.synchronized { cache.remove(k) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc206647..4d26d371aafef 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -193,6 +193,8 @@ private[spark] class KafkaRDD[K, V]( logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + val threadId = Thread.currentThread().getId + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] context.addTaskCompletionListener{ context => closeIfNeeded() } @@ -201,9 +203,9 @@ private[spark] class KafkaRDD[K, V]( CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) if (context.attemptNumber >= 1) { // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + CachedKafkaConsumer.remove(groupId, part.topic, part.partition, threadId) } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, threadId, kafkaParams) } else { CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) }