Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ public class WorkerSketchFetcher implements AutoCloseable
private final WorkerClient workerClient;
private final ExecutorService executorService;

public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes)
public WorkerSketchFetcher(
WorkerClient workerClient,
ClusterStatisticsMergeMode clusterStatisticsMergeMode,
int statisticsMaxRetainedBytes
)
{
this.workerClient = workerClient;
this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
Expand All @@ -86,14 +90,14 @@ public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.debug("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
Expand Down Expand Up @@ -128,12 +132,6 @@ CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});

try {
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get();
if (clusterByStatisticsSnapshot == null) {
Expand All @@ -151,12 +149,15 @@ CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
});

return partitionFuture;
}

Expand Down Expand Up @@ -247,11 +248,6 @@ public void submitFetchingTasksForNextTimeChunk()
stageDefinition.getStageNumber(),
timeChunk
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});

try {
ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get();
Expand Down Expand Up @@ -289,8 +285,10 @@ public void submitFetchingTasksForNextTimeChunk()
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
Expand All @@ -46,7 +45,6 @@
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
Expand All @@ -56,7 +54,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
Expand Down Expand Up @@ -107,52 +104,8 @@ public void setUp()
public void tearDown() throws Exception
{
mocks.close();
}

@Test
public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3");
final CountDownLatch latch = new CountDownLatch(workerIds.size());

target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));

// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
latch.countDown();
latch.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());

// Cause a worker to fail instead of returning the result
doAnswer(invocation -> {
latch.countDown();
latch.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshot(eq("2"), any(), anyInt());

CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
);

// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);

Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
if (target != null) {
target.close();
}
}

Expand Down Expand Up @@ -194,54 +147,6 @@ public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly()
Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow());
}

@Test
public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();

SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();

final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));

// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
barrier.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong());

// Cause a worker in the second time chunk to fail instead of returning the result
doAnswer(invocation -> {
barrier.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(eq("4"), any(), anyInt(), eq(2L));

CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
);

// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);

Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the correct statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(0)).clear();
verify(mergedClusterByStatisticsCollector2, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
}
}

@Test
public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
Expand Down