From 90b664ba1345af615bf9d749f149ed6621414ae6 Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Thu, 1 Dec 2022 17:23:38 +0530 Subject: [PATCH 1/7] Remove stray reference to fix OOM while merging sketches --- .../druid/msq/exec/WorkerSketchFetcher.java | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index c4118a9d38e0..47ccadbefe83 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -39,7 +39,9 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.stream.IntStream; /** @@ -86,14 +88,14 @@ public CompletableFuture> submitFetcherTask( return inMemoryFullSketchMerging(stageDefinition, workerTaskIds); case AUTO: if (clusterBy.getBucketByCount() == 0) { - log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId()); + log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId()); // If there is no time clustering, there is no scope for sequential merge return inMemoryFullSketchMerging(stageDefinition, workerTaskIds); } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) { - log.debug("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId()); + log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId()); return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds); } - log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId()); + log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId()); return inMemoryFullSketchMerging(stageDefinition, workerTaskIds); default: throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode); @@ -118,6 +120,14 @@ CompletableFuture> inMemoryFullSketchMerging( final int workerCount = workerTaskIds.size(); // Guarded by synchronized mergedStatisticsCollector final Set finishedWorkers = new HashSet<>(); + final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); + partitionFuture.whenComplete((result, exception) -> { + if (exception != null || (result != null && result.isError())) { + for (Future snapshotFuture : futuresToCancel) { + snapshotFuture.cancel(true); + } + } + }); // Submit a task for each worker to fetch statistics IntStream.range(0, workerCount).forEach(workerNo -> { @@ -128,12 +138,7 @@ CompletableFuture> inMemoryFullSketchMerging( stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber() ); - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - snapshotFuture.cancel(true); - } - }); - + futuresToCancel.add(snapshotFuture); try { ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get(); if (clusterByStatisticsSnapshot == null) { @@ -141,6 +146,7 @@ CompletableFuture> inMemoryFullSketchMerging( } synchronized (mergedStatisticsCollector) { mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot); + futuresToCancel.remove(snapshotFuture); finishedWorkers.add(workerNo); if (finishedWorkers.size() == workerCount) { @@ -231,6 +237,14 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes); // Guarded by synchronized mergedStatisticsCollector Set finishedWorkers = new HashSet<>(); + final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); + partitionFuture.whenComplete((result, exception) -> { + if (exception != null || (result != null && result.isError())) { + for (Future snapshotFuture : futuresToCancel) { + snapshotFuture.cancel(true); + } + } + }); log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", stageDefinition.getId().getQueryId(), @@ -247,11 +261,7 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.getStageNumber(), timeChunk ); - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - snapshotFuture.cancel(true); - } - }); + futuresToCancel.add(snapshotFuture); try { ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get(); @@ -260,6 +270,7 @@ public void submitFetchingTasksForNextTimeChunk() } synchronized (mergedStatisticsCollector) { mergedStatisticsCollector.addAll(snapshotForTimeChunk); + futuresToCancel.remove(snapshotFuture); finishedWorkers.add(workerNo); if (finishedWorkers.size() == workerIdsWithTimeChunk.size()) { From 704197d6d35b250ad879eaa4cfb6fc492f3141a6 Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Thu, 1 Dec 2022 18:49:02 +0530 Subject: [PATCH 2/7] Update future to add result from executor service --- .../druid/msq/exec/WorkerSketchFetcher.java | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index 47ccadbefe83..d1a25a846bbd 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -120,25 +120,17 @@ CompletableFuture> inMemoryFullSketchMerging( final int workerCount = workerTaskIds.size(); // Guarded by synchronized mergedStatisticsCollector final Set finishedWorkers = new HashSet<>(); - final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - for (Future snapshotFuture : futuresToCancel) { - snapshotFuture.cancel(true); - } - } - }); + final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); // Submit a task for each worker to fetch statistics IntStream.range(0, workerCount).forEach(workerNo -> { - executorService.submit(() -> { + futuresToCancel.add(executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshot( workerTaskIds.get(workerNo), stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber() ); - futuresToCancel.add(snapshotFuture); try { ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get(); if (clusterByStatisticsSnapshot == null) { @@ -146,7 +138,6 @@ CompletableFuture> inMemoryFullSketchMerging( } synchronized (mergedStatisticsCollector) { mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot); - futuresToCancel.remove(snapshotFuture); finishedWorkers.add(workerNo); if (finishedWorkers.size() == workerCount) { @@ -161,7 +152,17 @@ CompletableFuture> inMemoryFullSketchMerging( mergedStatisticsCollector.clear(); } } - }); + })); + }); + + partitionFuture.whenComplete((result, exception) -> { + if (exception != null || (result != null && result.isError())) { + for (Future future : futuresToCancel) { + if (!future.isDone()) { + future.cancel(true); + } + } + } }); return partitionFuture; } @@ -237,14 +238,7 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes); // Guarded by synchronized mergedStatisticsCollector Set finishedWorkers = new HashSet<>(); - final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - for (Future snapshotFuture : futuresToCancel) { - snapshotFuture.cancel(true); - } - } - }); + final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", stageDefinition.getId().getQueryId(), @@ -253,7 +247,7 @@ public void submitFetchingTasksForNextTimeChunk() // Submits a task for every worker which has a certain time chunk for (int workerNo : workerIdsWithTimeChunk) { - executorService.submit(() -> { + futuresToCancel.add(executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( workerTaskIds.get(workerNo), @@ -261,7 +255,6 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.getStageNumber(), timeChunk ); - futuresToCancel.add(snapshotFuture); try { ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get(); @@ -270,7 +263,6 @@ public void submitFetchingTasksForNextTimeChunk() } synchronized (mergedStatisticsCollector) { mergedStatisticsCollector.addAll(snapshotForTimeChunk); - futuresToCancel.remove(snapshotFuture); finishedWorkers.add(workerNo); if (finishedWorkers.size() == workerIdsWithTimeChunk.size()) { @@ -304,8 +296,18 @@ public void submitFetchingTasksForNextTimeChunk() mergedStatisticsCollector.clear(); } } - }); + })); } + + partitionFuture.whenComplete((result, exception) -> { + if (exception != null || (result != null && result.isError())) { + for (Future future : futuresToCancel) { + if (!future.isDone()) { + future.cancel(true); + } + } + } + }); } } From cd22ebf4fab9f7371f87446f0e52df95fa1512bf Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Sun, 4 Dec 2022 23:00:02 +0530 Subject: [PATCH 3/7] Update tests and address review comments --- .../druid/msq/exec/WorkerSketchFetcher.java | 48 ++++++++++++---- .../msq/exec/WorkerSketchFetcherTest.java | 57 +++++++++++++------ 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index d1a25a846bbd..e0a537d74794 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.exec; +import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterByPartition; @@ -37,9 +38,10 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.stream.IntStream; @@ -61,7 +63,11 @@ public class WorkerSketchFetcher implements AutoCloseable private final WorkerClient workerClient; private final ExecutorService executorService; - public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes) + public WorkerSketchFetcher( + WorkerClient workerClient, + ClusterStatisticsMergeMode clusterStatisticsMergeMode, + int statisticsMaxRetainedBytes + ) { this.workerClient = workerClient; this.clusterStatisticsMergeMode = clusterStatisticsMergeMode; @@ -69,6 +75,20 @@ public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes; } + @VisibleForTesting + WorkerSketchFetcher( + WorkerClient workerClient, + ClusterStatisticsMergeMode clusterStatisticsMergeMode, + int statisticsMaxRetainedBytes, + ExecutorService executorService + ) + { + this.workerClient = workerClient; + this.clusterStatisticsMergeMode = clusterStatisticsMergeMode; + this.executorService = executorService; + this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes; + } + /** * Submits a request to fetch and generate partitions for the given worker statistics and returns a future for it. It * decides based on the statistics if it should fetch sketches one by one or together. @@ -120,11 +140,11 @@ CompletableFuture> inMemoryFullSketchMerging( final int workerCount = workerTaskIds.size(); // Guarded by synchronized mergedStatisticsCollector final Set finishedWorkers = new HashSet<>(); - final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); + final Queue> executorFutures = new ConcurrentLinkedQueue<>(); // Submit a task for each worker to fetch statistics IntStream.range(0, workerCount).forEach(workerNo -> { - futuresToCancel.add(executorService.submit(() -> { + executorFutures.add(executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshot( workerTaskIds.get(workerNo), @@ -148,8 +168,10 @@ CompletableFuture> inMemoryFullSketchMerging( } catch (Exception e) { synchronized (mergedStatisticsCollector) { - partitionFuture.completeExceptionally(e); - mergedStatisticsCollector.clear(); + if (!partitionFuture.isDone()) { + partitionFuture.completeExceptionally(e); + mergedStatisticsCollector.clear(); + } } } })); @@ -157,7 +179,7 @@ CompletableFuture> inMemoryFullSketchMerging( partitionFuture.whenComplete((result, exception) -> { if (exception != null || (result != null && result.isError())) { - for (Future future : futuresToCancel) { + for (Future future : executorFutures) { if (!future.isDone()) { future.cancel(true); } @@ -238,7 +260,7 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes); // Guarded by synchronized mergedStatisticsCollector Set finishedWorkers = new HashSet<>(); - final Set> futuresToCancel = ConcurrentHashMap.newKeySet(); + final Queue> executorFutures = new ConcurrentLinkedQueue<>(); log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", stageDefinition.getId().getQueryId(), @@ -247,7 +269,7 @@ public void submitFetchingTasksForNextTimeChunk() // Submits a task for every worker which has a certain time chunk for (int workerNo : workerIdsWithTimeChunk) { - futuresToCancel.add(executorService.submit(() -> { + executorFutures.add(executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( workerTaskIds.get(workerNo), @@ -292,8 +314,10 @@ public void submitFetchingTasksForNextTimeChunk() } catch (Exception e) { synchronized (mergedStatisticsCollector) { - partitionFuture.completeExceptionally(e); - mergedStatisticsCollector.clear(); + if (!partitionFuture.isDone()) { + partitionFuture.completeExceptionally(e); + mergedStatisticsCollector.clear(); + } } } })); @@ -301,7 +325,7 @@ public void submitFetchingTasksForNextTimeChunk() partitionFuture.whenComplete((result, exception) -> { if (exception != null || (result != null && result.isError())) { - for (Future future : futuresToCancel) { + for (Future future : executorFutures) { if (!future.isDone()) { future.cancel(true); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 54c9a792e558..8aa58c4a7082 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -23,12 +23,12 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.RowKey; import org.apache.druid.java.util.common.Either; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; @@ -51,6 +51,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import static org.easymock.EasyMock.mock; import static org.mockito.ArgumentMatchers.any; @@ -80,6 +82,7 @@ public class WorkerSketchFetcherTest private WorkerClient workerClient; private ClusterByPartitions expectedPartitions1; private ClusterByPartitions expectedPartitions2; + private ExecutorService executorService; private AutoCloseable mocks; private WorkerSketchFetcher target; @@ -101,6 +104,7 @@ public void setUp() mergedClusterByStatisticsCollector1, mergedClusterByStatisticsCollector2 ).when(stageDefinition).createResultKeyStatisticsCollector(anyInt()); + executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); } @After @@ -113,20 +117,30 @@ public void tearDown() throws Exception public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception { // Store futures in a queue - final Queue> futureQueue = new ConcurrentLinkedQueue<>(); + final Queue> futureQueue = new ConcurrentLinkedQueue<>(); final List workerIds = ImmutableList.of("0", "1", "2", "3"); final CountDownLatch latch = new CountDownLatch(workerIds.size()); - target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000)); + target = spy( + new WorkerSketchFetcher( + workerClient, + ClusterStatisticsMergeMode.PARALLEL, + 300_000_000, + executorService + ) + ); + + // When submitting futures from the executor, add it to the list first. + doAnswer(invocation -> { + Future future = spy((Future) invocation.callRealMethod()); + futureQueue.add(future); + return future; + }).when(executorService).submit(any(Runnable.class)); - // When fetching snapshots, return a mock and add future to queue doAnswer(invocation -> { - ListenableFuture snapshotListenableFuture = - spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class))); - futureQueue.add(snapshotListenableFuture); latch.countDown(); latch.await(); - return snapshotListenableFuture; + return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); // Cause a worker to fail instead of returning the result @@ -151,7 +165,7 @@ public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCan verify(mergedClusterByStatisticsCollector1, times(1)).clear(); // Verify that other task futures were requested to be cancelled. Assert.assertFalse(futureQueue.isEmpty()); - for (ListenableFuture snapshotFuture : futureQueue) { + for (Future snapshotFuture : futureQueue) { verify(snapshotFuture, times(1)).cancel(eq(true)); } } @@ -198,21 +212,32 @@ public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly() public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception { // Store futures in a queue - final Queue> futureQueue = new ConcurrentLinkedQueue<>(); + final Queue> futureQueue = new ConcurrentLinkedQueue<>(); SortedMap> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4)); doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap(); final CyclicBarrier barrier = new CyclicBarrier(3); - target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000)); + target = spy( + new WorkerSketchFetcher( + workerClient, + ClusterStatisticsMergeMode.SEQUENTIAL, + 300_000_000, + executorService + ) + ); + + // When submitting futures from the executor, add it to the list first. + doAnswer(invocation -> { + Future future = spy((Future) invocation.callRealMethod()); + futureQueue.add(future); + return future; + }).when(executorService).submit(any(Runnable.class)); // When fetching snapshots, return a mock and add future to queue doAnswer(invocation -> { - ListenableFuture snapshotListenableFuture = - spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class))); - futureQueue.add(snapshotListenableFuture); barrier.await(); - return snapshotListenableFuture; + return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong()); // Cause a worker in the second time chunk to fail instead of returning the result @@ -237,7 +262,7 @@ public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldC verify(mergedClusterByStatisticsCollector2, times(1)).clear(); // Verify that other task futures were requested to be cancelled. Assert.assertFalse(futureQueue.isEmpty()); - for (ListenableFuture snapshotFuture : futureQueue) { + for (Future snapshotFuture : futureQueue) { verify(snapshotFuture, times(1)).cancel(eq(true)); } } From 65ab4746210421d08ccebf5817faa66e9f349617 Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Mon, 5 Dec 2022 21:04:46 +0530 Subject: [PATCH 4/7] Address review comments --- .../druid/msq/exec/WorkerSketchFetcher.java | 6 +-- .../msq/exec/WorkerSketchFetcherTest.java | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index e0a537d74794..c4536697e39e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -38,10 +38,8 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Queue; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.stream.IntStream; @@ -140,7 +138,7 @@ CompletableFuture> inMemoryFullSketchMerging( final int workerCount = workerTaskIds.size(); // Guarded by synchronized mergedStatisticsCollector final Set finishedWorkers = new HashSet<>(); - final Queue> executorFutures = new ConcurrentLinkedQueue<>(); + final List> executorFutures = new ArrayList<>(); // Submit a task for each worker to fetch statistics IntStream.range(0, workerCount).forEach(workerNo -> { @@ -260,7 +258,7 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes); // Guarded by synchronized mergedStatisticsCollector Set finishedWorkers = new HashSet<>(); - final Queue> executorFutures = new ConcurrentLinkedQueue<>(); + final List> executorFutures = new ArrayList<>(); log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", stageDefinition.getId().getQueryId(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 8aa58c4a7082..3d31572f7b43 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -138,18 +138,19 @@ public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCan }).when(executorService).submit(any(Runnable.class)); doAnswer(invocation -> { - latch.countDown(); - latch.await(); - return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); + String workerId = invocation.getArgument(0); + if ("2".equals(workerId)) { + // Cause a worker to fail instead of returning the result + latch.countDown(); + latch.await(); + return Futures.immediateFailedFuture(new InterruptedException("interrupted")); + } else { + latch.countDown(); + latch.await(); + return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); + } }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); - // Cause a worker to fail instead of returning the result - doAnswer(invocation -> { - latch.countDown(); - latch.await(); - return Futures.immediateFailedFuture(new InterruptedException("interrupted")); - }).when(workerClient).fetchClusterByStatisticsSnapshot(eq("2"), any(), anyInt()); - CompletableFuture> eitherCompletableFuture = target.submitFetcherTask( completeKeyStatisticsInformation, workerIds, @@ -236,16 +237,18 @@ public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldC // When fetching snapshots, return a mock and add future to queue doAnswer(invocation -> { - barrier.await(); - return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); + String workerId = invocation.getArgument(0); + long timeChunk = invocation.getArgument(3); + // Cause a worker in the second time chunk to fail instead of returning the result + if ("4".equals(workerId) && timeChunk == 2L) { + barrier.await(); + return Futures.immediateFailedFuture(new InterruptedException("interrupted")); + } else { + barrier.await(); + return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); + } }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong()); - // Cause a worker in the second time chunk to fail instead of returning the result - doAnswer(invocation -> { - barrier.await(); - return Futures.immediateFailedFuture(new InterruptedException("interrupted")); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(eq("4"), any(), anyInt(), eq(2L)); - CompletableFuture> eitherCompletableFuture = target.submitFetcherTask( completeKeyStatisticsInformation, ImmutableList.of("0", "1", "2", "3", "4"), From 22e286f951f8d25a912723d597f4743cff3bee0e Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Tue, 6 Dec 2022 16:56:23 +0530 Subject: [PATCH 5/7] Moved mock --- .../org/apache/druid/msq/exec/WorkerSketchFetcherTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 3d31572f7b43..0427237bcb05 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -82,7 +82,6 @@ public class WorkerSketchFetcherTest private WorkerClient workerClient; private ClusterByPartitions expectedPartitions1; private ClusterByPartitions expectedPartitions2; - private ExecutorService executorService; private AutoCloseable mocks; private WorkerSketchFetcher target; @@ -104,7 +103,6 @@ public void setUp() mergedClusterByStatisticsCollector1, mergedClusterByStatisticsCollector2 ).when(stageDefinition).createResultKeyStatisticsCollector(anyInt()); - executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); } @After @@ -121,6 +119,7 @@ public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCan final List workerIds = ImmutableList.of("0", "1", "2", "3"); final CountDownLatch latch = new CountDownLatch(workerIds.size()); + ExecutorService executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); target = spy( new WorkerSketchFetcher( workerClient, @@ -219,6 +218,7 @@ public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldC doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap(); final CyclicBarrier barrier = new CyclicBarrier(3); + ExecutorService executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); target = spy( new WorkerSketchFetcher( workerClient, From 69abe9670df93202471dfa542f99f086d019b4c5 Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Tue, 6 Dec 2022 17:29:10 +0530 Subject: [PATCH 6/7] Close threadpool on teardown --- .../org/apache/druid/msq/exec/WorkerSketchFetcherTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 0427237bcb05..501ba96e6bdd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -109,6 +109,9 @@ public void setUp() public void tearDown() throws Exception { mocks.close(); + if (target != null) { + target.close(); + } } @Test From 0ab623bfc90fcbdb6019204fe81dc0c946ef23de Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Wed, 7 Dec 2022 12:39:48 +0530 Subject: [PATCH 7/7] Remove worker task cancel --- .../druid/msq/exec/WorkerSketchFetcher.java | 45 +------ .../msq/exec/WorkerSketchFetcherTest.java | 126 ------------------ 2 files changed, 4 insertions(+), 167 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index c4536697e39e..dc6f21990587 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -19,7 +19,6 @@ package org.apache.druid.msq.exec; -import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterByPartition; @@ -41,7 +40,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import java.util.stream.IntStream; /** @@ -73,20 +71,6 @@ public WorkerSketchFetcher( this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes; } - @VisibleForTesting - WorkerSketchFetcher( - WorkerClient workerClient, - ClusterStatisticsMergeMode clusterStatisticsMergeMode, - int statisticsMaxRetainedBytes, - ExecutorService executorService - ) - { - this.workerClient = workerClient; - this.clusterStatisticsMergeMode = clusterStatisticsMergeMode; - this.executorService = executorService; - this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes; - } - /** * Submits a request to fetch and generate partitions for the given worker statistics and returns a future for it. It * decides based on the statistics if it should fetch sketches one by one or together. @@ -138,11 +122,10 @@ CompletableFuture> inMemoryFullSketchMerging( final int workerCount = workerTaskIds.size(); // Guarded by synchronized mergedStatisticsCollector final Set finishedWorkers = new HashSet<>(); - final List> executorFutures = new ArrayList<>(); // Submit a task for each worker to fetch statistics IntStream.range(0, workerCount).forEach(workerNo -> { - executorFutures.add(executorService.submit(() -> { + executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshot( workerTaskIds.get(workerNo), @@ -172,18 +155,9 @@ CompletableFuture> inMemoryFullSketchMerging( } } } - })); + }); }); - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - for (Future future : executorFutures) { - if (!future.isDone()) { - future.cancel(true); - } - } - } - }); return partitionFuture; } @@ -258,7 +232,6 @@ public void submitFetchingTasksForNextTimeChunk() stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes); // Guarded by synchronized mergedStatisticsCollector Set finishedWorkers = new HashSet<>(); - final List> executorFutures = new ArrayList<>(); log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", stageDefinition.getId().getQueryId(), @@ -267,7 +240,7 @@ public void submitFetchingTasksForNextTimeChunk() // Submits a task for every worker which has a certain time chunk for (int workerNo : workerIdsWithTimeChunk) { - executorFutures.add(executorService.submit(() -> { + executorService.submit(() -> { ListenableFuture snapshotFuture = workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( workerTaskIds.get(workerNo), @@ -318,18 +291,8 @@ public void submitFetchingTasksForNextTimeChunk() } } } - })); + }); } - - partitionFuture.whenComplete((result, exception) -> { - if (exception != null || (result != null && result.isError())) { - for (Future future : executorFutures) { - if (!future.isDone()) { - future.cancel(true); - } - } - } - }); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 501ba96e6bdd..83fb73043bd9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -28,7 +28,6 @@ import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.RowKey; import org.apache.druid.java.util.common.Either; -import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; @@ -46,19 +45,15 @@ import java.util.Set; import java.util.SortedMap; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import static org.easymock.EasyMock.mock; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -114,65 +109,6 @@ public void tearDown() throws Exception } } - @Test - public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception - { - // Store futures in a queue - final Queue> futureQueue = new ConcurrentLinkedQueue<>(); - final List workerIds = ImmutableList.of("0", "1", "2", "3"); - final CountDownLatch latch = new CountDownLatch(workerIds.size()); - - ExecutorService executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); - target = spy( - new WorkerSketchFetcher( - workerClient, - ClusterStatisticsMergeMode.PARALLEL, - 300_000_000, - executorService - ) - ); - - // When submitting futures from the executor, add it to the list first. - doAnswer(invocation -> { - Future future = spy((Future) invocation.callRealMethod()); - futureQueue.add(future); - return future; - }).when(executorService).submit(any(Runnable.class)); - - doAnswer(invocation -> { - String workerId = invocation.getArgument(0); - if ("2".equals(workerId)) { - // Cause a worker to fail instead of returning the result - latch.countDown(); - latch.await(); - return Futures.immediateFailedFuture(new InterruptedException("interrupted")); - } else { - latch.countDown(); - latch.await(); - return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); - } - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); - - CompletableFuture> eitherCompletableFuture = target.submitFetcherTask( - completeKeyStatisticsInformation, - workerIds, - stageDefinition - ); - - // Assert that the final result is failed and all other task futures are also cancelled. - Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join); - Thread.sleep(1000); - - Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally()); - // Verify that the statistics collector was cleared due to the error. - verify(mergedClusterByStatisticsCollector1, times(1)).clear(); - // Verify that other task futures were requested to be cancelled. - Assert.assertFalse(futureQueue.isEmpty()); - for (Future snapshotFuture : futureQueue) { - verify(snapshotFuture, times(1)).cancel(eq(true)); - } - } - @Test public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly() throws ExecutionException, InterruptedException @@ -211,68 +147,6 @@ public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly() Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow()); } - @Test - public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception - { - // Store futures in a queue - final Queue> futureQueue = new ConcurrentLinkedQueue<>(); - - SortedMap> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4)); - doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap(); - - final CyclicBarrier barrier = new CyclicBarrier(3); - ExecutorService executorService = spy(Execs.multiThreaded(4, "SketchFetcherThreadPool-%d")); - target = spy( - new WorkerSketchFetcher( - workerClient, - ClusterStatisticsMergeMode.SEQUENTIAL, - 300_000_000, - executorService - ) - ); - - // When submitting futures from the executor, add it to the list first. - doAnswer(invocation -> { - Future future = spy((Future) invocation.callRealMethod()); - futureQueue.add(future); - return future; - }).when(executorService).submit(any(Runnable.class)); - - // When fetching snapshots, return a mock and add future to queue - doAnswer(invocation -> { - String workerId = invocation.getArgument(0); - long timeChunk = invocation.getArgument(3); - // Cause a worker in the second time chunk to fail instead of returning the result - if ("4".equals(workerId) && timeChunk == 2L) { - barrier.await(); - return Futures.immediateFailedFuture(new InterruptedException("interrupted")); - } else { - barrier.await(); - return Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)); - } - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong()); - - CompletableFuture> eitherCompletableFuture = target.submitFetcherTask( - completeKeyStatisticsInformation, - ImmutableList.of("0", "1", "2", "3", "4"), - stageDefinition - ); - - // Assert that the final result is failed and all other task futures are also cancelled. - Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join); - Thread.sleep(1000); - - Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally()); - // Verify that the correct statistics collector was cleared due to the error. - verify(mergedClusterByStatisticsCollector1, times(0)).clear(); - verify(mergedClusterByStatisticsCollector2, times(1)).clear(); - // Verify that other task futures were requested to be cancelled. - Assert.assertFalse(futureQueue.isEmpty()); - for (Future snapshotFuture : futureQueue) { - verify(snapshotFuture, times(1)).cancel(eq(true)); - } - } - @Test public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly() throws ExecutionException, InterruptedException