Skip to content
2 changes: 1 addition & 1 deletion docs/multi-stage-query/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ The following table lists the context parameters for the MSQ task engine:
| `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 |
| `rowsPerSegment` | INSERT or REPLACE<br /><br />The number of rows per segment to target. The actual number of rows per segment may be somewhat higher or lower than this number. In most cases, use the default. For general information about sizing rows per segment, see [Segment Size Optimization](../operations/segment-optimization.md). | 3,000,000 |
| `indexSpec` | INSERT or REPLACE<br /><br />An [`indexSpec`](../ingestion/ingestion-spec.md#indexspec) to use when generating segments. May be a JSON string or object. See [Front coding](../ingestion/ingestion-spec.md#front-coding) for details on configuring an `indexSpec` with front coding. | See [`indexSpec`](../ingestion/ingestion-spec.md#indexspec). |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `AUTO` |
| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `PARALLEL` |

## Sketch Merging Mode
This section details the advantages and performance of various Cluster By Statistics Merge Modes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ public class ControllerImpl implements Controller
// For live reports. Written by the main controller thread, read by HTTP threads.
private final ConcurrentHashMap<Integer, Integer> stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();


private WorkerSketchFetcher workerSketchFetcher;
// Time at which the query started.
// For live reports. Written by the main controller thread, read by HTTP threads.
Expand Down Expand Up @@ -624,14 +625,21 @@ public void updatePartialKeyStatisticsInformation(int stageNumber, int workerNum
workerSketchFetcher.submitFetcherTask(
completeKeyStatisticsInformation,
workerTaskIds,
stageDef
stageDef,
queryKernel.getWorkerInputsForStage(stageId).workers()
// we only need tasks which are active for this stage.
);

// Add the listener to handle completion.
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, throwable) -> {
addToKernelManipulationQueue(holder -> {
if (throwable != null) {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
log.error("Error while fetching stats for stageId[%s]", stageId);
if (throwable instanceof MSQException) {
holder.failStageForReason(stageId, ((MSQException) throwable).getFault());
} else {
holder.failStageForReason(stageId, UnknownFault.forException(throwable));
}
} else if (clusterByPartitionsEither.isError()) {
holder.failStageForReason(stageId, new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ public ListenableFuture<Void> postWorkOrder(String workerTaskId, WorkOrder workO
}

@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(String workerTaskId, String queryId, int stageNumber)
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
String queryId,
int stageNumber
)
{
return client.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber);
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber));
}

@Override
Expand All @@ -70,7 +74,11 @@ public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSna
long timeChunk
)
{
return client.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk);
return wrap(
workerTaskId,
client,
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,16 +571,37 @@ public void postFinish()
@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
{
return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
if (stageKernelMap.get(stageId) == null) {
throw new ISE("Requested statistics snapshot for non-existent stageId %s.", stageId);
} else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
throw new ISE(
"Requested statistics snapshot is not generated yet for stageId[%s]",
stageId
);
} else {
return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
}
}

@Override
public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
{
ClusterByStatisticsSnapshot snapshot = stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
return snapshot.getSnapshotForTimeChunk(timeChunk);
if (stageKernelMap.get(stageId) == null) {
throw new ISE("Requested statistics snapshot for non-existent stageId[%s].", stageId);
} else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) {
throw new ISE(
"Requested statistics snapshot is not generated yet for stageId[%s]",
stageId
);
} else {
return stageKernelMap.get(stageId)
.getResultKeyStatisticsSnapshot()
.getSnapshotForTimeChunk(timeChunk);
}

}


@Override
public CounterSnapshotsTree getCounters()
{
Expand Down Expand Up @@ -643,7 +664,7 @@ private OutputChannelFactory makeStageOutputChannelFactory(final FrameContext fr
/**
* Decorates the server-wide {@link QueryProcessingPool} such that any Callables and Runnables, not just
* {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added to it.
*
* <p>
* In production, the underlying {@link QueryProcessingPool} pool is set up by
* {@link org.apache.druid.guice.DruidProcessingModule}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.msq.exec;

import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
Expand All @@ -40,7 +41,7 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;
import java.util.stream.Collectors;

/**
* Queues up fetching sketches from workers and progressively generates partitions boundaries.
Expand Down Expand Up @@ -78,7 +79,8 @@ public WorkerSketchFetcher(
public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
List<String> workerTaskIds,
StageDefinition stageDefinition
StageDefinition stageDefinition,
IntSet workersForStage
)
{
ClusterBy clusterBy = stageDefinition.getClusterBy();
Expand All @@ -87,18 +89,31 @@ public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
case SEQUENTIAL:
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
case PARALLEL:
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD
|| completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.info(
"Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
log.info(
"Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics",
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
}
Expand All @@ -111,20 +126,28 @@ public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(
*/
CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(
StageDefinition stageDefinition,
List<String> workerTaskIds
List<String> workerTaskIds,
IntSet workersForStage
)
{
CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new CompletableFuture<>();

// Create a new key statistics collector to merge worker sketches into
final ClusterByStatisticsCollector mergedStatisticsCollector =
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
final int workerCount = workerTaskIds.size();
final int workerCount = workersForStage.size();
// Guarded by synchronized mergedStatisticsCollector
final Set<Integer> finishedWorkers = new HashSet<>();

log.info(
"Fetching stats using %s for stage[%d] for workers[%s] ",
ClusterStatisticsMergeMode.PARALLEL,
stageDefinition.getStageNumber(),
workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))
);

// Submit a task for each worker to fetch statistics
IntStream.range(0, workerCount).forEach(workerNo -> {
workersForStage.forEach(workerNo -> {
executorService.submit(() -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
workerClient.fetchClusterByStatisticsSnapshot(
Expand Down Expand Up @@ -177,6 +200,13 @@ CompletableFuture<Either<Long, ClusterByPartitions>> sequentialTimeChunkMerging(
workerTaskIds,
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
);

log.info(
"Fetching stats using %s for stage[%d] for tasks[%s]",
ClusterStatisticsMergeMode.SEQUENTIAL,
stageDefinition.getStageNumber(),
String.join("", workerTaskIds)
);
sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
return sequentialFetchStage.getPartitionFuture();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.druid.msq.indexing;

import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.commons.lang.mutable.MutableLong;
import org.apache.druid.frame.file.FrameFileHttpResponseHandler;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.Worker;
import org.apache.druid.msq.kernel.StageId;
Expand Down Expand Up @@ -71,7 +73,7 @@ public WorkerChatHandler(TaskToolbox toolbox, Worker worker)

/**
* Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data.
*
* <p>
* See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API.
*/
@GET
Expand Down Expand Up @@ -193,17 +195,30 @@ public Response httpFetchKeyStatistics(
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot;
StageId stageId = new StageId(queryId, stageNumber);
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
try {
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
}
catch (Exception e) {
String errorMessage = StringUtils.format(
"Invalid request for key statistics for query[%s] and stage[%d]",
queryId,
stageNumber
);
log.error(e, errorMessage);
return Response.status(Response.Status.BAD_REQUEST)
.entity(ImmutableMap.<String, Object>of("error", errorMessage))
.build();
}
}

@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response httpSketch(
public Response httpFetchKeyStatisticsWithSnapshot(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
Expand All @@ -213,10 +228,24 @@ public Response httpSketch(
ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper());
ClusterByStatisticsSnapshot snapshotForTimeChunk;
StageId stageId = new StageId(queryId, stageNumber);
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
try {
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
}
catch (Exception e) {
String errorMessage = StringUtils.format(
"Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]",
queryId,
stageNumber,
timeChunk
);
log.error(e, errorMessage);
return Response.status(Response.Status.BAD_REQUEST)
.entity(ImmutableMap.<String, Object>of("error", errorMessage))
.build();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ public class WorkerStageKernel

private WorkerStagePhase phase = WorkerStagePhase.NEW;

// We read this variable in the main thread and the netty threads
@Nullable
private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;

@Nullable
private ClusterByPartitions resultPartitionBoundaries;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.ISE;

import javax.annotation.Nullable;
import java.util.Collections;
Expand Down Expand Up @@ -61,6 +62,9 @@ Map<Long, Bucket> getBuckets()
public ClusterByStatisticsSnapshot getSnapshotForTimeChunk(long timeChunk)
{
Bucket bucket = buckets.get(timeChunk);
if (bucket == null) {
throw new ISE("ClusterByStatistics not present for requested timechunk %s", timeChunk);
}
return new ClusterByStatisticsSnapshot(ImmutableMap.of(timeChunk, bucket), null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class MultiStageQueryContext

public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.AUTO.toString();
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.PARALLEL.toString();
private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;

public static final String CTX_DESTINATION = "destination";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,32 @@ public void testInsertOnFoo1WithTimeFunction()

}

@Test
public void testInsertOnFoo1WithTimeFunctionWithSequential()
{
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();
Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
.put(
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
ClusterStatisticsMergeMode.SEQUENTIAL.toString()
)
.build();

testIngestQuery().setSql(
"insert into foo1 select floor(__time to day) as __time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 PARTITIONED by day clustered by dim1")
.setQueryContext(context)
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(expectedFooSegments())
.setExpectedResultRows(expectedFooRows())
.verifyResults();

}

@Test
public void testInsertOnFoo1WithMultiValueDim()
{
Expand Down
Loading