diff --git a/core/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java b/core/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java index 45cda5cbc452..d5915369eada 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java +++ b/core/src/main/java/org/apache/druid/java/util/common/parsers/CloseableIterator.java @@ -75,6 +75,7 @@ private CloseableIterator findNextIeteratorIfNecessary() if (iterator != null) { try { iterator.close(); + iterator = null; } catch (IOException e) { throw new UncheckedIOException(e); @@ -112,6 +113,10 @@ public R next() public void close() throws IOException { delegate.close(); + if (iterator != null) { + iterator.close(); + iterator = null; + } } }; } diff --git a/core/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java b/core/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java index 5434e2de3909..be2d1d58bd5c 100644 --- a/core/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java +++ b/core/src/test/java/org/apache/druid/java/util/common/parsers/CloseableIteratorTest.java @@ -23,6 +23,7 @@ import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -54,10 +55,18 @@ public void testMap() } @Test - public void testFlatMap() + public void testFlatMap() throws IOException { - final CloseableIterator actual = generateTestIterator(8) - .flatMap(list -> CloseableIterators.withEmptyBaggage(list.iterator())); + List> innerIterators = new ArrayList<>(); + final CloseTrackingCloseableIterator actual = new CloseTrackingCloseableIterator<>( + generateTestIterator(8) + .flatMap(list -> { + CloseTrackingCloseableIterator inner = + new CloseTrackingCloseableIterator<>(CloseableIterators.withEmptyBaggage(list.iterator())); + innerIterators.add(inner); + return inner; + }) + ); final Iterator expected = IntStream .range(0, 8) .flatMap(i -> IntStream.range(0, i)) @@ -67,6 +76,48 @@ public void testFlatMap() } Assert.assertFalse(actual.hasNext()); Assert.assertFalse(expected.hasNext()); + actual.close(); + Assert.assertEquals(1, actual.closeCount); + for (CloseTrackingCloseableIterator iter : innerIterators) { + Assert.assertEquals(1, iter.closeCount); + } + } + + @Test + public void testFlatMapClosedEarly() throws IOException + { + final int numIterations = 8; + List> innerIterators = new ArrayList<>(); + final CloseTrackingCloseableIterator actual = new CloseTrackingCloseableIterator<>( + generateTestIterator(numIterations) + .flatMap(list -> { + CloseTrackingCloseableIterator inner = + new CloseTrackingCloseableIterator<>(CloseableIterators.withEmptyBaggage(list.iterator())); + innerIterators.add(inner); + return inner; + }) + ); + final Iterator expected = IntStream + .range(0, numIterations) + .flatMap(i -> IntStream.range(0, i)) + .iterator(); + + // burn through the first few iterators + int cnt = 0; + int numFlatIterations = 5; + while (expected.hasNext() && actual.hasNext() && cnt++ < numFlatIterations) { + Assert.assertEquals(expected.next(), actual.next()); + } + // but stop while we still have an open current inner iterator and a few remaining inner iterators + Assert.assertTrue(actual.hasNext()); + Assert.assertTrue(expected.hasNext()); + Assert.assertEquals(4, innerIterators.size()); + Assert.assertTrue(innerIterators.get(innerIterators.size() - 1).hasNext()); + actual.close(); + Assert.assertEquals(1, actual.closeCount); + for (CloseTrackingCloseableIterator iter : innerIterators) { + Assert.assertEquals(1, iter.closeCount); + } } private static CloseableIterator> generateTestIterator(int numIterates) @@ -99,4 +150,36 @@ public void close() } }; } + + static class CloseTrackingCloseableIterator implements CloseableIterator + { + CloseableIterator inner; + int closeCount; + + public CloseTrackingCloseableIterator(CloseableIterator toTrack) + { + this.inner = toTrack; + this.closeCount = 0; + } + + + @Override + public void close() throws IOException + { + inner.close(); + closeCount++; + } + + @Override + public boolean hasNext() + { + return inner.hasNext(); + } + + @Override + public T next() + { + return inner.next(); + } + } }