diff --git a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java index 8cc6348a7c8a..7e32e098a13e 100644 --- a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java +++ b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java @@ -106,6 +106,8 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) if (useResultCache && newResultSetId != null && newResultSetId.equals(existingResultSetId)) { log.debug("Return cached result set as there is no change in identifiers for query %s ", query.getId()); + // Call accumulate on the sequence to ensure that all Wrapper/Closer/Baggage/etc. get called + resultFromClient.accumulate(null, (accumulated, in) -> accumulated); return deserializeResults(cachedResultSet, strategy, existingResultSetId); } else { @Nullable diff --git a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java index 3cb4ae528e67..a0c14239e494 100644 --- a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java +++ b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java @@ -23,8 +23,13 @@ import org.apache.druid.client.cache.Cache; import org.apache.druid.client.cache.CacheConfig; import org.apache.druid.client.cache.MapCache; +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.collections.DefaultBlockingPool; +import org.apache.druid.collections.ReferenceCountingResourceHolder; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.timeseries.TimeseriesResultValue; import org.apache.druid.timeline.DataSegment; import org.joda.time.Interval; @@ -32,10 +37,16 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; +import static org.junit.Assert.fail; + public class ResultLevelCachingQueryRunnerTest extends QueryRunnerBasedOnClusteredClientTestBase { private Cache cache; @@ -252,7 +263,7 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() ); try { sequence.toList(); - Assert.fail("Expected to throw an exception"); + fail("Expected to throw an exception"); } catch (RuntimeException e) { Assert.assertEquals("Exception for testing", e.getMessage()); @@ -264,6 +275,71 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() } } + @Test + public void testUseCacheAndReleaseResourceFromClient() + { + final BlockingPool mergePool = new DefaultBlockingPool<>(() -> ByteBuffer.allocate(1), 1); + prepareCluster(10); + final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); + CacheConfig cacheConfig = newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE); + final QueryRunner> baseRunner = cachingClusteredClient.getQueryRunnerForIntervals(query, query.getIntervals()); + RetryQueryRunner> spyRunner = Mockito.spy(new RetryQueryRunner<>( + baseRunner, + cachingClusteredClient::getQueryRunnerForSegments, + new RetryQueryRunnerConfig(), + objectMapper + )); + Mockito.doAnswer((Answer) invocation -> { + List> resoruce = mergePool.takeBatch(1, 1); + if (resoruce.isEmpty()) { + fail("Resource should not be empty"); + } + Sequence> realSequence = (Sequence>) invocation.callRealMethod(); + Closer closer = Closer.create(); + closer.register(() -> resoruce.forEach(ReferenceCountingResourceHolder::close)); + return Sequences.withBaggage(realSequence, closer); + }).when(spyRunner).run(ArgumentMatchers.any(), ArgumentMatchers.any()); + + final ResultLevelCachingQueryRunner> queryRunner1 = new ResultLevelCachingQueryRunner<>( + spyRunner, + conglomerate.getToolChest(query), + query, + objectMapper, + cache, + cacheConfig + ); + + final Sequence> sequence1 = queryRunner1.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results1 = sequence1.toList(); + Assert.assertEquals(0, cache.getStats().getNumHits()); + Assert.assertEquals(1, cache.getStats().getNumEntries()); + Assert.assertEquals(1, cache.getStats().getNumMisses()); + + + final Sequence> sequence2 = queryRunner1.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results2 = sequence2.toList(); + Assert.assertEquals(results1, results2); + Assert.assertEquals(1, cache.getStats().getNumHits()); + Assert.assertEquals(1, cache.getStats().getNumEntries()); + Assert.assertEquals(1, cache.getStats().getNumMisses()); + + final Sequence> sequence3 = queryRunner1.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results3 = sequence3.toList(); + Assert.assertEquals(results1, results3); + Assert.assertEquals(2, cache.getStats().getNumHits()); + Assert.assertEquals(1, cache.getStats().getNumEntries()); + Assert.assertEquals(1, cache.getStats().getNumMisses()); + } + private ResultLevelCachingQueryRunner createQueryRunner( CacheConfig cacheConfig, Query query