diff --git a/core/src/main/java/org/apache/druid/collections/SerializablePair.java b/core/src/main/java/org/apache/druid/collections/SerializablePair.java index ce98933ee8dc..66f235054314 100644 --- a/core/src/main/java/org/apache/druid/collections/SerializablePair.java +++ b/core/src/main/java/org/apache/druid/collections/SerializablePair.java @@ -31,15 +31,17 @@ public SerializablePair(@JsonProperty("lhs") T1 lhs, @JsonProperty("rhs") T2 rhs super(lhs, rhs); } + @Override @JsonProperty public T1 getLhs() { - return lhs; + return super.getLhs(); } + @Override @JsonProperty public T2 getRhs() { - return rhs; + return super.getRhs(); } } diff --git a/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java b/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java index 47d996033290..48c1e825531d 100644 --- a/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java +++ b/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java @@ -22,8 +22,12 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.primitives.Longs; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import javax.annotation.Nullable; +import java.util.List; +import java.util.stream.Stream; /** */ @@ -62,4 +66,17 @@ public static > T getEnumIfPresent(final Class enumClass, f return null; } + + /** + * Materialze the stream of futures into a single listenable future that will return the list of results. + * + * @param futures The futures to collect into a single Listenable future + * @param The return value for the futures + * + * @return A single ListenableFuture whose return value is a list of the completed values of the input stream. + */ + public static ListenableFuture> allFuturesAsList(Stream> futures) + { + return Futures.allAsList(futures::iterator); + } } diff --git a/core/src/main/java/org/apache/druid/java/util/common/JodaUtils.java b/core/src/main/java/org/apache/druid/java/util/common/JodaUtils.java index 75998cd7d35e..198eee00e7d8 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/JodaUtils.java +++ b/core/src/main/java/org/apache/druid/java/util/common/JodaUtils.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.SortedSet; +import java.util.concurrent.TimeoutException; /** */ @@ -138,4 +139,23 @@ public static DateTime maxDateTime(DateTime... times) return max; } } + + /** + * Return a qty of millisconds approximately until deadline. If deadline has passed, throw TimeoutException + * + * @param deadline The time on or after which things should be considered "timed out" + * + * @return A millisecond number where, if one were to wait that many milliseconds, the deadline would + * probably have passed. Always greater than zero + * + * @throws TimeoutException If the deadline has already passed (ties are treated as having passed the deadline) + */ + public static long timeoutForDeadline(DateTime deadline) throws TimeoutException + { + final DateTime now = DateTimes.nowUtc(); + if (now.isAfter(deadline) || now.isEqual(deadline)) { + throw new TimeoutException(StringUtils.format("Deadline passed: [%s]", deadline)); + } + return deadline.getMillis() - now.getMillis(); + } } diff --git a/core/src/main/java/org/apache/druid/java/util/common/Pair.java b/core/src/main/java/org/apache/druid/java/util/common/Pair.java index 4b2acad6f1b9..e18bed5c3481 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/Pair.java +++ b/core/src/main/java/org/apache/druid/java/util/common/Pair.java @@ -20,7 +20,10 @@ package org.apache.druid.java.util.common; import javax.annotation.Nullable; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collector; +import java.util.stream.Collectors; /** */ @@ -32,6 +35,14 @@ public static Pair of(@Nullable T1 lhs, @Nullable T2 rhs) return new Pair<>(lhs, rhs); } + public static Collector, ?, Map> mapCollector() + { + return Collectors.toMap( + Pair::getLhs, + Pair::getRhs + ); + } + @Nullable public final T1 lhs; @@ -47,6 +58,18 @@ public Pair( this.rhs = rhs; } + @Nullable + public T1 getLhs() + { + return lhs; + } + + @Nullable + public T2 getRhs() + { + return rhs; + } + @Override public boolean equals(Object o) { diff --git a/core/src/main/java/org/apache/druid/java/util/common/concurrent/Execs.java b/core/src/main/java/org/apache/druid/java/util/common/concurrent/Execs.java index e69e661108ec..488e99db787a 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/concurrent/Execs.java +++ b/core/src/main/java/org/apache/druid/java/util/common/concurrent/Execs.java @@ -22,13 +22,21 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.JodaUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.joda.time.DateTime; import javax.annotation.Nullable; import javax.validation.constraints.NotNull; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionHandler; import java.util.concurrent.ScheduledExecutorService; @@ -36,6 +44,8 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLong; /** */ @@ -147,4 +157,113 @@ public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) } ); } + + private static final AtomicLong fjpWorkerThreadCount = new AtomicLong(0L); + + public static ForkJoinWorkerThread makeWorkerThread(String name, ForkJoinPool pool) + { + final ForkJoinWorkerThread t = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); + t.setDaemon(true); + final long threadNumber = fjpWorkerThreadCount.incrementAndGet(); + t.setName(StringUtils.nonStrictFormat(name, threadNumber)); + return t; + } + + private static final int DUMMY_THREAD_NUMBER = 17; + + /** + * Fail fast if the format can't take a single argument integer for a thread counter. + * + * Note that LACK of any argument in the format string still renders a valid name + * + * @param format The name format to check + * + * @throws java.util.IllegalFormatException if the format passed in does is not able to take a single thread parameter + */ + public static void checkThreadNameFormat(String format) + { + StringUtils.format(format, DUMMY_THREAD_NUMBER); + } + + /** + * Get the result for the future (without timeout), but do so in a way safe for running in a ForkJoinPool + * + * @param future The future to block on completion + * @param The type of the return value + * + * @return The result of the future if successfully completed, or one of the exceptions if not + * + * @throws InterruptedException If the call to future.get() was interrupted + * @throws ExecutionException If the future completed with an exception + */ + public static T futureManagedBlockGet(final Future future) + throws InterruptedException, ExecutionException + { + ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker() + { + @Override + public boolean block() throws InterruptedException + { + try { + future.get(); + } + catch (ExecutionException e) { + // Ignore, will be caught when get is called below + } + return true; + } + + @Override + public boolean isReleasable() + { + return future.isDone(); + } + }); + return future.get(); + } + + /** + * Attempt to get the result of the future before the deadline, but do so in a way safe to run in a ForkJoinPool. + * The deadline is best effort. It is possible the future completes, but the deadline is exceeded before the result + * can be returned. In such a scenario a TimeoutException will be thrown. + * + * The caller is responsible for handling the state of the Future in the case of an exception being thrown. + * Specifically, if an InterruptedException or a TimeoutException is thrown, there is no attempt in this method + * to change the behavior of the future. The caller should handle the potentially still active future as they see fit. + * + * @param future The future to await completion + * @param deadline Best effort deadline for the completion of the future. + * @param The future's yielded type + * + * @return The yield of the future or else a thrown exception + * + * @throws InterruptedException If the call to future.get is interrupted + * @throws TimeoutException If the deadline is exceeded + * @throws ExecutionException If the future completed with an exception + */ + public static T futureManagedBlockGet(final Future future, final DateTime deadline) + throws InterruptedException, TimeoutException, ExecutionException + { + ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker() + { + @Override + public boolean block() throws InterruptedException + { + try { + future.get(JodaUtils.timeoutForDeadline(deadline), TimeUnit.MILLISECONDS); + } + catch (ExecutionException | TimeoutException e) { + // Will get caught later + } + return true; + } + + @Override + public boolean isReleasable() + { + return future.isDone() || deadline.isBefore(DateTimes.nowUtc()); + } + }); + return future.get(JodaUtils.timeoutForDeadline(deadline), TimeUnit.MILLISECONDS); + } } diff --git a/core/src/main/java/org/apache/druid/java/util/common/guava/MergeWorkTask.java b/core/src/main/java/org/apache/druid/java/util/common/guava/MergeWorkTask.java new file mode 100644 index 000000000000..7f5d4cd07c5b --- /dev/null +++ b/core/src/main/java/org/apache/druid/java/util/common/guava/MergeWorkTask.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.java.util.common.guava; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.druid.java.util.common.Pair; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinTask; +import java.util.function.Function; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +public class MergeWorkTask extends ForkJoinTask> +{ + + /** + * Take a stream of sequences, split them as possible, and do intermediate merges. If the input stream is not + * a parallel stream, do a traditional merge. The stream attempts to use groups of {@code batchSize} to do its work, + * but this goal is on a best effort basis. Input streams that cannot be split or are not sized or not subsized + * might not be elligable for this parallelization. The intermediate merges are done in the passed in ForkJoinPool, + * but the final merge is still done when the returned sequence accumulated. The intermediate merges are yielded + * in the order in which they are ready. + * + * Exceptions that happen during execution of the merge are passed through and bubbled up during the resulting sequence + * iteration + * + * @param mergerFn The function that will merge a stream of sequences into a single sequence. If the + * baseSequences stream is parallel, this work will be done in the FJP, otherwise it + * will be called directly. + * @param baseSequences The sequences that need merged + * @param batchSize The input stream should be split down to this number if possible. This sets the target number of segments per merge thread work + * @param fjp The ForkJoinPool to do the intermediate merges in. + * @param The result type + * + * @return A Sequence that will be the merged results of the sub-sequences + * + * @throws RuntimeException Will throw a RuntimeException in during iterating through the returned Sequence if a Throwable + * was encountered in an intermediate merge + */ + public static Sequence parallelMerge( + Stream> baseSequences, + Function>, Sequence> mergerFn, + long batchSize, + ForkJoinPool fjp + ) + { + if (!baseSequences.isParallel()) { + // Don't even try. + return mergerFn.apply(baseSequences); + } + if (batchSize < 1) { + throw new IllegalArgumentException("Batch size must be greater than 0"); + } + + // At first glance this looks like an alternative implementation for a RecursiveTask because it does the following: + // 1. Divides the input work up into batches + // 2. Joins the results in a merging operation + // + // While these are true, there are some differences in this implementation and a raw RecursiveTask that are worth + // calling out. First, the results are fed into a BlockingQueue so that the final merge can accumulate as soon as + // the first intermediate result is available. This design constraint makes a RecursiveTask rather odd since the + // intended use case would have intermediate merges chain up to the top merge, rather than a single top merge + // accumulating the total results. This does not preclude a RecursiveAction that can feed the results into a + // blocking queue. + // + // But in such an implementation the total needed queue size is not known until all the recursive actions are + // forked off similar to the implementation here. The difference being the implementation below has a dedicated + // action submitted to the fjp for joining the result and feeding it into the result stream. Since this dedicated + // feeder work is submitted after all the tasks are launched, the total queue size needed is known ahead of time, + // and the blocking queue can be pre-allocated with the correct capacity to ensure submission to the queue never + // blocks. + // + // In addition, there exists an ability in this implementation to cancel all the forked tasks if the stream is + // closed (like on the case of query cancellation). + // + // Since there is a desire to + // 1. Ensure the intermediate results do not block when being fed into the final merge queue + // 2. Have the ability to cancel outstanding work tasks if the resulting Sequence is cancelled + // this implementation deviates from a straight up RecursiveTask or RecursiveAction implementation to attempt to + // provide an easy to follow and reason about workflow. + + @SuppressWarnings("unchecked") // Wildcard erasure is fine here + final Spliterator> baseSpliterator = (Spliterator>) baseSequences.spliterator(); + + // Accumulate a list of forked off tasks + final List>> tasks = new ArrayList<>(); + final long totalResults = baseSpliterator.estimateSize(); + long dequeueInitialCapacity = totalResults / batchSize + 1; + if (dequeueInitialCapacity < 16) { + // 16 is the default element count size in ArrayDeque at the time of this writing. + dequeueInitialCapacity = 16; + } + final Deque>> spliteratorStack = new ArrayDeque<>((int) dequeueInitialCapacity); + + // Push the base spliterator onto the stack, keep splitting until we can't or splits are small + spliteratorStack.push(baseSpliterator); + while (!spliteratorStack.isEmpty()) { + + final Spliterator> pop = spliteratorStack.pop(); + if (pop.estimateSize() <= batchSize) { + // Batch is small enough, yay! + tasks.add(fjp.submit(new MergeWorkTask<>(mergerFn, pop))); + continue; + } + + final Spliterator> other = pop.trySplit(); + if (other == null) { + // splits are too big, but we can't split any more + tasks.add(fjp.submit(new MergeWorkTask<>(mergerFn, pop))); + continue; + } + spliteratorStack.push(pop); + spliteratorStack.push(other); + } + + // We guarantee enough space to put all the results so that the FJP doesn't block waiting for results to come in + final BlockingQueue, Throwable>> readyForFinalMerge = new ArrayBlockingQueue<>(tasks.size()); + // Submit a simple feeder into the final merge queue. Since readyForFinalMerge is sized to the number of tasks, + // the readyForFinalMerge.add call should never block. + tasks.forEach(task -> fjp.submit(() -> { + try { + readyForFinalMerge.add(Pair.of(task.join(), null)); + } + catch (Throwable t) { + // FJP.join exceptions are different than executor service's `.get()` + readyForFinalMerge.add(Pair.of(null, t)); + } + })); + + final long totalAdditions = tasks.size(); + return mergerFn.apply( + StreamSupport.stream( + Spliterators.spliterator( + new Iterator>() + { + long taken = 0L; + + @Override + public boolean hasNext() + { + return taken < totalAdditions; + } + + @Override + public Sequence next() + { + if (taken >= totalAdditions) { + throw new NoSuchElementException(); + } + try { + taken++; + final Pair, Throwable> result = readyForFinalMerge.take(); + if (result.rhs != null) { + throw new RuntimeException("failed in executing merge task", result.rhs); + } + return result.lhs; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted waiting for intermediate merge", e); + } + } + }, + totalAdditions, + Spliterator.NONNULL | Spliterator.SIZED + ), + false + ).onClose(() -> tasks.forEach(t -> t.cancel(true))) + ); + } + + private final Spliterator> baseSpliterator; + private final Function>, Sequence> mergerFn; + private Sequence result; + + @VisibleForTesting + MergeWorkTask( + Function>, Sequence> mergerFn, + Spliterator> baseSpliterator + ) + { + this.mergerFn = mergerFn; + this.baseSpliterator = baseSpliterator; + } + + @Override + public Sequence getRawResult() + { + return result; + } + + @Override + protected void setRawResult(Sequence value) + { + result = value; + } + + @Override + protected boolean exec() + { + // Force materialization "work" in this thread + // For singleton lists it is not clear it is even worth the optimization of short circuiting the merge for the + // extra code maintenance overhead + result = mergerFn.apply(StreamSupport.stream(baseSpliterator, false)); + return true; + } +} diff --git a/core/src/main/java/org/apache/druid/java/util/common/guava/Sequences.java b/core/src/main/java/org/apache/druid/java/util/common/guava/Sequences.java index 2bab97141d5d..d6db132389bd 100644 --- a/core/src/main/java/org/apache/druid/java/util/common/guava/Sequences.java +++ b/core/src/main/java/org/apache/druid/java/util/common/guava/Sequences.java @@ -30,6 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.Executor; +import java.util.stream.Stream; /** */ @@ -58,6 +59,27 @@ public void cleanup(Iterator iterFromMake) ); } + public static Sequence fromStream(final Stream stream) + { + return new BaseSequence<>( + new BaseSequence.IteratorMaker>() + { + @Override + @SuppressWarnings("unchecked") + public Iterator make() + { + return (Iterator) stream.iterator(); + } + + @Override + public void cleanup(Iterator iterFromMake) + { + stream.close(); + } + } + ); + } + @SuppressWarnings("unchecked") public static Sequence empty() { diff --git a/core/src/test/java/org/apache/druid/common/guava/CombiningSequenceTest.java b/core/src/test/java/org/apache/druid/common/guava/CombiningSequenceTest.java deleted file mode 100644 index 9f64d6732863..000000000000 --- a/core/src/test/java/org/apache/druid/common/guava/CombiningSequenceTest.java +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.common.guava; - -import com.google.common.base.Predicate; -import com.google.common.collect.Iterables; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Ordering; -import org.apache.druid.java.util.common.Pair; -import org.apache.druid.java.util.common.guava.Sequence; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.YieldingAccumulator; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; - -import javax.annotation.Nullable; -import java.io.Closeable; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -@RunWith(Parameterized.class) -public class CombiningSequenceTest -{ - @Parameterized.Parameters - public static Collection valuesToTry() - { - return Arrays.asList(new Object[][]{ - {1}, {2}, {3}, {4}, {5}, {1000} - }); - } - - private final int yieldEvery; - - public CombiningSequenceTest(int yieldEvery) - { - this.yieldEvery = yieldEvery; - } - - @Test - public void testMerge() throws Exception - { - List> pairs = Arrays.asList( - Pair.of(0, 1), - Pair.of(0, 2), - Pair.of(0, 3), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 1), - Pair.of(5, 10), - Pair.of(6, 1), - Pair.of(5, 1) - ); - List> expected = Arrays.asList( - Pair.of(0, 6), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 11), - Pair.of(6, 1), - Pair.of(5, 1) - ); - - testCombining(pairs, expected); - } - - @Test - public void testNoMergeOne() throws Exception - { - List> pairs = Collections.singletonList( - Pair.of(0, 1) - ); - - List> expected = Collections.singletonList( - Pair.of(0, 1) - ); - - testCombining(pairs, expected); - } - - @Test - public void testMergeMany() throws Exception - { - List> pairs = Arrays.asList( - Pair.of(0, 6), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 11), - Pair.of(6, 1), - Pair.of(5, 1) - ); - - List> expected = Arrays.asList( - Pair.of(0, 6), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 11), - Pair.of(6, 1), - Pair.of(5, 1) - ); - - testCombining(pairs, expected); - } - - @Test - public void testNoMergeTwo() throws Exception - { - List> pairs = Arrays.asList( - Pair.of(0, 1), - Pair.of(1, 1) - ); - - List> expected = Arrays.asList( - Pair.of(0, 1), - Pair.of(1, 1) - ); - - testCombining(pairs, expected); - } - - @Test - public void testMergeTwo() throws Exception - { - List> pairs = Arrays.asList( - Pair.of(0, 1), - Pair.of(0, 1) - ); - - List> expected = Collections.singletonList( - Pair.of(0, 2) - ); - - testCombining(pairs, expected); - } - - @Test - public void testMergeSomeThingsMergedAtEnd() throws Exception - { - List> pairs = Arrays.asList( - Pair.of(0, 1), - Pair.of(0, 2), - Pair.of(0, 3), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 1), - Pair.of(5, 10), - Pair.of(6, 1), - Pair.of(5, 1), - Pair.of(5, 2), - Pair.of(5, 2), - Pair.of(5, 2), - Pair.of(5, 2), - Pair.of(5, 2) - ); - List> expected = Arrays.asList( - Pair.of(0, 6), - Pair.of(1, 1), - Pair.of(2, 1), - Pair.of(5, 11), - Pair.of(6, 1), - Pair.of(5, 11) - ); - - testCombining(pairs, expected); - } - - @Test - public void testNothing() throws Exception - { - testCombining(Collections.emptyList(), Collections.emptyList()); - } - - private void testCombining(List> pairs, List> expected) - throws Exception - { - for (int limit = 0; limit < expected.size() + 1; limit++) { - // limit = 0 doesn't work properly; it returns 1 element - final int expectedLimit = limit == 0 ? 1 : limit; - - testCombining( - pairs, - Lists.newArrayList(Iterables.limit(expected, expectedLimit)), - limit - ); - } - } - - private void testCombining( - List> pairs, - List> expected, - int limit - ) throws Exception - { - // Test that closing works too - final CountDownLatch closed = new CountDownLatch(1); - final Closeable closeable = closed::countDown; - - Sequence> seq = CombiningSequence.create( - Sequences.simple(pairs).withBaggage(closeable), - Ordering.natural().onResultOf(p -> p.lhs), - (lhs, rhs) -> { - if (lhs == null) { - return rhs; - } - - if (rhs == null) { - return lhs; - } - - return Pair.of(lhs.lhs, lhs.rhs + rhs.rhs); - } - ).limit(limit); - - List> merged = seq.toList(); - - Assert.assertEquals(expected, merged); - - Yielder> yielder = seq.toYielder( - null, - new YieldingAccumulator, Pair>() - { - int count = 0; - - @Override - public Pair accumulate( - Pair lhs, Pair rhs - ) - { - count++; - if (count % yieldEvery == 0) { - yield(); - } - return rhs; - } - } - ); - - Iterator> expectedVals = Iterators.filter( - expected.iterator(), - new Predicate>() - { - int count = 0; - - @Override - public boolean apply( - @Nullable Pair input - ) - { - count++; - if (count % yieldEvery == 0) { - return true; - } - return false; - } - } - ); - - if (expectedVals.hasNext()) { - while (!yielder.isDone()) { - final Pair expectedVal = expectedVals.next(); - final Pair actual = yielder.get(); - Assert.assertEquals(expectedVal, actual); - yielder = yielder.next(actual); - } - } - Assert.assertTrue(yielder.isDone()); - Assert.assertFalse(expectedVals.hasNext()); - yielder.close(); - - Assert.assertTrue("resource closed", closed.await(10000, TimeUnit.MILLISECONDS)); - } -} diff --git a/core/src/test/java/org/apache/druid/common/utils/JodaUtilsTest.java b/core/src/test/java/org/apache/druid/common/utils/JodaUtilsTest.java index c5610f969c80..f1827946a903 100644 --- a/core/src/test/java/org/apache/druid/common/utils/JodaUtilsTest.java +++ b/core/src/test/java/org/apache/druid/common/utils/JodaUtilsTest.java @@ -19,6 +19,7 @@ package org.apache.druid.common.utils; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.JodaUtils; import org.joda.time.Duration; @@ -30,6 +31,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeoutException; /** */ @@ -144,4 +146,35 @@ public void testMinMaxPeriod() Assert.assertEquals(Long.MAX_VALUE, period.getMinutes()); } + @Test(expected = TimeoutException.class) + public void testDeadlinePassed() throws TimeoutException + { + JodaUtils.timeoutForDeadline(DateTimes.nowUtc().minus(1)); + } + + @Test(expected = TimeoutException.class) + public void testDeadlineMinInstant() throws TimeoutException + { + JodaUtils.timeoutForDeadline(DateTimes.utc(JodaUtils.MIN_INSTANT)); + } + + @Test + public void testDeadlineMaxInstant() throws TimeoutException + { + final long ms = JodaUtils.timeoutForDeadline(DateTimes.utc(JodaUtils.MAX_INSTANT)); + Assert.assertTrue(ms > 0); + } + + @Test(expected = TimeoutException.class) + public void testDeadlineLongMin() throws TimeoutException + { + JodaUtils.timeoutForDeadline(DateTimes.utc(Long.MIN_VALUE)); + } + + @Test + public void testDeadlineLongMax() throws TimeoutException + { + final long ms = JodaUtils.timeoutForDeadline(DateTimes.utc(Long.MAX_VALUE)); + Assert.assertTrue(ms > 0); + } } diff --git a/core/src/test/java/org/apache/druid/concurrent/ExecsTest.java b/core/src/test/java/org/apache/druid/concurrent/ExecsTest.java deleted file mode 100644 index cedd99dcdd74..000000000000 --- a/core/src/test/java/org/apache/druid/concurrent/ExecsTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.concurrent; - -import com.google.common.base.Throwables; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.druid.java.util.common.concurrent.Execs; -import org.apache.druid.java.util.common.logger.Logger; -import org.junit.Assert; -import org.junit.Test; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicInteger; - -public class ExecsTest -{ - private static final Logger log = new Logger(ExecsTest.class); - - @Test - public void testBlockingExecutorServiceZeroCapacity() throws Exception - { - runTest(0); - } - - @Test - public void testBlockingExecutorServiceOneCapacity() throws Exception - { - runTest(1); - } - - @Test - public void testBlockingExecutorServiceThreeCapacity() throws Exception - { - runTest(3); - } - - private static void runTest(final int capacity) throws Exception - { - final int nTasks = (capacity + 1) * 3; - final ExecutorService blockingExecutor = Execs.newBlockingSingleThreaded("ExecsTest-Blocking-%d", capacity); - final CountDownLatch queueShouldBeFullSignal = new CountDownLatch(capacity + 1); - final CountDownLatch taskCompletedSignal = new CountDownLatch(nTasks); - final CountDownLatch taskStartSignal = new CountDownLatch(1); - final AtomicInteger producedCount = new AtomicInteger(); - final AtomicInteger consumedCount = new AtomicInteger(); - final ExecutorService producer = Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder().setNameFormat( - "ExecsTest-Producer-%d" - ).build() - ); - producer.submit( - new Runnable() - { - @Override - public void run() - { - for (int i = 0; i < nTasks; i++) { - final int taskID = i; - log.info("Produced task %d", taskID); - blockingExecutor.submit( - new Runnable() - { - @Override - public void run() - { - log.info("Starting task: %s", taskID); - try { - taskStartSignal.await(); - consumedCount.incrementAndGet(); - taskCompletedSignal.countDown(); - } - catch (Exception e) { - throw Throwables.propagate(e); - } - log.info("Completed task: %s", taskID); - } - } - ); - producedCount.incrementAndGet(); - queueShouldBeFullSignal.countDown(); - } - } - } - ); - - queueShouldBeFullSignal.await(); - // Verify that the producer blocks. I don't think it's possible to be sure that the producer is blocking (since - // it could be doing nothing for any reason). But waiting a short period of time and checking that it hasn't done - // anything should hopefully be sufficient. - Thread.sleep(500); - Assert.assertEquals(capacity + 1, producedCount.get()); - // let the tasks run - taskStartSignal.countDown(); - // wait until all tasks complete - taskCompletedSignal.await(); - // verify all tasks consumed - Assert.assertEquals(nTasks, consumedCount.get()); - // cleanup - blockingExecutor.shutdown(); - producer.shutdown(); - } -} diff --git a/core/src/test/java/org/apache/druid/java/util/common/guava/MergeSequenceTest.java b/core/src/test/java/org/apache/druid/java/util/common/guava/MergeSequenceTest.java index 61af63869487..bd2044519466 100644 --- a/core/src/test/java/org/apache/druid/java/util/common/guava/MergeSequenceTest.java +++ b/core/src/test/java/org/apache/druid/java/util/common/guava/MergeSequenceTest.java @@ -21,17 +21,35 @@ import com.google.common.collect.Lists; import com.google.common.collect.Ordering; -import junit.framework.Assert; +import org.junit.Assert; import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Stream; /** */ public class MergeSequenceTest { + public static > Supplier> naturalMergeSupplier( + Supplier>> stream + ) + { + return () -> naturalMerge(stream.get()); + } + + public static > Sequence naturalMerge(Stream> stream) + { + return new MergeSequence<>( + Ordering.natural(), + Sequences.fromStream(stream) + ); + } + @Test public void testSanity() throws Exception { @@ -40,13 +58,23 @@ public void testSanity() throws Exception TestSequence.create(2, 8), TestSequence.create(4, 6, 8) ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9)); + MergeSequence seq = new MergeSequence<>( + Ordering.natural(), + (Sequence) Sequences.simple(testSeqs) + ); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); + } @Test @@ -58,12 +86,19 @@ public void testWorksWhenBeginningOutOfOrder() throws Exception TestSequence.create(4, 6, 8) ); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9)); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + MergeSequence seq = new MergeSequence<>(Ordering.natural(), Sequences.simple(testSeqs)); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); } @Test @@ -76,12 +111,22 @@ public void testMergeEmpties() throws Exception TestSequence.create(4, 6, 8) ); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9)); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + MergeSequence seq = new MergeSequence<>( + Ordering.natural(), + Sequences.simple(testSeqs) + ); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); } @Test @@ -94,12 +139,22 @@ public void testMergeEmpties1() throws Exception TestSequence.create(4, 6, 8) ); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9)); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + MergeSequence seq = new MergeSequence<>( + Ordering.natural(), + Sequences.simple(testSeqs) + ); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); } @Test @@ -113,12 +168,22 @@ public void testMergeEmpties2() throws Exception TestSequence.create() ); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9)); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + MergeSequence seq = new MergeSequence<>( + Ordering.natural(), + Sequences.simple(testSeqs) + ); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); } @Test @@ -130,12 +195,19 @@ public void testScrewsUpOnOutOfOrder() throws Exception TestSequence.create(4, 6) ); - MergeSequence seq = new MergeSequence<>(Ordering.natural(), (Sequence) Sequences.simple(testSeqs)); - SequenceTestHelper.testAll(seq, Arrays.asList(1, 2, 3, 4, 5, 4, 6, 7, 8, 9)); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 4, 6, 7, 8, 9); + + MergeSequence seq = new MergeSequence<>(Ordering.natural(), Sequences.simple(testSeqs)); + SequenceTestHelper.testAll(seq, expected); for (TestSequence sequence : testSeqs) { Assert.assertTrue(sequence.isClosed()); } + + SequenceTestHelper.testAll( + naturalMergeSupplier(testSeqs::stream), + expected + ); } @Test @@ -161,9 +233,12 @@ public void testHierarchicalMerge() throws Exception public void testMergeOne() throws Exception { final Sequence mergeOne = new MergeSequence<>( - Ordering.natural(), Sequences.simple( - Collections.singletonList(TestSequence.create(1)) - ) + Ordering.natural(), + Sequences.>simple( + Collections.singletonList( + TestSequence.create(1) + ) + ) ); SequenceTestHelper.testAll(mergeOne, Collections.singletonList(1)); diff --git a/core/src/test/java/org/apache/druid/java/util/common/guava/MergeWorkTaskTest.java b/core/src/test/java/org/apache/druid/java/util/common/guava/MergeWorkTaskTest.java new file mode 100644 index 000000000000..6969de33aaa4 --- /dev/null +++ b/core/src/test/java/org/apache/druid/java/util/common/guava/MergeWorkTaskTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.druid.java.util.common.guava; + +import com.google.common.collect.Lists; +import com.google.common.collect.Ordering; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class MergeWorkTaskTest +{ + @Test + public void testNotParallelSequence() throws Exception + { + final ArrayList> testSeqs = Lists.newArrayList( + TestSequence.create(1, 3, 5, 7, 9), + TestSequence.create(2, 8), + TestSequence.create(4, 6, 8) + ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + testSeqs.stream(), + s -> new MergeSequence<>(Ordering.natural(), Sequences.fromStream(s)), + 999, + ForkJoinPool.commonPool() + ), expected); + } + + @Test + public void testOneBatchParallelSequence() throws Exception + { + final ArrayList> testSeqs = Lists.newArrayList( + TestSequence.create(1, 3, 5, 7, 9), + TestSequence.create(2, 8), + TestSequence.create(4, 6, 8) + ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + testSeqs.stream().parallel(), + MergeSequenceTest::naturalMerge, + 999, + ForkJoinPool.commonPool() + ), expected); + } + + @Test + public void testAllBatchParallelSequence() throws Exception + { + final ArrayList> testSeqs = Lists.newArrayList( + TestSequence.create(1, 3, 5, 7, 9), + TestSequence.create(2, 8), + TestSequence.create(4, 6, 8) + ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + testSeqs.stream().parallel(), + MergeSequenceTest::naturalMerge, + 1, + ForkJoinPool.commonPool() + ), expected); + } + + @Test + public void testSomeBatchParallelSequence() throws Exception + { + final ArrayList> testSeqs = Lists.newArrayList( + TestSequence.create(1, 3, 5, 7, 9), + TestSequence.create(2, 8), + TestSequence.create(4, 6, 8) + ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + testSeqs.stream().parallel(), + MergeSequenceTest::naturalMerge, + 2, + ForkJoinPool.commonPool() + ), expected); + } + + + @Test + public void testFJPChoke() throws Exception + { + final ArrayList> testSeqs = Lists.newArrayList( + TestSequence.create(1, 3, 5, 7, 9), + TestSequence.create(2, 8), + TestSequence.create(4, 6, 8) + ); + final List expected = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 8, 9); + final AtomicReference exception = new AtomicReference<>(null); + final ForkJoinPool fjp = new ForkJoinPool( + 1, + pool -> Execs.makeWorkerThread("test-%s", pool), + (t, e) -> exception.set(e), + false + ); + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + testSeqs.stream().parallel(), + MergeSequenceTest::naturalMerge, + 1, + fjp + ), expected); + fjp.shutdown(); + Assert.assertTrue(fjp.awaitTermination(5, TimeUnit.SECONDS)); + Assert.assertNull(exception.get()); + } + + @Test + public void testBigMerge() throws Exception + { + final AtomicReference exception = new AtomicReference<>(null); + final ForkJoinPool fjp = new ForkJoinPool( + 4, + pool -> Execs.makeWorkerThread("test-%s", pool), + (t, e) -> exception.set(e), + false + ); + try (AutoCloseable closeable = fjp::shutdown) { + // Take a big list of numbers, scatter them among a bunch of different buckets, then make sure the parallel merge + // returns the original list + + final List intList = IntStream.range(0, 10000).boxed().collect(Collectors.toList()); + final List> listList = new ArrayList<>(); + for (int i = 0; i < 500; i++) { + listList.add(new ArrayList<>()); + } + final Random r = new Random(37489165L); + intList.forEach(i -> listList.get(r.nextInt(listList.size())).add(i)); + SequenceTestHelper.testAll(() -> MergeWorkTask.parallelMerge( + listList.stream( + ).map( + TestSequence::create + ).parallel(), + MergeSequenceTest::naturalMerge, + 10, + fjp + ), intList); + } + Assert.assertTrue(fjp.awaitTermination(5, TimeUnit.SECONDS)); + Assert.assertNull(exception.get()); + } +} diff --git a/core/src/test/java/org/apache/druid/java/util/common/guava/SequenceTestHelper.java b/core/src/test/java/org/apache/druid/java/util/common/guava/SequenceTestHelper.java index e4450210ae24..31e266f4ff4b 100644 --- a/core/src/test/java/org/apache/druid/java/util/common/guava/SequenceTestHelper.java +++ b/core/src/test/java/org/apache/druid/java/util/common/guava/SequenceTestHelper.java @@ -19,29 +19,47 @@ package org.apache.druid.java.util.common.guava; -import junit.framework.Assert; +import org.junit.Assert; import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; /** */ public class SequenceTestHelper { + public static void testAll(Sequence seq, List nums) throws IOException + { + testAll(() -> seq, nums); + } + + public static void testAll(Supplier> seq, List nums) throws IOException { testAll("", seq, nums); } public static void testAll(String prefix, Sequence seq, List nums) throws IOException + { + testAll(prefix, () -> seq, nums); + } + + public static void testAll(String prefix, Supplier> seq, List nums) throws IOException { testAccumulation(prefix, seq, nums); testYield(prefix, seq, nums); } public static void testYield(final String prefix, Sequence seq, final List nums) throws IOException + { + testYield(prefix, () -> seq, nums); + } + + public static void testYield(final String prefix, Supplier> seq, final List nums) + throws IOException { testYield(prefix, 3, seq, nums); testYield(prefix, 1, seq, nums); @@ -53,9 +71,19 @@ public static void testYield( Sequence seq, final List nums ) throws IOException + { + testYield(prefix, numToTake, () -> seq, nums); + } + + public static void testYield( + final String prefix, + final int numToTake, + Supplier> seq, + final List nums + ) throws IOException { Iterator numsIter = nums.iterator(); - Yielder yielder = seq.toYielder( + Yielder yielder = seq.get().toYielder( 0, new YieldingAccumulator() { @@ -97,15 +125,20 @@ public Integer accumulate(Integer accumulated, Integer in) yielder.close(); } - public static void testAccumulation(final String prefix, Sequence seq, final List nums) + { + testAccumulation(prefix, () -> seq, nums); + } + + public static void testAccumulation(final String prefix, Supplier> seq, final List nums) { int expectedSum = 0; for (Integer num : nums) { expectedSum += num; } - int sum = seq.accumulate( + + int sum = seq.get().accumulate( 0, new Accumulator() { diff --git a/docs/content/querying/query-context.md b/docs/content/querying/query-context.md index 3df148b20b57..42ead5ecff87 100644 --- a/docs/content/querying/query-context.md +++ b/docs/content/querying/query-context.md @@ -23,6 +23,7 @@ The query context is used for various query configuration parameters. The follow |maxQueuedBytes | `druid.broker.http.maxQueuedBytes` | Maximum number of bytes queued per query before exerting backpressure on the channel to the data server. Similar to `maxScatterGatherBytes`, except unlike that configuration, this one will trigger backpressure rather than query failure. Zero means disabled.| |serializeDateTimeAsLong| `false` | If true, DateTime is serialized as long in the result returned by broker and the data transportation between broker and compute node| |serializeDateTimeAsLongInner| `false` | If true, DateTime is serialized as long in the data transportation between broker and compute node| +|intermediateMergeBatchThreshold|none, do not use|(EXPERIMENTAL) (positive integer) If present, will attempt to do parallel intermediate merges at the broker with a batch size of `intermediateMergeBatchThreshold` server results. This can greatly speed up the merging of large result sets across a large quantity of servers. The maximum number of active parallel merges across all queries is currently tied to the same count as the Common ForkJoinPool and can be set for both by the system property `java.util.concurrent.ForkJoinPool.common.parallelism`| In addition, some query types offer context parameters specific to that query type. diff --git a/processing/src/main/java/org/apache/druid/guice/ForkJoinPoolProvider.java b/processing/src/main/java/org/apache/druid/guice/ForkJoinPoolProvider.java new file mode 100644 index 000000000000..2db413acc057 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/guice/ForkJoinPoolProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice; + +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; + +import javax.inject.Provider; +import java.util.concurrent.ForkJoinPool; + +public class ForkJoinPoolProvider implements Provider +{ + private static final Logger LOG = new Logger(ForkJoinPoolProvider.class); + + private final String nameFormat; + + public ForkJoinPoolProvider(String nameFormat) + { + // Fail fast on bad name format + Execs.checkThreadNameFormat(nameFormat); + this.nameFormat = nameFormat; + } + + @Override + public LifecycleForkJoinPool get() + { + return new LifecycleForkJoinPool( + // This should probably be configurable. Until then, just piggyback off the common pool's parallelism + ForkJoinPool.commonPool().getParallelism(), + pool -> Execs.makeWorkerThread(nameFormat, pool), + (t, e) -> LOG.error(e, "Unhandled exception in thread [%s]", t), + false + ); + } +} diff --git a/processing/src/main/java/org/apache/druid/guice/LifecycleForkJoinPool.java b/processing/src/main/java/org/apache/druid/guice/LifecycleForkJoinPool.java new file mode 100644 index 000000000000..c06b82468284 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/guice/LifecycleForkJoinPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice; + + +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.java.util.common.logger.Logger; + +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; + +public class LifecycleForkJoinPool extends ForkJoinPool +{ + private static final Logger LOG = new Logger(LifecycleForkJoinPool.class); + + public LifecycleForkJoinPool( + int parallelism, + ForkJoinWorkerThreadFactory factory, + Thread.UncaughtExceptionHandler handler, + boolean asyncMode + ) + { + super(parallelism, factory, handler, asyncMode); + } + + @LifecycleStop + public void stop() + { + LOG.info("Shutting down ForkJoinPool [%s]", this); + shutdown(); + try { + // This should be configurable https://github.com/apache/incubator-druid/issues/6264 + if (!awaitTermination(1, TimeUnit.MINUTES)) { + LOG.warn("Failed to complete all tasks in FJP [%s]", this); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("interrupted on shutdown", e); + } + } +} diff --git a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java index 0cfce5766090..642eba466772 100644 --- a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java @@ -20,14 +20,14 @@ package org.apache.druid.query; import com.google.common.base.Throwables; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import com.google.common.collect.Ordering; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import org.apache.druid.common.guava.GuavaUtils; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.BaseSequence; import org.apache.druid.java.util.common.guava.MergeIterable; import org.apache.druid.java.util.common.guava.Sequence; @@ -40,8 +40,9 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; /** * A QueryRunner that combines a list of other QueryRunners and executes them in parallel on an executor. @@ -60,7 +61,7 @@ public class ChainedExecutionQueryRunner implements QueryRunner { private static final Logger log = new Logger(ChainedExecutionQueryRunner.class); - private final Iterable> queryables; + private final Stream> queryables; private final ListeningExecutorService exec; private final QueryWatcher queryWatcher; @@ -70,7 +71,7 @@ public ChainedExecutionQueryRunner( QueryRunner... queryables ) { - this(exec, queryWatcher, Arrays.asList(queryables)); + this(exec, queryWatcher, Arrays.stream(queryables)); } public ChainedExecutionQueryRunner( @@ -78,79 +79,92 @@ public ChainedExecutionQueryRunner( QueryWatcher queryWatcher, Iterable> queryables ) + { + this(exec, queryWatcher, StreamSupport.stream(queryables.spliterator(), false)); + } + + public ChainedExecutionQueryRunner( + ExecutorService exec, + QueryWatcher queryWatcher, + Stream> queryables + ) { // listeningDecorator will leave PrioritizedExecutorService unchanged, // since it already implements ListeningExecutorService this.exec = MoreExecutors.listeningDecorator(exec); - this.queryables = Iterables.unmodifiableIterable(queryables); this.queryWatcher = queryWatcher; + this.queryables = queryables; } @Override public Sequence run(final QueryPlus queryPlus, final Map responseContext) { - Query query = queryPlus.getQuery(); + final Query query = queryPlus.getQuery(); final int priority = QueryContexts.getPriority(query); - final Ordering ordering = query.getResultOrdering(); + final Ordering ordering = query.getResultOrdering(); final QueryPlus threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); - return new BaseSequence>( + return new BaseSequence<>( new BaseSequence.IteratorMaker>() { @Override public Iterator make() { // Make it a List<> to materialize all of the values (so that it will submit everything to the executor) - ListenableFuture>> futures = Futures.allAsList( - Lists.newArrayList( - Iterables.transform( - queryables, - input -> { - if (input == null) { - throw new ISE("Null queryRunner! Looks to be some segment unmapping action happening"); + final ListenableFuture>> futures = GuavaUtils.allFuturesAsList( + queryables.map( + // Don't use peek here: https://github.com/apache/incubator-druid/pull/5913#discussion_r213472699 + queryRunner -> { + if (queryRunner == null) { + throw new ISE("Null queryRunner! Looks to be some segment unmapping action happening"); + } + return queryRunner; + } + ).map( + queryRunner -> new AbstractPrioritizedCallable>(priority) + { + @Override + public Iterable call() + { + try { + Sequence result = queryRunner.run(threadSafeQueryPlus, responseContext); + if (result == null) { + throw new ISE("Got a null result! Segments are missing!"); } - return exec.submit( - new AbstractPrioritizedCallable>(priority) - { - @Override - public Iterable call() - { - try { - Sequence result = input.run(threadSafeQueryPlus, responseContext); - if (result == null) { - throw new ISE("Got a null result! Segments are missing!"); - } - - List retVal = result.toList(); - if (retVal == null) { - throw new ISE("Got a null list of results! WTF?!"); - } + List retVal = result.toList(); + if (retVal == null) { + throw new ISE("Got a null list of results! WTF?!"); + } - return retVal; - } - catch (QueryInterruptedException e) { - throw Throwables.propagate(e); - } - catch (Exception e) { - log.error(e, "Exception with one of the sequences!"); - throw Throwables.propagate(e); - } - } - } - ); + return retVal; + } + catch (QueryInterruptedException e) { + throw Throwables.propagate(e); + } + catch (Exception e) { + log.error(e, "Exception with one of the sequences!"); + throw Throwables.propagate(e); } - ) - ) + } + } + ).map(exec::submit) ); queryWatcher.registerQuery(query, futures); try { + final List> result; + if (QueryContexts.hasTimeout(query)) { + result = Execs.futureManagedBlockGet( + futures, + DateTimes.nowUtc().plusMillis((int) QueryContexts.getTimeout(query)) + ); + } else { + result = Execs.futureManagedBlockGet(futures); + } return new MergeIterable<>( ordering.nullsFirst(), - QueryContexts.hasTimeout(query) ? - futures.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) : - futures.get() + result ).iterator(); } catch (InterruptedException e) { diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 3399aa9b1a26..60c5397814b6 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -25,6 +25,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Numbers; +import java.util.OptionalLong; import java.util.concurrent.TimeUnit; @PublicApi @@ -36,6 +37,7 @@ public class QueryContexts public static final String MAX_QUEUED_BYTES_KEY = "maxQueuedBytes"; public static final String DEFAULT_TIMEOUT_KEY = "defaultTimeout"; public static final String CHUNK_PERIOD_KEY = "chunkPeriod"; + public static final String INTERMEDIATE_MERGE_BATCH_THRESHOLD = "intermediateMergeBatchThreshold"; public static final boolean DEFAULT_BY_SEGMENT = false; public static final boolean DEFAULT_POPULATE_CACHE = true; @@ -172,6 +174,26 @@ public static Query verifyMaxQueryTimeout(Query query, long maxQueryTi } } + /** + * Return an optional long of the batch size. If the batch is less than 1 (0 or negative) then just return empty + * + * @param query The query whose context is to be used + * @param The query result type + * + * @return An optional long which, if present, will only be a positive long + */ + public static OptionalLong getIntermediateMergeBatchThreshold(Query query) + { + final OptionalLong optionalLong = parseLong(query, INTERMEDIATE_MERGE_BATCH_THRESHOLD); + if (!optionalLong.isPresent()) { + return optionalLong; + } + if (optionalLong.getAsLong() < 1) { + return OptionalLong.empty(); + } + return optionalLong; + } + public static long getMaxQueuedBytes(Query query, long defaultValue) { return parseLong(query, MAX_QUEUED_BYTES_KEY, defaultValue); @@ -222,6 +244,12 @@ static long parseLong(Query query, String key, long defaultValue) return val == null ? defaultValue : Numbers.parseLong(val); } + static OptionalLong parseLong(Query query, String key) + { + final Object val = query.getContextValue(key); + return val == null ? OptionalLong.empty() : OptionalLong.of(Numbers.parseLong(val)); + } + static int parseInt(Query query, String key, int defaultValue) { final Object val = query.getContextValue(key); diff --git a/processing/src/main/java/org/apache/druid/query/QueryRunner.java b/processing/src/main/java/org/apache/druid/query/QueryRunner.java index a7d62d4514bd..7461db4d7f20 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/QueryRunner.java @@ -31,4 +31,10 @@ public interface QueryRunner * Runs the given query and returns results in a time-ordered sequence. */ Sequence run(QueryPlus queryPlus, Map responseContext); + + @SuppressWarnings("unchecked") + static QueryRunner of(Sequence s) + { + return (ignored0, ignored1) -> (Sequence) s; + } } diff --git a/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryQueryToolChest.java index 554835ae5c95..5a2048e4b08a 100644 --- a/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryQueryToolChest.java @@ -111,7 +111,8 @@ public Sequence doRun( Map context ) { - SegmentMetadataQuery updatedQuery = ((SegmentMetadataQuery) queryPlus.getQuery()).withFinalizedAnalysisTypes(config); + SegmentMetadataQuery updatedQuery = ((SegmentMetadataQuery) queryPlus.getQuery()) + .withFinalizedAnalysisTypes(config); QueryPlus updatedQueryPlus = queryPlus.withQuery(updatedQuery); return new MappedSequence<>( CombiningSequence.create( diff --git a/processing/src/test/java/org/apache/druid/guice/ForkJoinPoolProviderTest.java b/processing/src/test/java/org/apache/druid/guice/ForkJoinPoolProviderTest.java new file mode 100644 index 000000000000..79d536c4c167 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/guice/ForkJoinPoolProviderTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; + +public class ForkJoinPoolProviderTest +{ + private static final String GOOD_NAME_FORMAT = "test-fjp-%d"; + private static final ForkJoinPoolProvider FORK_JOIN_POOL_PROVIDER = new ForkJoinPoolProvider(GOOD_NAME_FORMAT); + + @Test + public void testThreadThrowsException() throws InterruptedException + { + final ForkJoinPool fjp = FORK_JOIN_POOL_PROVIDER.get(); + final RuntimeException re = new RuntimeException("test exception"); + try { + fjp.submit(() -> { + throw re; + }).get(); + } + catch (ExecutionException e) { + if (!re.equals(e.getCause().getCause())) { + throw new RuntimeException("Unexpected exception", e); + } + return; + } + Assert.fail("Should have thrown exception"); + } + + @Test + public void testThreadSwallowsException() + { + final ForkJoinPool fjp = FORK_JOIN_POOL_PROVIDER.get(); + final RuntimeException re = new RuntimeException("test exception"); + fjp.execute(() -> { + throw re; + }); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java b/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java index 14e7af3bd87a..1aaf320ae9eb 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java +++ b/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java @@ -20,11 +20,13 @@ package org.apache.druid.query; import com.google.common.base.Function; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.util.concurrent.MoreExecutors; +import org.apache.druid.collections.StupidPool; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; @@ -48,16 +50,40 @@ import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.GroupByQueryConfig; +import org.apache.druid.query.groupby.GroupByQueryRunnerTest; +import org.apache.druid.query.groupby.strategy.GroupByStrategySelector; +import org.apache.druid.query.metadata.SegmentMetadataQueryConfig; +import org.apache.druid.query.metadata.SegmentMetadataQueryQueryToolChest; +import org.apache.druid.query.metadata.SegmentMetadataQueryRunnerFactory; +import org.apache.druid.query.metadata.metadata.SegmentMetadataQuery; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.scan.ScanQueryConfig; +import org.apache.druid.query.scan.ScanQueryEngine; +import org.apache.druid.query.scan.ScanQueryQueryToolChest; +import org.apache.druid.query.scan.ScanQueryRunnerFactory; +import org.apache.druid.query.select.SelectQuery; +import org.apache.druid.query.select.SelectQueryConfig; +import org.apache.druid.query.select.SelectQueryEngine; +import org.apache.druid.query.select.SelectQueryQueryToolChest; +import org.apache.druid.query.select.SelectQueryRunnerFactory; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.query.spec.SpecificSegmentSpec; +import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesQueryEngine; import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest; import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory; +import org.apache.druid.query.topn.TopNQuery; +import org.apache.druid.query.topn.TopNQueryConfig; +import org.apache.druid.query.topn.TopNQueryQueryToolChest; +import org.apache.druid.query.topn.TopNQueryRunnerFactory; import org.apache.druid.segment.IncrementalIndexSegment; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.QueryableIndexSegment; import org.apache.druid.segment.Segment; +import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestIndex; import org.apache.druid.segment.incremental.IncrementalIndex; import org.apache.druid.timeline.TimelineObjectHolder; @@ -66,6 +92,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -80,7 +107,8 @@ public class QueryRunnerTestHelper { - public static final QueryWatcher NOOP_QUERYWATCHER = (query, future) -> {}; + public static final QueryWatcher NOOP_QUERYWATCHER = (query, future) -> { + }; public static final String segmentId = "testSegment"; public static final String dataSource = "testing"; @@ -275,6 +303,104 @@ public Object[] apply(@Nullable Object input) ); } + + public static final Map, QueryRunnerFactory> DEFAULT_CONGLOMERATE_MAP = ImmutableMap + ., QueryRunnerFactory>builder() + .put( + SegmentMetadataQuery.class, + new SegmentMetadataQueryRunnerFactory( + new SegmentMetadataQueryQueryToolChest( + new SegmentMetadataQueryConfig("P1W") + ), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ) + ) + .put( + ScanQuery.class, + new ScanQueryRunnerFactory( + new ScanQueryQueryToolChest( + new ScanQueryConfig(), + new DefaultGenericQueryMetricsFactory(TestHelper.makeJsonMapper()) + ), + new ScanQueryEngine() + ) + ) + .put( + SelectQuery.class, + new SelectQueryRunnerFactory( + new SelectQueryQueryToolChest( + TestHelper.makeJsonMapper(), + QueryRunnerTestHelper.NoopIntervalChunkingQueryRunnerDecorator(), + Suppliers.ofInstance( + new SelectQueryConfig(true) + ) + ), + new SelectQueryEngine(), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ) + ) + .put( + TimeseriesQuery.class, + new TimeseriesQueryRunnerFactory( + new TimeseriesQueryQueryToolChest( + QueryRunnerTestHelper.NoopIntervalChunkingQueryRunnerDecorator() + ), + new TimeseriesQueryEngine(), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ) + ) + .put( + TopNQuery.class, + new TopNQueryRunnerFactory( + new StupidPool<>( + "test-TopNQueryRunnerFactory-bufferPool", + () -> ByteBuffer.allocate(10 << 20) + ), + new TopNQueryQueryToolChest( + new TopNQueryConfig(), + QueryRunnerTestHelper.NoopIntervalChunkingQueryRunnerDecorator() + ), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ) + ) + .put( + GroupByQuery.class, + GroupByQueryRunnerTest.makeQueryRunnerFactory( + GroupByQueryRunnerTest.DEFAULT_MAPPER, + new GroupByQueryConfig() + { + @Override + public String getDefaultStrategy() + { + return GroupByStrategySelector.STRATEGY_V2; + } + }, + new DruidProcessingConfig() + { + @Override + public String getFormatString() + { + return "test-processing-%s"; + } + + @Override + public int intermediateComputeSizeBytes() + { + return 10 << 20; + } + + @Override + public int getNumMergeBuffers() + { + // Need 3 buffers for CalciteQueryTest.testDoubleNestedGroupby. + // Two buffers for the broker and one for the queryable + return 3; + } + } + ).getLhs() + ) + .build(); + // simple cartesian iterable public static Iterable cartesian(final Iterable... iterables) { diff --git a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java index 99df5f0a06f9..6c691311ed00 100644 --- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java +++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java @@ -21,15 +21,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Function; -import com.google.common.base.Optional; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; import com.google.common.collect.Ordering; -import com.google.common.collect.RangeSet; -import com.google.common.collect.Sets; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; import com.google.inject.Inject; @@ -40,23 +34,27 @@ import org.apache.druid.client.selector.QueryableDruidServer; import org.apache.druid.client.selector.ServerSelector; import org.apache.druid.guice.annotations.Client; +import org.apache.druid.guice.annotations.Processing; import org.apache.druid.guice.annotations.Smile; import org.apache.druid.guice.http.DruidHttpClientConfig; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; -import org.apache.druid.java.util.common.guava.BaseSequence; -import org.apache.druid.java.util.common.guava.LazySequence; +import org.apache.druid.java.util.common.guava.MergeSequence; +import org.apache.druid.java.util.common.guava.MergeWorkTask; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.query.BySegmentResultValueClass; import org.apache.druid.query.CacheStrategy; +import org.apache.druid.query.FluentQueryRunnerBuilder; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; +import org.apache.druid.query.QueryRunnerFactory; +import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChestWarehouse; @@ -65,8 +63,10 @@ import org.apache.druid.query.aggregation.MetricManipulatorFns; import org.apache.druid.query.filter.DimFilterUtils; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; +import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryResource; import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineLookup; import org.apache.druid.timeline.TimelineObjectHolder; @@ -80,44 +80,69 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; -import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.SortedMap; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Spliterators; +import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; /** */ public class CachingClusteredClient implements QuerySegmentWalker { private static final EmittingLogger log = new EmittingLogger(CachingClusteredClient.class); + private static final DruidServer ALREADY_CACHED_SERVER = new DruidServer( + new DruidNode( + "__internal-client-cache", + "localhost", + false, + -1, + -1, + true, + false + ), + new DruidServerConfig(), + ServerType.HISTORICAL + ); + private final QueryRunnerFactoryConglomerate conglomerate; private final QueryToolChestWarehouse warehouse; private final TimelineServerView serverView; private final Cache cache; private final ObjectMapper objectMapper; private final CachePopulator cachePopulator; private final CacheConfig cacheConfig; + private final ForkJoinPool mergeFjp; private final DruidHttpClientConfig httpClientConfig; @Inject public CachingClusteredClient( + QueryRunnerFactoryConglomerate conglomerate, QueryToolChestWarehouse warehouse, TimelineServerView serverView, Cache cache, @Smile ObjectMapper objectMapper, + @Processing ForkJoinPool mergeFjp, CachePopulator cachePopulator, CacheConfig cacheConfig, @Client DruidHttpClientConfig httpClientConfig ) { + this.conglomerate = conglomerate; this.warehouse = warehouse; this.serverView = serverView; this.cache = cache; this.objectMapper = objectMapper; this.cachePopulator = cachePopulator; this.cacheConfig = cacheConfig; + this.mergeFjp = mergeFjp; this.httpClientConfig = httpClientConfig; if (cacheConfig.isQueryCacheable(Query.GROUP_BY) && (cacheConfig.isUseCache() || cacheConfig.isPopulateCache())) { @@ -144,57 +169,77 @@ public ServerView.CallbackAction segmentRemoved(DruidServerMetadata server, Data @Override public QueryRunner getQueryRunnerForIntervals(final Query query, final Iterable intervals) { - return new QueryRunner() - { - @Override - public Sequence run(final QueryPlus queryPlus, final Map responseContext) - { - return CachingClusteredClient.this.run(queryPlus, responseContext, timeline -> timeline); - } - }; + return runAndMergeWithTimelineChange(query, UnaryOperator.identity()); } /** * Run a query. The timelineConverter will be given the "master" timeline and can be used to return a different * timeline, if desired. This is used by getQueryRunnerForSegments. */ - private Sequence run( + @VisibleForTesting + Stream> run( final QueryPlus queryPlus, final Map responseContext, final UnaryOperator> timelineConverter ) { - return new SpecificQueryRunnable<>(queryPlus, responseContext).run(timelineConverter); + return new SpecificQueryRunnable<>(queryPlus, responseContext).runByServer(timelineConverter); + } + + private QueryRunner runAndMergeWithTimelineChange( + final Query query, + final UnaryOperator> timelineConverter + ) + { + final OptionalLong mergeBatch = QueryContexts.getIntermediateMergeBatchThreshold(query); + + if (mergeBatch.isPresent()) { + final QueryRunnerFactory> queryRunnerFactory = conglomerate.findFactory(query); + final QueryToolChest> toolChest = queryRunnerFactory.getToolchest(); + return (queryPlus, responseContext) -> { + final Stream> sequences = run(queryPlus, responseContext, timelineConverter); + return MergeWorkTask.parallelMerge( + sequences.parallel(), + (Stream> sequenceStream) -> + new FluentQueryRunnerBuilder<>(toolChest) + .create( + queryRunnerFactory.mergeRunners( + mergeFjp, + sequenceStream.map(QueryRunner::of).collect(Collectors.toList()) + ) + ) + .mergeResults() + .run(queryPlus, responseContext), + mergeBatch.getAsLong(), + mergeFjp + ); + }; + } else { + return (queryPlus, responseContext) -> { + final Stream> sequences = run(queryPlus, responseContext, timelineConverter); + return new MergeSequence<>(query.getResultOrdering(), Sequences.fromStream(sequences)); + }; + } } @Override public QueryRunner getQueryRunnerForSegments(final Query query, final Iterable specs) { - return new QueryRunner() - { - @Override - public Sequence run(final QueryPlus queryPlus, final Map responseContext) - { - return CachingClusteredClient.this.run( - queryPlus, - responseContext, - timeline -> { - final VersionedIntervalTimeline timeline2 = - new VersionedIntervalTimeline<>(Ordering.natural()); - for (SegmentDescriptor spec : specs) { - final PartitionHolder entry = timeline.findEntry(spec.getInterval(), spec.getVersion()); - if (entry != null) { - final PartitionChunk chunk = entry.getChunk(spec.getPartitionNumber()); - if (chunk != null) { - timeline2.add(spec.getInterval(), spec.getVersion(), chunk); - } - } - } - return timeline2; - } - ); + return runAndMergeWithTimelineChange(query, timeline -> { + final VersionedIntervalTimeline timeline2 = new VersionedIntervalTimeline<>( + Ordering.natural() + ); + for (SegmentDescriptor spec : specs) { + final PartitionHolder entry = timeline.findEntry(spec.getInterval(), spec.getVersion()); + if (entry != null) { + final PartitionChunk chunk = entry.getChunk(spec.getPartitionNumber()); + if (chunk != null) { + timeline2.add(spec.getInterval(), spec.getVersion(), chunk); + } + } } - }; + return timeline2; + }); } /** @@ -235,9 +280,9 @@ private class SpecificQueryRunnable this.downstreamQuery = query.withOverriddenContext(makeDownstreamQueryContext()); } - private ImmutableMap makeDownstreamQueryContext() + private Map makeDownstreamQueryContext() { - final ImmutableMap.Builder contextBuilder = new ImmutableMap.Builder<>(); + final Map contextBuilder = new LinkedHashMap<>(); final int priority = QueryContexts.getPriority(query); contextBuilder.put(QueryContexts.PRIORITY_KEY, priority); @@ -247,74 +292,73 @@ private ImmutableMap makeDownstreamQueryContext() contextBuilder.put(CacheConfig.POPULATE_CACHE, false); contextBuilder.put("bySegment", true); } - return contextBuilder.build(); + return Collections.unmodifiableMap(contextBuilder); } - Sequence run(final UnaryOperator> timelineConverter) + /** + * This is the main workflow for the query setup. The sequences are created but not accumulated here. + * + * @param timelineConverter Any manipulations to the timeline that need done + * + * @return A stream of the sequences. Each sequence is either a server result or the total cache result. A + * spliterator on the returned stream should be sized and subsized. + */ + Stream> runByServer(final UnaryOperator> timelineConverter) { @Nullable TimelineLookup timeline = serverView.getTimeline(query.getDataSource()); if (timeline == null) { - return Sequences.empty(); + return Stream.empty(); } timeline = timelineConverter.apply(timeline); if (uncoveredIntervalsLimit > 0) { computeUncoveredIntervals(timeline); } - final Set segments = computeSegmentsToQuery(timeline); + Stream segments = computeSegmentsToQuery(timeline); @Nullable final byte[] queryCacheKey = computeQueryCacheKey(); if (query.getContext().get(QueryResource.HEADER_IF_NONE_MATCH) != null) { + // Materialize for computeCurrentEtag, then re-stream + final List materializedSegments = segments.collect(Collectors.toList()); + segments = materializedSegments.stream(); + @Nullable final String prevEtag = (String) query.getContext().get(QueryResource.HEADER_IF_NONE_MATCH); @Nullable - final String currentEtag = computeCurrentEtag(segments, queryCacheKey); + final String currentEtag = computeCurrentEtag(materializedSegments, queryCacheKey); if (currentEtag != null && currentEtag.equals(prevEtag)) { - return Sequences.empty(); + return Stream.empty(); } } - final List> alreadyCachedResults = pruneSegmentsWithCachedResults(queryCacheKey, segments); - final SortedMap> segmentsByServer = groupSegmentsByServer(segments); - return new LazySequence<>(() -> { - List> sequencesByInterval = new ArrayList<>(alreadyCachedResults.size() + segmentsByServer.size()); - addSequencesFromCache(sequencesByInterval, alreadyCachedResults); - addSequencesFromServer(sequencesByInterval, segmentsByServer); - return Sequences - .simple(sequencesByInterval) - .flatMerge(seq -> seq, query.getResultOrdering()); - }); - } - - private Set computeSegmentsToQuery(TimelineLookup timeline) - { - final List> serversLookup = toolChest.filterSegments( - query, - query.getIntervals().stream().flatMap(i -> timeline.lookup(i).stream()).collect(Collectors.toList()) + // This pipeline follows a few general steps: + // 1. Fetch cache results - Unfortunately this is an eager operation so that the non cached items can + // be batched per server. Cached results are assigned to a mock server ALREADY_CACHED_SERVER + // 2. Group the segment information by server + // 3. Per server (including the ALREADY_CACHED_SERVER) create the appropriate Sequence results - cached results + // are handled in their own merge + final Stream>> cacheResolvedResults = deserializeFromCache( + maybeFetchCacheResults(queryCacheKey, segments) ); - - final Set segments = Sets.newLinkedHashSet(); - final Map>> dimensionRangeCache = Maps.newHashMap(); - // Filter unneeded chunks based on partition dimension - for (TimelineObjectHolder holder : serversLookup) { - final Set> filteredChunks = DimFilterUtils.filterShards( - query.getFilter(), - holder.getObject(), - partitionChunk -> partitionChunk.getObject().getSegment().getShardSpec(), - dimensionRangeCache - ); - for (PartitionChunk chunk : filteredChunks) { - ServerSelector server = chunk.getObject(); - final SegmentDescriptor segment = new SegmentDescriptor( - holder.getInterval(), - holder.getVersion(), - chunk.getChunkNumber() - ); - segments.add(new ServerToSegment(server, segment)); - } - } - return segments; + final Pair>>> serverCountAndStream = + groupCachedResultsByServer(cacheResolvedResults); + + // Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much. + final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, httpClientConfig.getMaxQueuedBytes()); + final long maxQueuedBytesPerServer = maxQueuedBytes / Math.max(serverCountAndStream.getLhs(), 1); + + return serverCountAndStream + .getRhs() + .map(s -> this.runOnServer(s, maxQueuedBytesPerServer)) + // We do a hard materialization here so that the resulting spliterators have properties that we want + // Otherwise the stream's spliterator is of a hash map entry spliterator from the group-by-server operation + // This also causes eager initialization of the **sequences**, aka forking off the direct druid client requests + // Sequence result accumulation should still be lazy + // + // See https://github.com/apache/incubator-druid/issues/6421 + .collect(Collectors.toList()) + .stream(); } private void computeUncoveredIntervals(TimelineLookup timeline) @@ -358,6 +402,41 @@ private void computeUncoveredIntervals(TimelineLookup ti } } + /** + * Create a stream of the partition chunks which are relevant to this query + * + * @param holder The holder of the shard to server component of the timeline + * + * @return Chunks and the segment descriptors corresponding to the chunk + */ + private Stream extractServerAndSegment(TimelineObjectHolder holder) + { + return DimFilterUtils + .filterShards( + query.getFilter(), + holder.getObject(), + partitionChunk -> partitionChunk.getObject().getSegment().getShardSpec(), + Maps.newHashMap() + ) + .stream() + .map(chunk -> new ServerToSegment( + chunk.getObject(), + new SegmentDescriptor(holder.getInterval(), holder.getVersion(), chunk.getChunkNumber()) + )); + } + + private Stream computeSegmentsToQuery(TimelineLookup timeline) + { + return toolChest + .filterSegments( + query, + query.getIntervals().stream().flatMap(i -> timeline.lookup(i).stream()).collect(Collectors.toList()) + ) + .stream() + .flatMap(this::extractServerAndSegment) + .distinct(); + } + @Nullable private byte[] computeQueryCacheKey() { @@ -371,7 +450,7 @@ private byte[] computeQueryCacheKey() } @Nullable - private String computeCurrentEtag(final Set segments, @Nullable byte[] queryCacheKey) + private String computeCurrentEtag(final Iterable segments, @Nullable byte[] queryCacheKey) { Hasher hasher = Hashing.sha1().newHasher(); boolean hasOnlyHistoricalSegments = true; @@ -394,173 +473,251 @@ private String computeCurrentEtag(final Set segments, @Nullable } } - private List> pruneSegmentsWithCachedResults( + private Pair> lookupInCache( + Pair key, + Map> cache + ) + { + final ServerToSegment segment = key.getLhs(); + final Cache.NamedKey segmentCacheKey = key.getRhs(); + final Interval segmentQueryInterval = segment.getSegmentDescriptor().getInterval(); + final Optional cachedValue = Optional + .ofNullable(cache.get(segmentCacheKey)) + // Shouldn't happen in practice, but can screw up unit tests where cache state is mutated in crazy + // ways when the cache returns null instead of an optional. + .orElse(Optional.empty()); + if (!cachedValue.isPresent()) { + // if populating cache, add segment to list of segments to cache if it is not cached + final String segmentIdentifier = segment.getServer().getSegment().getIdentifier(); + addCachePopulatorKey(segmentCacheKey, segmentIdentifier, segmentQueryInterval); + } + return Pair.of(segment, cachedValue); + } + + /** + * This materializes the input segment stream in order to let the BulkGet stuff in the cache system work + * + * @param queryCacheKey The cache key that is for the query (not-segment) portion + * @param segments The segments to check if they are in cache + * + * @return A stream of the server and segment combinations as well as an optional that is present + * if a cached value was found + */ + private Stream>> maybeFetchCacheResults( final byte[] queryCacheKey, - final Set segments + final Stream segments ) { if (queryCacheKey == null) { - return Collections.emptyList(); + return segments.map(s -> Pair.of(s, Optional.empty())); } - final List> alreadyCachedResults = Lists.newArrayList(); - Map perSegmentCacheKeys = computePerSegmentCacheKeys(segments, queryCacheKey); - // Pull cached segments from cache and remove from set of segments to query - final Map cachedValues = computeCachedValues(perSegmentCacheKeys); - - perSegmentCacheKeys.forEach((segment, segmentCacheKey) -> { - final Interval segmentQueryInterval = segment.getSegmentDescriptor().getInterval(); - - final byte[] cachedValue = cachedValues.get(segmentCacheKey); - if (cachedValue != null) { - // remove cached segment from set of segments to query - segments.remove(segment); - alreadyCachedResults.add(Pair.of(segmentQueryInterval, cachedValue)); - } else if (populateCache) { - // otherwise, if populating cache, add segment to list of segments to cache - final String segmentIdentifier = segment.getServer().getSegment().getIdentifier(); - addCachePopulatorKey(segmentCacheKey, segmentIdentifier, segmentQueryInterval); - } - }); - return alreadyCachedResults; + // We materialize the stream here in order to have the bulk cache fetching work as expected + final List> materializedKeyList = computePerSegmentCacheKeys( + segments, + queryCacheKey + ).collect(Collectors.toList()); + + // Do bulk fetch + final Map> cachedValues = computeCachedValues(materializedKeyList.stream()) + .collect(Pair.mapCollector()); + + // A limitation of the cache system is that the cached values are returned without passing through the original + // objects. This hash join is a way to get the ServerToSegment and Optional matched up again + return materializedKeyList + .stream() + .map(serializedPairSegmentAndKey -> lookupInCache(serializedPairSegmentAndKey, cachedValues)); } - private Map computePerSegmentCacheKeys( - Set segments, + private Stream> computePerSegmentCacheKeys( + Stream segments, byte[] queryCacheKey ) { - // cacheKeys map must preserve segment ordering, in order for shards to always be combined in the same order - Map cacheKeys = Maps.newLinkedHashMap(); - for (ServerToSegment serverToSegment : segments) { - final Cache.NamedKey segmentCacheKey = CacheUtil.computeSegmentCacheKey( - serverToSegment.getServer().getSegment().getIdentifier(), - serverToSegment.getSegmentDescriptor(), - queryCacheKey - ); - cacheKeys.put(serverToSegment, segmentCacheKey); - } - return cacheKeys; + return segments + .map(serverToSegment -> { + // cacheKeys map must preserve segment ordering, in order for shards to always be combined in the same order + final Cache.NamedKey segmentCacheKey = CacheUtil.computeSegmentCacheKey( + serverToSegment.getServer().getSegment().getIdentifier(), + serverToSegment.getSegmentDescriptor(), + queryCacheKey + ); + return Pair.of(serverToSegment, segmentCacheKey); + }); } - private Map computeCachedValues(Map cacheKeys) + private Stream>> computeCachedValues( + Stream> cacheKeys + ) { if (useCache) { - return cache.getBulk(Iterables.limit(cacheKeys.values(), cacheConfig.getCacheBulkMergeLimit())); + return cache.getBulk(cacheKeys.limit(cacheConfig.getCacheBulkMergeLimit()).map(Pair::getRhs)); } else { - return ImmutableMap.of(); + return Stream.empty(); } } + private String cacheKey(String segmentId, Interval segmentInterval) + { + return StringUtils.format("%s_%s", segmentId, segmentInterval); + } + private void addCachePopulatorKey( Cache.NamedKey segmentCacheKey, String segmentIdentifier, Interval segmentQueryInterval ) { - cachePopulatorKeyMap.put( - StringUtils.format("%s_%s", segmentIdentifier, segmentQueryInterval), - segmentCacheKey - ); + cachePopulatorKeyMap.put(cacheKey(segmentIdentifier, segmentQueryInterval), segmentCacheKey); } @Nullable private Cache.NamedKey getCachePopulatorKey(String segmentId, Interval segmentInterval) { - return cachePopulatorKeyMap.get(StringUtils.format("%s_%s", segmentId, segmentInterval)); + return cachePopulatorKeyMap.get(cacheKey(segmentId, segmentInterval)); } - private SortedMap> groupSegmentsByServer(Set segments) + /** + * Check the input stream to see what was cached and what was not. For the ones that were cached, merge the results + * and return the merged sequence. For the ones that were NOT cached, get the server result sequence queued up into + * the stream response + * + * @param segmentOrResult A list that is traversed in order to determine what should be sent back. All segments + * should be on the same server. + * + * @return A sequence of either the merged cached results, or the server results from any particular server + */ + private Sequence runOnServer(List> segmentOrResult, long maxQueuedBytesPerServer) { - final SortedMap> serverSegments = Maps.newTreeMap(); - for (ServerToSegment serverToSegment : segments) { - final QueryableDruidServer queryableDruidServer = serverToSegment.getServer().pick(); - - if (queryableDruidServer == null) { - log.makeAlert( - "No servers found for SegmentDescriptor[%s] for DataSource[%s]?! How can this be?!", - serverToSegment.getSegmentDescriptor(), - query.getDataSource() - ).emit(); - } else { - final DruidServer server = queryableDruidServer.getServer(); - serverSegments.computeIfAbsent(server, s -> new ArrayList<>()).add(serverToSegment.getSegmentDescriptor()); - } + final List segmentsOfServer = segmentOrResult + .stream() + .map(ServerMaybeSegmentMaybeCache::getSegmentDescriptor) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + // We should only ever have cache or queries to run, not both. So if we have no segments, try caches + if (segmentsOfServer.isEmpty()) { + // Have a special sequence for the cache results so the merge doesn't go all crazy. + // See org.apache.druid.java.util.common.guava.MergeSequenceTest.testScrewsUpOnOutOfOrder for an example + // With zero results actually being found (no segments no caches) this should essentially return a no-op + // merge sequence + return new MergeSequence<>(query.getResultOrdering(), Sequences.fromStream( + segmentOrResult + .stream() + .map(ServerMaybeSegmentMaybeCache::getCachedValue) + .filter(Objects::nonNull) + .map(Collections::singletonList) + .map(Sequences::simple) + )); + } + + final DruidServer server = segmentOrResult.get(0).getServer(); + final QueryRunner serverRunner = serverView.getQueryRunner(server); + + if (serverRunner == null) { + log.error("Server[%s] doesn't have a query runner", server); + return Sequences.empty(); + } + + final MultipleSpecificSegmentSpec segmentsOfServerSpec = new MultipleSpecificSegmentSpec(segmentsOfServer); + + final Sequence serverResults; + + if (isBySegment) { + serverResults = getBySegmentServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); + } else if (!server.segmentReplicatable() || !populateCache) { + serverResults = getSimpleServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); + } else { + serverResults = getAndCacheServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); } - return serverSegments; + return serverResults; } - private void addSequencesFromCache( - final List> listOfSequences, - final List> cachedResults - ) + private ServerMaybeSegmentMaybeCache pickServer(Pair> tuple) { - if (strategy == null) { - return; + final Optional maybeResult = tuple.getRhs(); + if (maybeResult.isPresent()) { + return new ServerMaybeSegmentMaybeCache<>(ALREADY_CACHED_SERVER, null, maybeResult.get()); } - - final Function pullFromCacheFunction = strategy.pullFromSegmentLevelCache(); - final TypeReference cacheObjectClazz = strategy.getCacheObjectClazz(); - for (Pair cachedResultPair : cachedResults) { - final byte[] cachedResult = cachedResultPair.rhs; - Sequence cachedSequence = new BaseSequence<>( - new BaseSequence.IteratorMaker>() - { - @Override - public Iterator make() - { - try { - if (cachedResult.length == 0) { - return Collections.emptyIterator(); - } - - return objectMapper.readValues( - objectMapper.getFactory().createParser(cachedResult), - cacheObjectClazz - ); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public void cleanup(Iterator iterFromMake) - { - } - } - ); - listOfSequences.add(Sequences.map(cachedSequence, pullFromCacheFunction)); + final ServerToSegment serverToSegment = tuple.getLhs(); + final QueryableDruidServer queryableDruidServer = serverToSegment.getServer().pick(); + if (queryableDruidServer == null) { + log.makeAlert( + "No servers found for SegmentDescriptor[%s] for DataSource[%s]?! How can this be?!", + serverToSegment.getSegmentDescriptor(), + query.getDataSource() + ).emit(); + return new ServerMaybeSegmentMaybeCache<>(ALREADY_CACHED_SERVER, null, null); } + final DruidServer server = queryableDruidServer.getServer(); + return new ServerMaybeSegmentMaybeCache<>(server, serverToSegment.getSegmentDescriptor(), null); } - private void addSequencesFromServer( - final List> listOfSequences, - final SortedMap> segmentsByServer + /** + * This materializes the input stream in order to group it by server. This method takes in the stream of cache + * resolved items and will group all the items by server. Each entry in the output stream contains a list whose + * entries' getServer is the same. Each entry will either have a present segemnt descriptor or a present result, + * but not both. Downstream consumers should check each and handle appropriately. + * + * @param cacheResolvedStream A pair of the count of servers (for backpressure calculations) + * + * @return A stream of potentially cached results per server + */ + + private Pair>>> groupCachedResultsByServer( + Stream>> cacheResolvedStream ) { - segmentsByServer.forEach((server, segmentsOfServer) -> { - final QueryRunner serverRunner = serverView.getQueryRunner(server); - - if (serverRunner == null) { - log.error("Server[%s] doesn't have a query runner", server); - return; - } - - final MultipleSpecificSegmentSpec segmentsOfServerSpec = new MultipleSpecificSegmentSpec(segmentsOfServer); - // Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much. - final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, httpClientConfig.getMaxQueuedBytes()); - final long maxQueuedBytesPerServer = maxQueuedBytes / segmentsByServer.size(); - final Sequence serverResults; + final Map>> groupedServers = cacheResolvedStream + .map(this::pickServer) + .collect(Collectors.groupingBy(ServerMaybeSegmentMaybeCache::getServer)); + return Pair.of(groupedServers.size(), groupedServers + .values() + // At this point we have the segments per server, and a special entry for the pre-cached results. + // As of the time of this writing, this results in a java.util.HashMap.ValueSpliterator which + // does not have great properties for splitting in parallel since it does not have total size awareness + // yet. I hope future implementations of the grouping collector can handle such a scenario where the + // grouping result is immutable and can be split very easily into parallel spliterators + .stream() + .filter(l -> !l.isEmpty()) + // Get rid of any alerted conditions missing queryableDruidServer + .filter(l -> l.get(0).getCachedValue() != null || l.get(0).getSegmentDescriptor() != null)); + } - if (isBySegment) { - serverResults = getBySegmentServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); - } else if (!server.segmentReplicatable() || !populateCache) { - serverResults = getSimpleServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); - } else { - serverResults = getAndCacheServerResults(serverRunner, segmentsOfServerSpec, maxQueuedBytesPerServer); + private Stream>> deserializeFromCache( + final Stream>> cachedResults + ) + { + if (strategy == null) { + return cachedResults.map(s -> Pair.of(s.getLhs(), Optional.empty())); + } + final Function pullFromCacheFunction = strategy.pullFromSegmentLevelCache()::apply; + final TypeReference cacheObjectClazz = strategy.getCacheObjectClazz(); + return cachedResults.flatMap(cachedResultPair -> { + if (!cachedResultPair.getRhs().isPresent()) { + return Stream.of(Pair.of(cachedResultPair.getLhs(), Optional.empty())); + } + final byte[] cachedResult = cachedResultPair.getRhs().get(); + try { + if (cachedResult.length == 0) { + return Stream.of(Pair.of(cachedResultPair.getLhs(), Optional.empty())); + } + // Query granularity in a segment may be higher fidelity than the segment as a file, + // so this might have multiple results + return StreamSupport + .stream( + Spliterators.spliteratorUnknownSize( + objectMapper.readValues(objectMapper.getFactory().createParser(cachedResult), cacheObjectClazz), + 0 + ), + false + ) + .map(pullFromCacheFunction) + .map(obj -> Pair.of(cachedResultPair.getLhs(), Optional.ofNullable(obj))); + } + catch (IOException e) { + throw new RuntimeException(e); } - listOfSequences.add(serverResults); }); } @@ -578,11 +735,11 @@ private Sequence getBySegmentServerResults( ); // bySegment results need to be de-serialized, see DirectDruidClient.run() return (Sequence) resultsBySegments - .map(result -> result.map( - resultsOfSegment -> resultsOfSegment.mapResults( + .map(result -> result + .map(resultsOfSegment -> resultsOfSegment.mapResults( toolChest.makePreComputeManipulatorFn(query, MetricManipulatorFns.deserializing())::apply - ) - )); + )) + ); } @SuppressWarnings("unchecked") @@ -598,6 +755,26 @@ private Sequence getSimpleServerResults( ); } + private Sequence bySegmentWithCachePopulator( + Result> result, + Function cachePrep + ) + { + final BySegmentResultValueClass resultsOfSegment = result.getValue(); + final Cache.NamedKey cachePopulatorKey = getCachePopulatorKey( + resultsOfSegment.getSegmentId(), + resultsOfSegment.getInterval() + ); + Sequence res = Sequences + .simple(resultsOfSegment.getResults()); + if (cachePopulatorKey != null) { + res = cachePopulator.wrap(res, cachePrep, cache, cachePopulatorKey); + } + return res.map( + toolChest.makePreComputeManipulatorFn(downstreamQuery, MetricManipulatorFns.deserializing())::apply + ); + } + private Sequence getAndCacheServerResults( final QueryRunner serverRunner, final MultipleSpecificSegmentSpec segmentsOfServerSpec, @@ -612,22 +789,46 @@ private Sequence getAndCacheServerResults( .withMaxQueuedBytes(maxQueuedBytesPerServer), responseContext ); - final Function cacheFn = strategy.prepareForSegmentLevelCache(); - + final Function cacheFn = strategy.prepareForSegmentLevelCache()::apply; return resultsBySegments - .map(result -> { - final BySegmentResultValueClass resultsOfSegment = result.getValue(); - final Cache.NamedKey cachePopulatorKey = - getCachePopulatorKey(resultsOfSegment.getSegmentId(), resultsOfSegment.getInterval()); - Sequence res = Sequences.simple(resultsOfSegment.getResults()); - if (cachePopulatorKey != null) { - res = cachePopulator.wrap(res, cacheFn::apply, cache, cachePopulatorKey); - } - return res.map( - toolChest.makePreComputeManipulatorFn(downstreamQuery, MetricManipulatorFns.deserializing())::apply - ); - }) - .flatMerge(seq -> seq, query.getResultOrdering()); + .map(result -> bySegmentWithCachePopulator(result, cacheFn)) + .flatMerge(Function.identity(), query.getResultOrdering()); + } + } + + // POJO + private static class ServerMaybeSegmentMaybeCache + { + private final DruidServer server; + private final SegmentDescriptor segmentDescriptor; + private final T cachedValue; + + public DruidServer getServer() + { + return server; + } + + @Nullable + public SegmentDescriptor getSegmentDescriptor() + { + return segmentDescriptor; + } + + @Nullable + public T getCachedValue() + { + return cachedValue; + } + + private ServerMaybeSegmentMaybeCache( + DruidServer server, + @Nullable SegmentDescriptor segmentDescriptor, + @Nullable T cachedValue + ) + { + this.server = server; + this.segmentDescriptor = segmentDescriptor; + this.cachedValue = cachedValue; } } @@ -640,12 +841,12 @@ private ServerToSegment(ServerSelector server, SegmentDescriptor segment) ServerSelector getServer() { - return lhs; + return super.getLhs(); } SegmentDescriptor getSegmentDescriptor() { - return rhs; + return super.getRhs(); } } } diff --git a/server/src/main/java/org/apache/druid/client/cache/Cache.java b/server/src/main/java/org/apache/druid/client/cache/Cache.java index 725a880a4da7..27e29106934d 100644 --- a/server/src/main/java/org/apache/druid/client/cache/Cache.java +++ b/server/src/main/java/org/apache/druid/client/cache/Cache.java @@ -20,6 +20,7 @@ package org.apache.druid.client.cache; import com.google.common.base.Preconditions; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.emitter.service.ServiceEmitter; @@ -27,6 +28,8 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; /** */ @@ -34,16 +37,30 @@ public interface Cache { @Nullable byte[] get(NamedKey key); + void put(NamedKey key, byte[] value); /** * Resulting map should not contain any null values (i.e. cache misses should not be included) * * @param keys + * * @return */ Map getBulk(Iterable keys); + /** + * Returns a stream of the input keys with an optional byte array if the key was found in the cache + * + * @param keys + * + * @return + */ + default Stream>> getBulk(Stream keys) + { + return keys.map(key -> new Pair<>(key, Optional.ofNullable(get(key)))); + } + void close(String namespace); CacheStats getStats(); @@ -52,6 +69,7 @@ public interface Cache /** * Custom metrics not covered by CacheStats may be emitted by this method. + * * @param emitter The service emitter to emit on. */ void doMonitor(ServiceEmitter emitter); @@ -73,9 +91,9 @@ public byte[] toByteArray() { final byte[] nsBytes = StringUtils.toUtf8(this.namespace); return ByteBuffer.allocate(Integer.BYTES + nsBytes.length + this.key.length) - .putInt(nsBytes.length) - .put(nsBytes) - .put(this.key).array(); + .putInt(nsBytes.length) + .put(nsBytes) + .put(this.key).array(); } @Override diff --git a/server/src/main/java/org/apache/druid/client/cache/CaffeineCache.java b/server/src/main/java/org/apache/druid/client/cache/CaffeineCache.java index ec18cd2d7aba..62db94276af4 100644 --- a/server/src/main/java/org/apache/druid/client/cache/CaffeineCache.java +++ b/server/src/main/java/org/apache/druid/client/cache/CaffeineCache.java @@ -28,16 +28,19 @@ import net.jpountz.lz4.LZ4Compressor; import net.jpountz.lz4.LZ4Factory; import net.jpountz.lz4.LZ4FastDecompressor; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import java.nio.ByteBuffer; import java.util.Map; +import java.util.Optional; import java.util.OptionalLong; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; public class CaffeineCache implements org.apache.druid.client.cache.Cache { @@ -105,6 +108,21 @@ public Map getBulk(Iterable keys) return ImmutableMap.copyOf(Maps.transformValues(cache.getAllPresent(keys), this::deserialize)); } + @Override + public Stream>> getBulk(Stream keys) + { + return keys.map( + k -> Pair.of( + k, + Optional.ofNullable( + cache.getIfPresent(k) + ).map( + this::deserialize + ) + ) + ); + } + // This is completely racy with put. Any values missed should be evicted later anyways. So no worries. @Override public void close(String namespace) diff --git a/server/src/main/java/org/apache/druid/client/cache/HybridCache.java b/server/src/main/java/org/apache/druid/client/cache/HybridCache.java index d940f7166185..38eba49bca8a 100644 --- a/server/src/main/java/org/apache/druid/client/cache/HybridCache.java +++ b/server/src/main/java/org/apache/druid/client/cache/HybridCache.java @@ -21,14 +21,19 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import javax.annotation.Nullable; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class HybridCache implements Cache { @@ -38,8 +43,8 @@ public class HybridCache implements Cache private final Cache level1; private final Cache level2; - private final AtomicLong hitCount = new AtomicLong(0); - private final AtomicLong missCount = new AtomicLong(0); + private final LongAdder hitCount = new LongAdder(); + private final LongAdder missCount = new LongAdder(); public HybridCache(HybridCacheConfig config, Cache level1, Cache level2) { @@ -61,10 +66,10 @@ public byte[] get(NamedKey key) } } if (res != null) { - hitCount.incrementAndGet(); + hitCount.increment(); return res; } else { - missCount.incrementAndGet(); + missCount.increment(); return null; } } @@ -93,7 +98,7 @@ public Map getBulk(Iterable keys) { Set remaining = Sets.newHashSet(keys); Map res = level1.getBulk(keys); - hitCount.addAndGet(res.size()); + hitCount.add(res.size()); remaining = Sets.difference(remaining, res.keySet()); @@ -104,8 +109,8 @@ public Map getBulk(Iterable keys) } int size = res2.size(); - hitCount.addAndGet(size); - missCount.addAndGet(remaining.size() - size); + hitCount.add(size); + missCount.add(remaining.size() - size); if (size != 0) { res = Maps.newHashMap(res); @@ -124,6 +129,49 @@ private Map getBulkL2(Iterable keys) } } + @Override + public Stream>> getBulk(Stream keys) + { + if (!config.getUseL2()) { + return level1.getBulk(keys); + } + final List>> materializedL1Results = level1 + .getBulk(keys) + .collect(Collectors.toList()); + final List>> materializedL2Results = level2 + .getBulk( + materializedL1Results.stream( + ).filter( + s -> !s.getRhs().isPresent() + ).map( + Pair::getLhs + ) + ).collect(Collectors.toList()); + // The l2 list should only have "missing" ones from l1. So we loop through and look for the missing L1 results + // and replace with whatever l2 found + int l2Pos = 0; + for (int i = 0; i < materializedL1Results.size(); i++) { + final Pair> me = materializedL1Results.get(i); + if (!me.getRhs().isPresent()) { + final Pair> other = materializedL2Results.get(l2Pos++); + if (!me.getLhs().equals(other.getLhs())) { + // sanity check for something very broken + break; + } + materializedL1Results.set(i, other); + } + } + // Register hits/misses early so it doesn't require the stream to be consumed + materializedL1Results.forEach(sp -> { + if (sp.getRhs().isPresent()) { + hitCount.increment(); + } else { + missCount.increment(); + } + }); + return materializedL1Results.stream(); + } + @Override public void close(String namespace) { @@ -137,8 +185,8 @@ public CacheStats getStats() CacheStats stats1 = level1.getStats(); CacheStats stats2 = level2.getStats(); return new CacheStats( - hitCount.get(), - missCount.get(), + hitCount.longValue(), + missCount.longValue(), stats1.getNumEntries() + stats2.getNumEntries(), stats1.getSizeInBytes() + stats2.getSizeInBytes(), stats1.getNumEvictions() + stats2.getNumEvictions(), diff --git a/server/src/main/java/org/apache/druid/client/cache/MemcachedCache.java b/server/src/main/java/org/apache/druid/client/cache/MemcachedCache.java index a3fd26caf073..6975ea145212 100644 --- a/server/src/main/java/org/apache/druid/client/cache/MemcachedCache.java +++ b/server/src/main/java/org/apache/druid/client/cache/MemcachedCache.java @@ -19,7 +19,6 @@ package org.apache.druid.client.cache; -import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; import com.google.common.base.Supplier; @@ -27,7 +26,6 @@ import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import net.spy.memcached.AddrUtil; @@ -45,6 +43,7 @@ import org.apache.commons.codec.digest.DigestUtils; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.collections.StupidResourceHolder; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.emitter.service.ServiceEmitter; @@ -57,9 +56,12 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -69,6 +71,9 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; public class MemcachedCache implements Cache { @@ -524,59 +529,35 @@ private static byte[] deserializeValue(NamedKey key, byte[] bytes) return value; } - @Override - public Map getBulk(Iterable keys) + + Map getCacheMap(Collection keys) { + if (keys.isEmpty()) { + return Collections.emptyMap(); + } + // Hold onto the client until the future is fetched try (ResourceHolder clientHolder = client.get()) { - Map keyLookup = Maps.uniqueIndex( - keys, - new Function() - { - @Override - public String apply( - @Nullable NamedKey input - ) - { - return computeKeyHash(memcachedPrefix, input); - } - } - ); - - Map results = Maps.newHashMap(); - - BulkFuture> future; + final BulkFuture> future; try { - future = clientHolder.get().asyncGetBulk(keyLookup.keySet()); + future = clientHolder.get().asyncGetBulk(keys); } catch (IllegalStateException e) { // operation did not get queued in time (queue is full) errorCount.incrementAndGet(); log.warn(e, "Unable to queue cache operation"); - return results; + return Collections.emptyMap(); } try { - Map some = future.getSome(timeout, TimeUnit.MILLISECONDS); + final Map some = future.getSome(timeout, TimeUnit.MILLISECONDS); if (future.isTimeout()) { future.cancel(false); timeoutCount.incrementAndGet(); } - missCount.addAndGet(keyLookup.size() - some.size()); + missCount.addAndGet(keys.size() - some.size()); hitCount.addAndGet(some.size()); - - for (Map.Entry entry : some.entrySet()) { - final NamedKey key = keyLookup.get(entry.getKey()); - final byte[] value = (byte[]) entry.getValue(); - if (value != null) { - results.put( - key, - deserializeValue(key, value) - ); - } - } - - return results; + return some; } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -585,11 +566,44 @@ public String apply( catch (ExecutionException e) { errorCount.incrementAndGet(); log.warn(e, "Exception pulling item from cache"); - return results; + return Collections.emptyMap(); } } } + @Override + public Stream>> getBulk(Stream keys) + { + final List> materializedKeys = keys.map( + k -> Pair.of(k, computeKeyHash(memcachedPrefix, k)) + ).collect( + Collectors.toList() + ); + final Map some = getCacheMap( + materializedKeys + .stream() + .map(Pair::getRhs) + .collect(Collectors.toList()) + ); + return materializedKeys.stream().map(k -> { + final NamedKey key = k.getLhs(); + final String cacheKey = k.getRhs(); + return Pair.of( + key, + Optional.ofNullable(some.get(cacheKey)) + .map(val -> deserializeValue(key, (byte[]) val)) + ); + }); + } + + @Override + public Map getBulk(Iterable keys) + { + return getBulk(StreamSupport.stream(keys.spliterator(), false)) + .filter(s -> s.getRhs().isPresent()) + .collect(Collectors.toMap(Pair::getLhs, s -> s.getRhs().get())); + } + @Override public void close(String namespace) { diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java index 9cec0227aac3..76eb9e3303a0 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java @@ -38,12 +38,15 @@ import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.DataSource; +import org.apache.druid.query.DefaultQueryRunnerFactoryConglomerate; import org.apache.druid.query.Druids; import org.apache.druid.query.Query; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; +import org.apache.druid.query.QueryRunnerTestHelper; import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.select.SelectQueryConfig; @@ -68,6 +71,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; /** */ @@ -98,6 +102,21 @@ public void setUp() { timeline = new VersionedIntervalTimeline<>(Ordering.natural()); serverView = EasyMock.createNiceMock(TimelineServerView.class); + + final QueryRunner emptyQueryRunner = EasyMock.createStrictMock(QueryRunner.class); + + EasyMock.expect( + emptyQueryRunner.run(EasyMock.anyObject(), EasyMock.anyObject()) + ).andReturn( + Sequences.empty() + ).anyTimes(); + EasyMock.expect( + serverView.getQueryRunner(EasyMock.anyObject()) + ).andReturn( + emptyQueryRunner + ).anyTimes(); + + EasyMock.replay(serverView, emptyQueryRunner); cache = MapCache.create(100000); client = makeClient( new ForegroundCachePopulator(OBJECT_MAPPER, new CachePopulatorStats(), -1) @@ -238,6 +257,9 @@ protected CachingClusteredClient makeClient( ) { return new CachingClusteredClient( + new DefaultQueryRunnerFactoryConglomerate( + QueryRunnerTestHelper.DEFAULT_CONGLOMERATE_MAP + ), WAREHOUSE, new TimelineServerView() { @@ -279,6 +301,7 @@ public void registerServerRemovedCallback(Executor exec, ServerRemovedCallback c }, cache, OBJECT_MAPPER, + ForkJoinPool.commonPool(), cachePopulator, new CacheConfig() { diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java index 573c5c924658..bf1d8f6a3b94 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java @@ -69,12 +69,15 @@ import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.java.util.common.guava.FunctionalIterable; import org.apache.druid.java.util.common.guava.MergeIterable; +import org.apache.druid.java.util.common.guava.MergeSequence; +import org.apache.druid.java.util.common.guava.MergeWorkTask; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.nary.TrinaryFn; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.BySegmentResultValueClass; import org.apache.druid.query.DataSource; +import org.apache.druid.query.DefaultQueryRunnerFactoryConglomerate; import org.apache.druid.query.Druids; import org.apache.druid.query.FinalizeResultsQueryRunner; import org.apache.druid.query.Query; @@ -160,10 +163,13 @@ import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.Spliterator; import java.util.TreeMap; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; +import java.util.stream.Stream; /** */ @@ -281,14 +287,7 @@ public static Iterable constructorFeeder() { return Lists.transform( Lists.newArrayList(new RangeIterable(RANDOMNESS)), - new Function() - { - @Override - public Object[] apply(Integer input) - { - return new Object[]{input}; - } - } + input -> new Object[]{input} ); } @@ -555,9 +554,9 @@ public void testCachingOverBulkLimitEnforcesLimit() final Map context = new HashMap<>(); final Cache cache = EasyMock.createStrictMock(Cache.class); - final Capture> cacheKeyCapture = EasyMock.newCapture(); + final Capture> cacheKeyCapture = EasyMock.newCapture(); EasyMock.expect(cache.getBulk(EasyMock.capture(cacheKeyCapture))) - .andReturn(ImmutableMap.of()) + .andReturn(Stream.empty()) .once(); EasyMock.replay(cache); client = makeClient(new ForegroundCachePopulator(JSON_MAPPER, new CachePopulatorStats(), -1), cache, limit); @@ -575,22 +574,31 @@ public void testCachingOverBulkLimitEnforcesLimit() getDefaultQueryRunner().run(QueryPlus.wrap(query), context); Assert.assertTrue("Capture cache keys", cacheKeyCapture.hasCaptured()); - Assert.assertTrue("Cache key below limit", ImmutableList.copyOf(cacheKeyCapture.getValue()).size() <= limit); + Assert.assertTrue("Cache key below limit", cacheKeyCapture.getValue().count() <= limit); EasyMock.verify(cache); EasyMock.reset(cache); cacheKeyCapture.reset(); EasyMock.expect(cache.getBulk(EasyMock.capture(cacheKeyCapture))) - .andReturn(ImmutableMap.of()) + .andReturn(Stream.empty()) .once(); EasyMock.replay(cache); client = makeClient(new ForegroundCachePopulator(JSON_MAPPER, new CachePopulatorStats(), -1), cache, 0); + + // Direct druid client runners are eagerly forked off + QueryRunner runner = EasyMock.createStrictMock(QueryRunner.class); + EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once(); + + EasyMock.reset(serverView); + EasyMock.expect(serverView.getQueryRunner(lastServer)).andReturn(runner).once(); + EasyMock.replay(runner, serverView); + getDefaultQueryRunner().run(QueryPlus.wrap(query), context); - EasyMock.verify(cache); - EasyMock.verify(dataSegment); + + EasyMock.verify(cache, dataSegment, serverView, runner); Assert.assertTrue("Capture cache keys", cacheKeyCapture.hasCaptured()); - Assert.assertTrue("Cache Keys empty", ImmutableList.copyOf(cacheKeyCapture.getValue()).isEmpty()); + Assert.assertTrue("Cache Keys empty", cacheKeyCapture.getValue().count() == 0); } @Test @@ -1174,14 +1182,14 @@ public void testSearchCaching() public void testSearchCachingRenamedOutput() { final Druids.SearchQueryBuilder builder = Druids.newSearchQueryBuilder() - .dataSource(DATA_SOURCE) - .filters(DIM_FILTER) - .granularity(GRANULARITY) - .limit(1000) - .intervals(SEG_SPEC) - .dimensions(Collections.singletonList(TOP_DIM)) - .query("how") - .context(CONTEXT); + .dataSource(DATA_SOURCE) + .filters(DIM_FILTER) + .granularity(GRANULARITY) + .limit(1000) + .intervals(SEG_SPEC) + .dimensions(Collections.singletonList(TOP_DIM)) + .query("how") + .context(CONTEXT); testQueryCaching( getDefaultQueryRunner(), @@ -1289,18 +1297,18 @@ public void testSelectCaching() Intervals.of("2011-01-05/2011-01-10"), makeSelectResults(dimensions, metrics, DateTimes.of("2011-01-05"), - DateTimes.of("2011-01-06"), - DateTimes.of("2011-01-07"), ImmutableMap.of("a", "f", "rows", 7), ImmutableMap.of("a", "ff"), - DateTimes.of("2011-01-08"), ImmutableMap.of("a", "g", "rows", 8), - DateTimes.of("2011-01-09"), ImmutableMap.of("a", "h", "rows", 9) + DateTimes.of("2011-01-06"), + DateTimes.of("2011-01-07"), ImmutableMap.of("a", "f", "rows", 7), ImmutableMap.of("a", "ff"), + DateTimes.of("2011-01-08"), ImmutableMap.of("a", "g", "rows", 8), + DateTimes.of("2011-01-09"), ImmutableMap.of("a", "h", "rows", 9) ), Intervals.of("2011-01-05/2011-01-10"), makeSelectResults(dimensions, metrics, DateTimes.of("2011-01-05T01"), ImmutableMap.of("a", "d", "rows", 5), - DateTimes.of("2011-01-06T01"), ImmutableMap.of("a", "e", "rows", 6), - DateTimes.of("2011-01-07T01"), ImmutableMap.of("a", "f", "rows", 7), - DateTimes.of("2011-01-08T01"), ImmutableMap.of("a", "g", "rows", 8), - DateTimes.of("2011-01-09T01"), ImmutableMap.of("a", "h", "rows", 9) + DateTimes.of("2011-01-06T01"), ImmutableMap.of("a", "e", "rows", 6), + DateTimes.of("2011-01-07T01"), ImmutableMap.of("a", "f", "rows", 7), + DateTimes.of("2011-01-08T01"), ImmutableMap.of("a", "g", "rows", 8), + DateTimes.of("2011-01-09T01"), ImmutableMap.of("a", "h", "rows", 9) ) ); @@ -1315,17 +1323,17 @@ public void testSelectCaching() HashMap context = new HashMap(); TestHelper.assertExpectedResults( makeSelectResults(dimensions, metrics, DateTimes.of("2011-01-01"), ImmutableMap.of("a", "b", "rows", 1), - DateTimes.of("2011-01-02"), ImmutableMap.of("a", "c", "rows", 5), - DateTimes.of("2011-01-05"), - DateTimes.of("2011-01-05T01"), ImmutableMap.of("a", "d", "rows", 5), - DateTimes.of("2011-01-06"), - DateTimes.of("2011-01-06T01"), ImmutableMap.of("a", "e", "rows", 6), - DateTimes.of("2011-01-07"), ImmutableMap.of("a", "f", "rows", 7), ImmutableMap.of("a", "ff"), - DateTimes.of("2011-01-07T01"), ImmutableMap.of("a", "f", "rows", 7), - DateTimes.of("2011-01-08"), ImmutableMap.of("a", "g", "rows", 8), - DateTimes.of("2011-01-08T01"), ImmutableMap.of("a", "g", "rows", 8), - DateTimes.of("2011-01-09"), ImmutableMap.of("a", "h", "rows", 9), - DateTimes.of("2011-01-09T01"), ImmutableMap.of("a", "h", "rows", 9) + DateTimes.of("2011-01-02"), ImmutableMap.of("a", "c", "rows", 5), + DateTimes.of("2011-01-05"), + DateTimes.of("2011-01-05T01"), ImmutableMap.of("a", "d", "rows", 5), + DateTimes.of("2011-01-06"), + DateTimes.of("2011-01-06T01"), ImmutableMap.of("a", "e", "rows", 6), + DateTimes.of("2011-01-07"), ImmutableMap.of("a", "f", "rows", 7), ImmutableMap.of("a", "ff"), + DateTimes.of("2011-01-07T01"), ImmutableMap.of("a", "f", "rows", 7), + DateTimes.of("2011-01-08"), ImmutableMap.of("a", "g", "rows", 8), + DateTimes.of("2011-01-08T01"), ImmutableMap.of("a", "g", "rows", 8), + DateTimes.of("2011-01-09"), ImmutableMap.of("a", "h", "rows", 9), + DateTimes.of("2011-01-09T01"), ImmutableMap.of("a", "h", "rows", 9) ), runner.run(QueryPlus.wrap(builder.intervals("2011-01-01/2011-01-10").build()), context) ); @@ -1338,14 +1346,14 @@ public void testSelectCachingRenamedOutputName() final Set metrics = Sets.newHashSet("rows"); Druids.SelectQueryBuilder builder = Druids.newSelectQueryBuilder() - .dataSource(DATA_SOURCE) - .intervals(SEG_SPEC) - .filters(DIM_FILTER) - .granularity(GRANULARITY) - .dimensions(Collections.singletonList("a")) - .metrics(Collections.singletonList("rows")) - .pagingSpec(new PagingSpec(null, 3)) - .context(CONTEXT); + .dataSource(DATA_SOURCE) + .intervals(SEG_SPEC) + .filters(DIM_FILTER) + .granularity(GRANULARITY) + .dimensions(Collections.singletonList("a")) + .metrics(Collections.singletonList("rows")) + .pagingSpec(new PagingSpec(null, 3)) + .context(CONTEXT); testQueryCaching( getDefaultQueryRunner(), @@ -1433,9 +1441,9 @@ public void testSelectCachingRenamedOutputName() public void testGroupByCaching() { List aggsWithUniques = ImmutableList.builder() - .addAll(AGGS) - .add(new HyperUniquesAggregatorFactory("uniques", "uniques")) - .build(); + .addAll(AGGS) + .add(new HyperUniquesAggregatorFactory("uniques", "uniques")) + .build(); final HashFunction hashFn = Hashing.murmur3_128(); @@ -1637,9 +1645,11 @@ For dim1 (2011-01-06/2011-01-10), the combined range for the bound filters is {( makeTimeResults(DateTimes.of("2011-01-01"), 50, 5000, DateTimes.of("2011-01-02"), 10, 1252, DateTimes.of("2011-01-03"), 20, 6213, - DateTimes.of("2011-01-04"), 30, 743), + DateTimes.of("2011-01-04"), 30, 743 + ), makeTimeResults(DateTimes.of("2011-01-07"), 60, 6020, - DateTimes.of("2011-01-08"), 70, 250) + DateTimes.of("2011-01-08"), 70, 250 + ) ); testQueryCachingWithFilter( @@ -1677,14 +1687,14 @@ public void testSingleDimensionPruning() ); final Druids.TimeseriesQueryBuilder builder = Druids.newTimeseriesQueryBuilder() - .dataSource(DATA_SOURCE) - .filters(filter) - .granularity(GRANULARITY) - .intervals(SEG_SPEC) - .context(CONTEXT) - .intervals("2011-01-05/2011-01-10") - .aggregators(RENAMED_AGGS) - .postAggregators(RENAMED_POST_AGGS); + .dataSource(DATA_SOURCE) + .filters(filter) + .granularity(GRANULARITY) + .intervals(SEG_SPEC) + .context(CONTEXT) + .intervals("2011-01-05/2011-01-10") + .aggregators(RENAMED_AGGS) + .postAggregators(RENAMED_POST_AGGS); TimeseriesQuery query = builder.build(); Map context = new HashMap<>(); @@ -1740,7 +1750,8 @@ public void testSingleDimensionPruning() } private ServerSelector makeMockSingleDimensionSelector( - DruidServer server, String dimension, String start, String end, int partitionNum) + DruidServer server, String dimension, String start, String end, int partitionNum + ) { DataSegment segment = EasyMock.createNiceMock(DataSegment.class); EasyMock.expect(segment.getIdentifier()).andReturn(DATA_SOURCE).anyTimes(); @@ -1870,7 +1881,12 @@ public void testQueryCachingWithFilter( @Override public Sequence answer() { - return toFilteredQueryableTimeseriesResults((TimeseriesQuery) capture.getValue().getQuery(), segmentIds, queryIntervals, results); + return toFilteredQueryableTimeseriesResults( + (TimeseriesQuery) capture.getValue().getQuery(), + segmentIds, + queryIntervals, + results + ); } }) .times(0, 1); @@ -1928,7 +1944,11 @@ private Sequence> toFilteredQueryableTimeseriesRes MultipleSpecificSegmentSpec spec = (MultipleSpecificSegmentSpec) query.getQuerySegmentSpec(); List> ret = Lists.newArrayList(); for (SegmentDescriptor descriptor : spec.getDescriptors()) { - String id = StringUtils.format("%s_%s", queryIntervals.indexOf(descriptor.getInterval()), descriptor.getPartitionNumber()); + String id = StringUtils.format( + "%s_%s", + queryIntervals.indexOf(descriptor.getInterval()), + descriptor.getPartitionNumber() + ); int index = segmentIds.indexOf(id); if (index != -1) { ret.add(new Result( @@ -1994,8 +2014,8 @@ public void testQueryCaching( .andReturn(expectations.getQueryRunner()) .once(); - final Capture capture = new Capture(); - final Capture context = new Capture(); + final Capture capture = EasyMock.newCapture(); + final Capture context = EasyMock.newCapture(); queryCaptures.add(capture); QueryRunner queryable = expectations.getQueryRunner(); @@ -2088,50 +2108,40 @@ public void testQueryCaching( } runWithMocks( - new Runnable() - { - @Override - public void run() - { - HashMap context = new HashMap(); - for (int i = 0; i < numTimesToQuery; ++i) { - TestHelper.assertExpectedResults( - new MergeIterable<>( - Comparators.naturalNullsFirst(), - FunctionalIterable - .create(new RangeIterable(expectedResultsRangeStart, expectedResultsRangeEnd)) - .transformCat( - new Function>>>() - { - @Override - public Iterable>> apply(@Nullable Integer input) - { - List>> retVal = Lists.newArrayList(); - - final Map exps = serverExpectationList.get(input); - for (ServerExpectations expectations : exps.values()) { - for (ServerExpectation expectation : expectations) { - retVal.add(expectation.getResults()); - } - } - - return retVal; + () -> { + HashMap context = new HashMap<>(); + for (int i1 = 0; i1 < numTimesToQuery; ++i1) { + TestHelper.assertExpectedResults( + new MergeIterable<>( + Comparators.naturalNullsFirst(), + FunctionalIterable + .create(new RangeIterable(expectedResultsRangeStart, expectedResultsRangeEnd)) + .transformCat( + (Function>>>) input -> { + List>> retVal = Lists.newArrayList(); + + final Map exps = serverExpectationList.get(input); + for (ServerExpectations expectations : exps.values()) { + for (ServerExpectation expectation : expectations) { + retVal.add(expectation.getResults()); } } - ) - ), - runner.run( - QueryPlus.wrap( - query.withQuerySegmentSpec( - new MultipleIntervalSegmentSpec(ImmutableList.of(actualQueryInterval)) - ) - ), - context - ) - ); - if (queryCompletedCallback != null) { - queryCompletedCallback.run(); - } + + return retVal; + } + ) + ), + runner.run( + QueryPlus.wrap( + query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(ImmutableList.of(actualQueryInterval)) + ) + ), + context + ) + ); + if (queryCompletedCallback != null) { + queryCompletedCallback.run(); } } }, @@ -2440,13 +2450,13 @@ private Iterable> makeTimeResults(Object... object (DateTime) objects[i], new TimeseriesResultValue( ImmutableMap.builder() - .put("rows", objects[i + 1]) - .put("imps", objects[i + 2]) - .put("impers", objects[i + 2]) - .put("avg_imps_per_row", avg_impr) - .put("avg_imps_per_row_half", avg_impr / 2) - .put("avg_imps_per_row_double", avg_impr * 2) - .build() + .put("rows", objects[i + 1]) + .put("imps", objects[i + 2]) + .put("impers", objects[i + 2]) + .put("avg_imps_per_row", avg_impr) + .put("avg_imps_per_row_half", avg_impr / 2) + .put("avg_imps_per_row_double", avg_impr * 2) + .build() ) ) ); @@ -2563,7 +2573,11 @@ private Iterable> makeSearchResults(String dim, Object return retVal; } - private Iterable> makeSelectResults(Set dimensions, Set metrics, Object... objects) + private Iterable> makeSelectResults( + Set dimensions, + Set metrics, + Object... objects + ) { List> retVal = Lists.newArrayList(); int index = 0; @@ -2579,7 +2593,8 @@ private Iterable> makeSelectResults(Set dimens retVal.add(new Result<>( timestamp, new SelectResultValue(ImmutableMap.of(timestamp.toString(), 0), - dimensions, metrics, values) + dimensions, metrics, values + ) )); } return retVal; @@ -2625,6 +2640,9 @@ protected CachingClusteredClient makeClient( ) { return new CachingClusteredClient( + new DefaultQueryRunnerFactoryConglomerate( + QueryRunnerTestHelper.DEFAULT_CONGLOMERATE_MAP + ), WAREHOUSE, new TimelineServerView() { @@ -2665,6 +2683,7 @@ public void registerServerRemovedCallback(Executor exec, ServerRemovedCallback c }, cache, JSON_MAPPER, + ForkJoinPool.commonPool(), cachePopulator, new CacheConfig() { @@ -3080,7 +3099,10 @@ public void testIfNoneMatch() TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder() .dataSource(DATA_SOURCE) .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(interval))) - .context(ImmutableMap.of("If-None-Match", "aVJV29CJY93rszVW/QBy0arWZo0=")) + .context(ImmutableMap.of( + "If-None-Match", + "aVJV29CJY93rszVW/QBy0arWZo0=" + )) .build(); @@ -3093,13 +3115,124 @@ public void testIfNoneMatch() @SuppressWarnings("unchecked") private QueryRunner getDefaultQueryRunner() { - return new QueryRunner() { - @Override - public Sequence run(final QueryPlus queryPlus, final Map responseContext) - { - return client.getQueryRunnerForIntervals(queryPlus.getQuery(), queryPlus.getQuery().getIntervals()) - .run(queryPlus, responseContext); - } - }; + return (queryPlus, responseContext) -> client + .getQueryRunnerForIntervals(queryPlus.getQuery(), queryPlus.getQuery().getIntervals()) + .run(queryPlus, responseContext); + } + + @Test + public void testSpliterator() + { + { + // populate cache selectively + final Druids.TimeseriesQueryBuilder builder = Druids.newTimeseriesQueryBuilder() + .dataSource(DATA_SOURCE) + .intervals(SEG_SPEC) + .filters(DIM_FILTER) + .granularity(GRANULARITY) + .aggregators(AGGS) + .postAggregators(POST_AGGS) + .context(CONTEXT); + + QueryRunner runner = new FinalizeResultsQueryRunner( + getDefaultQueryRunner(), + new TimeseriesQueryQueryToolChest( + QueryRunnerTestHelper.NoopIntervalChunkingQueryRunnerDecorator() + ) + ); + + testQueryCaching( + runner, + 1, + true, + builder.build(), + Intervals.of("2011-01-01/2011-01-02"), makeTimeResults(DateTimes.of("2011-01-01"), 50, 5000), + Intervals.of("2011-01-02/2011-01-03"), makeTimeResults(DateTimes.of("2011-01-02"), 30, 6000), + Intervals.of("2011-01-04/2011-01-05"), makeTimeResults(DateTimes.of("2011-01-04"), 23, 85312), + + Intervals.of("2011-01-05/2011-01-10"), + makeTimeResults( + DateTimes.of("2011-01-05"), 85, 102, + DateTimes.of("2011-01-06"), 412, 521, + DateTimes.of("2011-01-07"), 122, 21894, + DateTimes.of("2011-01-08"), 5, 20, + DateTimes.of("2011-01-09"), 18, 521 + ), + + Intervals.of("2011-01-05/2011-01-10"), + makeTimeResults( + DateTimes.of("2011-01-05T01"), 80, 100, + DateTimes.of("2011-01-06T01"), 420, 520, + DateTimes.of("2011-01-07T01"), 12, 2194, + DateTimes.of("2011-01-08T01"), 59, 201, + DateTimes.of("2011-01-09T01"), 181, 52 + ) + ); + } + final Druids.TimeseriesQueryBuilder builder = Druids + .newTimeseriesQueryBuilder() + .dataSource(DATA_SOURCE) + .intervals(SEG_SPEC) + .filters(DIM_FILTER) + .granularity(GRANULARITY) + .aggregators(AGGS) + .postAggregators(POST_AGGS) + .context(CONTEXT); + + + final HashMap context = new HashMap<>(); + final TimeseriesQuery query = builder + .intervals("2011-01-01/2011-01-10") + .aggregators(RENAMED_AGGS) + .postAggregators(RENAMED_POST_AGGS) + .build() + .withOverriddenContext(Collections.singletonMap("populateCache", "false")); + + final Stream>> results = client.run( + QueryPlus.wrap(query), + context, + stringServerSelectorTimelineLookup -> stringServerSelectorTimelineLookup + ); + + final Spliterator>> spliterator = results.spliterator(); + + Assert.assertNotNull(spliterator); + final int characteristics = spliterator.characteristics(); + Assert.assertEquals(characteristics & Spliterator.SIZED, Spliterator.SIZED); + Assert.assertEquals(characteristics & Spliterator.SUBSIZED, Spliterator.SUBSIZED); + final ArrayList>> sequences = new ArrayList<>(); + spliterator.forEachRemaining(sequences::add); + Assert.assertFalse(sequences.isEmpty()); + + + final Sequence> parallelMergeResults = MergeWorkTask.parallelMerge( + client.run( + QueryPlus.wrap(query), + context, + stringServerSelectorTimelineLookup -> stringServerSelectorTimelineLookup + ).parallel(), + s -> new MergeSequence<>(query.getResultOrdering(), Sequences.fromStream(s)), + 1, + ForkJoinPool.commonPool() + ); + + TestHelper.assertExpectedResults( + makeRenamedTimeResults( + DateTimes.of("2011-01-01"), 50, 5000, + DateTimes.of("2011-01-02"), 30, 6000, + DateTimes.of("2011-01-04"), 23, 85312, + DateTimes.of("2011-01-05"), 85, 102, + DateTimes.of("2011-01-05T01"), 80, 100, + DateTimes.of("2011-01-06"), 412, 521, + DateTimes.of("2011-01-06T01"), 420, 520, + DateTimes.of("2011-01-07"), 122, 21894, + DateTimes.of("2011-01-07T01"), 12, 2194, + DateTimes.of("2011-01-08"), 5, 20, + DateTimes.of("2011-01-08T01"), 59, 201, + DateTimes.of("2011-01-09"), 18, 521, + DateTimes.of("2011-01-09T01"), 181, 52 + ), + parallelMergeResults + ); } } diff --git a/server/src/test/java/org/apache/druid/client/cache/HybridCacheTest.java b/server/src/test/java/org/apache/druid/client/cache/HybridCacheTest.java index 4284bd535b4a..2c86f260ed8e 100644 --- a/server/src/test/java/org/apache/druid/client/cache/HybridCacheTest.java +++ b/server/src/test/java/org/apache/druid/client/cache/HybridCacheTest.java @@ -22,21 +22,24 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; import com.google.common.primitives.Ints; -import com.google.inject.Binder; import com.google.inject.Injector; import com.google.inject.Key; -import com.google.inject.Module; import com.google.inject.name.Names; import org.apache.druid.guice.CacheModule; import org.apache.druid.guice.GuiceInjectors; import org.apache.druid.guice.annotations.Global; import org.apache.druid.initialization.Initialization; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; public class HybridCacheTest { @@ -53,17 +56,12 @@ public void testInjection() System.setProperty(prefix + ".l2.hosts", "localhost:11711"); final Injector injector = Initialization.makeInjectorWithModules( - GuiceInjectors.makeStartupInjector(), ImmutableList.of( - new Module() - { - @Override - public void configure(Binder binder) - { - binder.bindConstant().annotatedWith(Names.named("serviceName")).to("hybridTest"); - binder.bindConstant().annotatedWith(Names.named("servicePort")).to(0); - binder.bindConstant().annotatedWith(Names.named("tlsServicePort")).to(-1); - binder.install(new CacheModule(prefix)); - } + GuiceInjectors.makeStartupInjector(), ImmutableList.of( + binder -> { + binder.bindConstant().annotatedWith(Names.named("serviceName")).to("hybridTest"); + binder.bindConstant().annotatedWith(Names.named("servicePort")).to(0); + binder.bindConstant().annotatedWith(Names.named("tlsServicePort")).to(-1); + binder.install(new CacheModule(prefix)); } ) ); @@ -98,7 +96,6 @@ public void testSanity() final byte[] value3 = Ints.toByteArray(3); - // test put puts to both cache.put(key1, value1); Assert.assertEquals(value1, l1.get(key1)); @@ -137,6 +134,24 @@ public void testSanity() Assert.assertEquals(hits, cache.getStats().getNumHits()); } + // test streaming bulk get with l1 and l2 + { + final List keys = ImmutableList.of(key1, key2, key3); + final List>> res = cache + .getBulk(keys.stream()).collect(Collectors.toList()); + Assert.assertNotNull(res); + Assert.assertEquals(Arrays.asList( + Pair.of(key1, Optional.of(value1)), + Pair.of(key2, Optional.of(value2)), + Pair.of(key3, Optional.of(value3)) + ), res); + + hits += 3; + Assert.assertEquals(0, cache.getStats().getNumMisses()); + Assert.assertEquals(hits, cache.getStats().getNumHits()); + } + + // test bulk get with l1 entries only { final HashSet keys = Sets.newHashSet(key1, key2); @@ -175,5 +190,31 @@ public void testSanity() Assert.assertEquals(++hits, cache.getStats().getNumHits()); Assert.assertEquals(++misses, cache.getStats().getNumMisses()); } + + { + final List keys = ImmutableList.of(key3, key4); + final List>> res = cache + .getBulk(keys.stream()).collect(Collectors.toList()); + Assert.assertNotNull(res); + Assert.assertEquals(Arrays.asList( + Pair.of(key3, Optional.of(value3)), + Pair.of(key4, Optional.empty()) + ), res); + Assert.assertEquals(++misses, cache.getStats().getNumMisses()); + Assert.assertEquals(++hits, cache.getStats().getNumHits()); + } + + { + final List keys = ImmutableList.of(key1, key4); + final List>> res = cache + .getBulk(keys.stream()).collect(Collectors.toList()); + Assert.assertNotNull(res); + Assert.assertEquals(Arrays.asList( + Pair.of(key1, Optional.of(value1)), + Pair.of(key4, Optional.empty()) + ), res); + Assert.assertEquals(++misses, cache.getStats().getNumMisses()); + Assert.assertEquals(++hits, cache.getStats().getNumHits()); + } } } diff --git a/services/src/main/java/org/apache/druid/cli/CliBroker.java b/services/src/main/java/org/apache/druid/cli/CliBroker.java index 009a6fedbc54..2aaaaf7537d6 100644 --- a/services/src/main/java/org/apache/druid/cli/CliBroker.java +++ b/services/src/main/java/org/apache/druid/cli/CliBroker.java @@ -24,6 +24,8 @@ import com.google.inject.Module; import com.google.inject.name.Names; import io.airlift.airline.Command; +import org.apache.druid.guice.ForkJoinPoolProvider; +import org.apache.druid.guice.LifecycleForkJoinPool; import org.apache.druid.client.BrokerSegmentWatcherConfig; import org.apache.druid.client.BrokerServerView; import org.apache.druid.client.CachingClusteredClient; @@ -42,8 +44,10 @@ import org.apache.druid.guice.JsonConfigProvider; import org.apache.druid.guice.LazySingleton; import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.ManageLifecycle; import org.apache.druid.guice.QueryRunnerFactoryModule; import org.apache.druid.guice.QueryableModule; +import org.apache.druid.guice.annotations.Processing; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.RetryQueryRunnerConfig; @@ -62,6 +66,7 @@ import org.eclipse.jetty.server.Server; import java.util.List; +import java.util.concurrent.ForkJoinPool; /** */ @@ -96,6 +101,14 @@ protected List getModules() binder.bind(CachingClusteredClient.class).in(LazySingleton.class); binder.bind(BrokerServerView.class).in(LazySingleton.class); binder.bind(TimelineServerView.class).to(BrokerServerView.class).in(LazySingleton.class); + binder.bind(Key.get(LifecycleForkJoinPool.class, Processing.class)) + .toProvider(new ForkJoinPoolProvider("processing-fjp-%s")) + .in(ManageLifecycle.class); + // Bind the lifecycle key, then bind the lifecycle to the forkjoinpool key so that any extensions that + // want to have their own fork join pool instead of this one can do so. + LifecycleModule.register(binder, LifecycleForkJoinPool.class, Processing.class); + binder.bind(Key.get(ForkJoinPool.class, Processing.class)) + .to(Key.get(LifecycleForkJoinPool.class, Processing.class)); JsonConfigProvider.bind(binder, "druid.broker.cache", CacheConfig.class); binder.install(new CacheModule());