Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.FutureUtils;
Expand Down Expand Up @@ -203,6 +204,7 @@ public int getChannelNumber(int rowNumber, int numRows, int numChannels)
private final List<KeyColumn> sortKey = ImmutableList.of(new KeyColumn(KEY, KeyOrder.ASCENDING));

private List<List<Frame>> channelFrames;
private ListeningExecutorService innerExec;
private FrameProcessorExecutor exec;
private List<BlockingQueueFrameChannel> channels;

Expand All @@ -226,7 +228,7 @@ public void setupTrial()
frameReader = FrameReader.create(signature);

exec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
innerExec = MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat(getClass().getSimpleName()))
)
);
Expand Down Expand Up @@ -335,8 +337,8 @@ public void setupInvocation() throws IOException
@TearDown(Level.Trial)
public void tearDown() throws Exception
{
exec.getExecutorService().shutdownNow();
if (!exec.getExecutorService().awaitTermination(1, TimeUnit.MINUTES)) {
innerExec.shutdownNow();
if (!innerExec.awaitTermination(1, TimeUnit.MINUTES)) {
throw new ISE("Could not terminate executor after 1 minute");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Injector;
import org.apache.druid.indexing.common.TaskLockType;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.input.InputSpecSlicer;
import org.apache.druid.msq.input.table.SegmentsInputSlice;
import org.apache.druid.msq.input.table.TableInputSpec;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig;
import org.apache.druid.msq.querykit.QueryKit;
import org.apache.druid.msq.querykit.QueryKitSpec;
import org.apache.druid.query.Query;
import org.apache.druid.server.DruidNode;

/**
Expand All @@ -41,7 +44,7 @@ public interface ControllerContext
/**
* Configuration for {@link org.apache.druid.msq.kernel.controller.ControllerQueryKernel}.
*/
ControllerQueryKernelConfig queryKernelConfig(MSQSpec querySpec, QueryDefinition queryDef);
ControllerQueryKernelConfig queryKernelConfig(String queryId, MSQSpec querySpec);

/**
* Callback from the controller implementation to "register" the controller. Used in the indexing task implementation
Expand Down Expand Up @@ -73,20 +76,25 @@ public interface ControllerContext
/**
* Provides an {@link InputSpecSlicer} that slices {@link TableInputSpec} into {@link SegmentsInputSlice}.
*/
InputSpecSlicer newTableInputSpecSlicer();
InputSpecSlicer newTableInputSpecSlicer(WorkerManager workerManager);

/**
* Provide access to segment actions in the Overlord. Only called for ingestion queries, i.e., where
* {@link MSQSpec#getDestination()} is {@link org.apache.druid.msq.indexing.destination.DataSourceMSQDestination}.
*/
TaskActionClient taskActionClient();

/**
* Task lock type.
*/
TaskLockType taskLockType();

/**
* Provides services about workers: starting, canceling, obtaining status.
*
* @param queryId query ID
* @param querySpec query spec
* @param queryKernelConfig config from {@link #queryKernelConfig(MSQSpec, QueryDefinition)}
* @param queryKernelConfig config from {@link #queryKernelConfig(String, MSQSpec)}
* @param workerFailureListener listener that receives callbacks when workers fail
*/
WorkerManager newWorkerManager(
Expand All @@ -100,4 +108,15 @@ WorkerManager newWorkerManager(
* Client for communicating with workers.
*/
WorkerClient newWorkerClient();

/**
* Create a {@link QueryKitSpec}. This method provides controller contexts a way to customize parameters around the
* number of workers and partitions.
*/
QueryKitSpec makeQueryKitSpec(
QueryKit<Query<?>> queryKit,
String queryId,
MSQSpec querySpec,
ControllerQueryKernelConfig queryKernelConfig
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
import org.apache.druid.msq.kernel.controller.WorkerInputs;
import org.apache.druid.msq.querykit.MultiQueryKit;
import org.apache.druid.msq.querykit.QueryKit;
import org.apache.druid.msq.querykit.QueryKitSpec;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.WindowOperatorQueryKit;
Expand Down Expand Up @@ -224,6 +225,7 @@
public class ControllerImpl implements Controller
{
private static final Logger log = new Logger(ControllerImpl.class);
private static final String RESULT_READER_CANCELLATION_ID = "result-reader";

private final String queryId;
private final MSQSpec querySpec;
Expand Down Expand Up @@ -364,7 +366,7 @@ private void runInternal(final QueryListener queryListener, final Closer closer)

// Execution-related: run the multi-stage QueryDefinition.
final InputSpecSlicerFactory inputSpecSlicerFactory =
makeInputSpecSlicerFactory(context.newTableInputSpecSlicer());
makeInputSpecSlicerFactory(context.newTableInputSpecSlicer(workerManager));

final Pair<ControllerQueryKernel, ListenableFuture<?>> queryRunResult =
new RunQueryUntilDone(
Expand Down Expand Up @@ -560,12 +562,12 @@ public void addToKernelManipulationQueue(Consumer<ControllerQueryKernel> kernelC
private QueryDefinition initializeQueryDefAndState(final Closer closer)
{
this.selfDruidNode = context.selfNode();
this.netClient = new ExceptionWrappingWorkerClient(context.newWorkerClient());
closer.register(netClient);
this.netClient = closer.register(new ExceptionWrappingWorkerClient(context.newWorkerClient()));
this.queryKernelConfig = context.queryKernelConfig(queryId, querySpec);

final QueryContext queryContext = querySpec.getQuery().context();
final QueryDefinition queryDef = makeQueryDefinition(
queryId(),
makeQueryControllerToolKit(),
context.makeQueryKitSpec(makeQueryControllerToolKit(), queryId, querySpec, queryKernelConfig),
querySpec,
context.jsonMapper(),
resultsContext
Expand All @@ -587,7 +589,6 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
QueryValidator.validateQueryDef(queryDef);
queryDefRef.set(queryDef);

queryKernelConfig = context.queryKernelConfig(querySpec, queryDef);
workerManager = context.newWorkerManager(
queryId,
querySpec,
Expand All @@ -612,7 +613,7 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
);
}

final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context());
final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(queryContext);
this.faultsExceededChecker = new FaultsExceededChecker(
ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions)
);
Expand All @@ -624,7 +625,7 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
stageDefinition.getId().getStageNumber(),
finalizeClusterStatisticsMergeMode(
stageDefinition,
MultiStageQueryContext.getClusterStatisticsMergeMode(querySpec.getQuery().context())
MultiStageQueryContext.getClusterStatisticsMergeMode(queryContext)
)
)
);
Expand Down Expand Up @@ -920,7 +921,7 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecs(
destination,
partitionBoundaries,
keyReader,
MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(querySpec.getQuery().getContext()), false),
context.taskLockType(),
isStageOutputEmpty
);
}
Expand Down Expand Up @@ -1191,7 +1192,7 @@ private Int2ObjectMap<Object> makeWorkerFactoryInfosForStage(
}

@SuppressWarnings("rawtypes")
private QueryKit makeQueryControllerToolKit()
private QueryKit<Query<?>> makeQueryControllerToolKit()
{
final Map<Class<? extends Query>, QueryKit> kitMap =
ImmutableMap.<Class<? extends Query>, QueryKit>builder()
Expand Down Expand Up @@ -1328,10 +1329,7 @@ private void publishAllSegments(
(DataSourceMSQDestination) querySpec.getDestination();
final Set<DataSegment> segmentsWithTombstones = new HashSet<>(segments);
int numTombstones = 0;
final TaskLockType taskLockType = MultiStageQueryContext.validateAndGetTaskLockType(
QueryContext.of(querySpec.getQuery().getContext()),
destination.isReplaceTimeChunks()
);
final TaskLockType taskLockType = context.taskLockType();

if (destination.isReplaceTimeChunks()) {
final List<Interval> intervalsToDrop = findIntervalsToDrop(Preconditions.checkNotNull(segments, "segments"));
Expand Down Expand Up @@ -1715,8 +1713,7 @@ private void cleanUpDurableStorageIfNeeded()

@SuppressWarnings("unchecked")
private static QueryDefinition makeQueryDefinition(
final String queryId,
@SuppressWarnings("rawtypes") final QueryKit toolKit,
final QueryKitSpec queryKitSpec,
final MSQSpec querySpec,
final ObjectMapper jsonMapper,
final ResultsContext resultsContext
Expand All @@ -1725,11 +1722,11 @@ private static QueryDefinition makeQueryDefinition(
final MSQTuningConfig tuningConfig = querySpec.getTuningConfig();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
final Query<?> queryToPlan;
final ShuffleSpecFactory shuffleSpecFactory;
final ShuffleSpecFactory resultShuffleSpecFactory;

if (MSQControllerTask.isIngestion(querySpec)) {
shuffleSpecFactory = querySpec.getDestination()
.getShuffleSpecFactory(tuningConfig.getRowsPerSegment());
resultShuffleSpecFactory = querySpec.getDestination()
.getShuffleSpecFactory(tuningConfig.getRowsPerSegment());

if (!columnMappings.hasUniqueOutputColumnNames()) {
// We do not expect to hit this case in production, because the SQL validator checks that column names
Expand All @@ -1753,7 +1750,7 @@ private static QueryDefinition makeQueryDefinition(
queryToPlan = querySpec.getQuery();
}
} else {
shuffleSpecFactory =
resultShuffleSpecFactory =
querySpec.getDestination()
.getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()));
queryToPlan = querySpec.getQuery();
Expand All @@ -1762,12 +1759,10 @@ private static QueryDefinition makeQueryDefinition(
final QueryDefinition queryDef;

try {
queryDef = toolKit.makeQueryDefinition(
queryId,
queryDef = queryKitSpec.getQueryKit().makeQueryDefinition(
queryKitSpec,
queryToPlan,
toolKit,
shuffleSpecFactory,
tuningConfig.getMaxNumWorkers(),
resultShuffleSpecFactory,
0
);
}
Expand Down Expand Up @@ -1796,7 +1791,7 @@ private static QueryDefinition makeQueryDefinition(

// Add all query stages.
// Set shuffleCheckHasMultipleValues on the stage that serves as input to the final segment-generation stage.
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId);
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryKitSpec.getQueryId());

for (final StageDefinition stageDef : queryDef.getStageDefinitions()) {
if (stageDef.equals(finalShuffleStageDef)) {
Expand All @@ -1822,7 +1817,7 @@ private static QueryDefinition makeQueryDefinition(
// attaching new query results stage if the final stage does sort during shuffle so that results are ordered.
StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition();
if (finalShuffleStageDef.doesSortDuringShuffle()) {
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId);
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryKitSpec.getQueryId());
builder.addAll(queryDef);
builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
.inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
Expand Down Expand Up @@ -1859,15 +1854,15 @@ private static QueryDefinition makeQueryDefinition(
}

final ResultFormat resultFormat = exportMSQDestination.getResultFormat();
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId);
final QueryDefinitionBuilder builder = QueryDefinition.builder(queryKitSpec.getQueryId());
builder.addAll(queryDef);
builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
.inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
.maxWorkerCount(tuningConfig.getMaxNumWorkers())
.signature(queryDef.getFinalStageDefinition().getSignature())
.shuffleSpec(null)
.processorFactory(new ExportResultsFrameProcessorFactory(
queryId,
queryKitSpec.getQueryId(),
exportStorageProvider,
resultFormat,
columnMappings,
Expand Down Expand Up @@ -2183,6 +2178,34 @@ private static void logKernelStatus(final String queryId, final ControllerQueryK
}
}

/**
* Create a result-reader executor for {@link RunQueryUntilDone#readQueryResults()}.
*/
private static FrameProcessorExecutor createResultReaderExec(final String queryId)
{
return new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId + "]")))
);
}

/**
* Cancel any currently-running work and shut down a result-reader executor, like one created by
* {@link #createResultReaderExec(String)}.
*/
private static void closeResultReaderExec(final FrameProcessorExecutor exec)
{
try {
exec.cancel(RESULT_READER_CANCELLATION_ID);
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
exec.shutdownNow();
}
}

private void stopExternalFetchers()
{
if (workerSketchFetcher != null) {
Expand Down Expand Up @@ -2692,12 +2715,9 @@ private void startQueryResultsReader()
inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds);
}

final FrameProcessorExecutor resultReaderExec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId() + "]")))
);
final FrameProcessorExecutor resultReaderExec = createResultReaderExec(queryId());
resultReaderExec.registerCancellationId(RESULT_READER_CANCELLATION_ID);

final String cancellationId = "results-reader";
ReadableConcatFrameChannel resultsChannel = null;

try {
Expand All @@ -2707,7 +2727,7 @@ private void startQueryResultsReader()
inputChannelFactory,
() -> ArenaMemoryAllocator.createOnHeap(5_000_000),
resultReaderExec,
cancellationId,
RESULT_READER_CANCELLATION_ID,
null,
MultiStageQueryContext.removeNullBytes(querySpec.getQuery().context())
);
Expand Down Expand Up @@ -2741,7 +2761,7 @@ private void startQueryResultsReader()
queryListener
);

queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, cancellationId);
queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, RESULT_READER_CANCELLATION_ID);

// When results are done being read, kick the main thread.
// Important: don't use FutureUtils.futureWithBaggage, because we need queryResultsReaderFuture to resolve
Expand All @@ -2758,23 +2778,13 @@ private void startQueryResultsReader()
e,
() -> CloseableUtils.closeAll(
finalResultsChannel,
() -> resultReaderExec.getExecutorService().shutdownNow()
() -> closeResultReaderExec(resultReaderExec)
)
);
}

// Result reader is set up. Register with the query-wide closer.
closer.register(() -> {
try {
resultReaderExec.cancel(cancellationId);
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
resultReaderExec.getExecutorService().shutdownNow();
}
});
closer.register(() -> closeResultReaderExec(resultReaderExec));
}

/**
Expand Down
Loading