Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ public class WorkerSketchFetcher implements AutoCloseable
private final WorkerClient workerClient;
private final ExecutorService executorService;

public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes)
public WorkerSketchFetcher(
WorkerClient workerClient,
ClusterStatisticsMergeMode clusterStatisticsMergeMode,
int statisticsMaxRetainedBytes
)
{
this.workerClient = workerClient;
this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
Expand All @@ -86,14 +90,14 @@ public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.debug("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
Expand Down Expand Up @@ -128,12 +132,6 @@ CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});

try {
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get();
if (clusterByStatisticsSnapshot == null) {
Expand All @@ -151,12 +149,15 @@ CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
});

return partitionFuture;
}

Expand Down Expand Up @@ -247,11 +248,6 @@ public void submitFetchingTasksForNextTimeChunk()
stageDefinition.getStageNumber(),
timeChunk
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});

try {
ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get();
Expand Down Expand Up @@ -289,8 +285,10 @@ public void submitFetchingTasksForNextTimeChunk()
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
Expand All @@ -46,7 +45,6 @@
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
Expand All @@ -56,7 +54,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
Expand Down Expand Up @@ -107,52 +104,8 @@ public void setUp()
public void tearDown() throws Exception
{
mocks.close();
}

@Test
public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3");
final CountDownLatch latch = new CountDownLatch(workerIds.size());

target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));

// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
latch.countDown();
latch.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());

// Cause a worker to fail instead of returning the result
doAnswer(invocation -> {
latch.countDown();
latch.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshot(eq("2"), any(), anyInt());

CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
);

// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);

Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
if (target != null) {
target.close();
}
}

Expand Down Expand Up @@ -194,54 +147,6 @@ public void test_submitFetcherTask_parallelFetch_mergePerformedCorrectly()
Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow());
}

@Test
public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();

SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();

final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));

// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
barrier.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong());

// Cause a worker in the second time chunk to fail instead of returning the result
doAnswer(invocation -> {
barrier.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(eq("4"), any(), anyInt(), eq(2L));

CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
);

// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);

Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the correct statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(0)).clear();
verify(mergedClusterByStatisticsCollector2, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
}
}

@Test
public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException
Expand Down