diff --git a/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java b/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java index 7e4518947787..a18a1c805c3d 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java +++ b/core/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java @@ -22,13 +22,16 @@ import com.google.common.collect.Lists; import com.google.common.collect.Ordering; import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.utils.JvmUtils; import javax.annotation.Nullable; +import java.io.Closeable; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -300,6 +303,7 @@ private MergeCombinePartitioningAction( @Override protected void compute() { + List> sequenceCursors = new ArrayList<>(sequences.size()); try { final int parallelTaskCount = computeNumTasks(); @@ -315,7 +319,6 @@ protected void compute() QueuePusher> resultsPusher = new QueuePusher<>(out, hasTimeout, timeoutAt); - List> sequenceCursors = new ArrayList<>(sequences.size()); for (Sequence s : sequences) { sequenceCursors.add(new YielderBatchedResultsCursor<>(new SequenceBatcher<>(s, batchSize), orderingFn)); } @@ -340,8 +343,9 @@ protected void compute() spawnParallelTasks(parallelTaskCount); } } - catch (Exception ex) { - cancellationGizmo.cancel(ex); + catch (Throwable t) { + closeAllCursors(sequenceCursors); + cancellationGizmo.cancel(t); out.offer(ResultBatch.TERMINAL); } } @@ -624,6 +628,8 @@ protected void compute() // if we got the cancellation signal, go ahead and write terminal value into output queue to help gracefully // allow downstream stuff to stop LOG.debug("cancelled after %s tasks", metricsAccumulator.getTaskCount()); + // make sure to close underlying cursors + closeAllCursors(pQueue); outputQueue.offer(ResultBatch.TERMINAL); } else { // if priority queue is empty, push the final accumulated value into the output batch and push it out @@ -635,8 +641,9 @@ protected void compute() LOG.debug("merge combine complete after %s tasks", metricsAccumulator.getTaskCount()); } } - catch (Exception ex) { - cancellationGizmo.cancel(ex); + catch (Throwable t) { + closeAllCursors(pQueue); + cancellationGizmo.cancel(t); outputQueue.offer(ResultBatch.TERMINAL); } } @@ -695,13 +702,15 @@ private PrepareMergeCombineInputsAction( @Override protected void compute() { + PriorityQueue> cursors = new PriorityQueue<>(partition.size()); try { - PriorityQueue> cursors = new PriorityQueue<>(partition.size()); for (BatchedResultsCursor cursor : partition) { // this is blocking cursor.initialize(); if (!cursor.isDone()) { cursors.offer(cursor); + } else { + cursor.close(); } } @@ -722,8 +731,9 @@ protected void compute() outputQueue.offer(ResultBatch.TERMINAL); } } - catch (Exception ex) { - cancellationGizmo.cancel(ex); + catch (Throwable t) { + closeAllCursors(partition); + cancellationGizmo.cancel(t); outputQueue.offer(ResultBatch.TERMINAL); } } @@ -849,6 +859,7 @@ static Yielder> fromSequence(Sequence sequence, int batchS new YieldingAccumulator, E>() { int count = 0; + @Override public ResultBatch accumulate(ResultBatch accumulated, E in) { @@ -913,7 +924,7 @@ public boolean isReleasable() * from these cursors, and combine results with the same ordering using the combining function. */ abstract static class BatchedResultsCursor - implements ForkJoinPool.ManagedBlocker, Comparable> + implements ForkJoinPool.ManagedBlocker, Comparable>, Closeable { final Ordering ordering; volatile ResultBatch resultBatch; @@ -939,7 +950,8 @@ void nextBatch() } } - public void close() + @Override + public void close() throws IOException { // nothing to close for blocking queue, but yielders will need to clean up or they will leak resources } @@ -1034,14 +1046,11 @@ public boolean isReleasable() } @Override - public void close() + public void close() throws IOException { - try { + if (yielder != null) { yielder.close(); } - catch (IOException e) { - throw new RuntimeException("Failed to close yielder", e); - } } } @@ -1135,21 +1144,21 @@ public boolean isReleasable() */ static class CancellationGizmo { - private final AtomicReference exception = new AtomicReference<>(null); + private final AtomicReference throwable = new AtomicReference<>(null); - void cancel(Exception ex) + void cancel(Throwable t) { - exception.compareAndSet(null, ex); + throwable.compareAndSet(null, t); } boolean isCancelled() { - return exception.get() != null; + return throwable.get() != null; } RuntimeException getRuntimeException() { - Exception ex = exception.get(); + Throwable ex = throwable.get(); if (ex instanceof RuntimeException) { return (RuntimeException) ex; } @@ -1350,4 +1359,11 @@ long getTotalCpuTimeNanos() return totalCpuTimeNanos; } } + + private static void closeAllCursors(final Collection> cursors) + { + Closer closer = Closer.create(); + closer.registerAll(cursors); + CloseQuietly.close(closer); + } } diff --git a/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java b/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java index e96e1e3a0aa1..8e2b4e5025c2 100644 --- a/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java +++ b/core/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java @@ -40,6 +40,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BinaryOperator; import java.util.function.Consumer; @@ -63,6 +64,9 @@ public class ParallelMergeCombiningSequenceTest private ForkJoinPool pool; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @Before public void setup() { @@ -80,8 +84,6 @@ public void teardown() pool.shutdown(); } - @Rule - public ExpectedException expectedException = ExpectedException.none(); @Test public void testOrderedResultBatchFromSequence() throws IOException @@ -448,12 +450,21 @@ public void testExceptionOnInputSequenceRead() throws Exception "exploded" ); assertException(input); + } + @Test + public void testExceptionOnInputSequenceRead2() throws Exception + { + List> input = new ArrayList<>(); input.add(nonBlockingSequence(5)); input.add(nonBlockingSequence(25)); input.add(explodingSequence(11)); input.add(nonBlockingSequence(12)); + expectedException.expect(RuntimeException.class); + expectedException.expectMessage( + "exploded" + ); assertException(input); } @@ -653,6 +664,12 @@ private void assertException( parallelMergeCombineYielder.close(); } catch (Exception ex) { + sequences.forEach(sequence -> { + if (sequence instanceof ExplodingSequence) { + ExplodingSequence exploder = (ExplodingSequence) sequence; + Assert.assertEquals(1, exploder.getCloseCount()); + } + }); LOG.warn(ex, "exception:"); throw ex; } @@ -808,42 +825,60 @@ private static Sequence nonBlockingSequence(int size) private static Sequence explodingSequence(int explodeAfter) { final int explodeAt = explodeAfter + 1; - return new BaseSequence<>( - new BaseSequence.IteratorMaker>() - { - @Override - public Iterator make() + + // we start at one because we only need to close if the sequence is actually made + AtomicInteger explodedIteratorMakerCleanup = new AtomicInteger(1); + + // just hijacking this class to use it's interface... which i override.. + return new ExplodingSequence( + new BaseSequence<>( + new BaseSequence.IteratorMaker>() { - return new Iterator() + @Override + public Iterator make() { - int mergeKey = 0; - int rowCounter = 0; - @Override - public boolean hasNext() + // we got yielder, decrement so we expect it to be incremented again on cleanup + explodedIteratorMakerCleanup.decrementAndGet(); + return new Iterator() { - return rowCounter < explodeAt; - } + int mergeKey = 0; + int rowCounter = 0; + @Override + public boolean hasNext() + { + return rowCounter < explodeAt; + } - @Override - public IntPair next() - { - if (rowCounter == explodeAfter) { - throw new RuntimeException("exploded"); + @Override + public IntPair next() + { + if (rowCounter == explodeAfter) { + throw new RuntimeException("exploded"); + } + mergeKey += incrementMergeKeyAmount(); + rowCounter++; + return makeIntPair(mergeKey); } - mergeKey += incrementMergeKeyAmount(); - rowCounter++; - return makeIntPair(mergeKey); - } - }; - } + }; + } - @Override - public void cleanup(Iterator iterFromMake) - { - // nothing to cleanup + @Override + public void cleanup(Iterator iterFromMake) + { + explodedIteratorMakerCleanup.incrementAndGet(); + } } - } - ); + ), + false, + false + ) + { + @Override + public long getCloseCount() + { + return explodedIteratorMakerCleanup.get(); + } + }; } private static List generateOrderedPairs(int length)