Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -38,7 +39,6 @@
import org.apache.samza.operators.spec.JoinOperatorSpec;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.system.StreamSpec;
import org.apache.samza.system.SystemStream;
import org.apache.samza.table.TableSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -54,14 +54,16 @@
public class ExecutionPlanner {
private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class);

static final int MAX_INFERRED_PARTITIONS = 256;
/* package private */ static final int MAX_INFERRED_PARTITIONS = 256;

private final Config config;
private final StreamConfig streamConfig;
private final StreamManager streamManager;

public ExecutionPlanner(Config config, StreamManager streamManager) {
this.config = config;
this.streamManager = streamManager;
this.streamConfig = new StreamConfig(config);
}

public ExecutionPlan plan(OperatorSpecGraph specGraph) {
Expand All @@ -71,12 +73,10 @@ public ExecutionPlan plan(OperatorSpecGraph specGraph) {
JobGraph jobGraph = createJobGraph(specGraph);

// fetch the external streams partition info
updateExistingPartitions(jobGraph, streamManager);
fetchInputAndOutputStreamPartitions(jobGraph, streamManager);

if (!jobGraph.getIntermediateStreamEdges().isEmpty()) {
// figure out the partitions for internal streams
calculatePartitions(jobGraph);
}
// figure out the partitions for internal streams
calculatePartitions(jobGraph);

return jobGraph;
}
Expand All @@ -85,9 +85,9 @@ private void validateConfig() {
ApplicationConfig appConfig = new ApplicationConfig(config);
ClusterManagerConfig clusterConfig = new ClusterManagerConfig(config);
// currently we don't support host-affinity in batch mode
if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH
&& clusterConfig.getHostAffinityEnabled()) {
throw new SamzaException("Host affinity is not supported in batch mode. Please configure job.host-affinity.enabled=false.");
if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH && clusterConfig.getHostAffinityEnabled()) {
throw new SamzaException(String.format("Host affinity is not supported in batch mode. Please configure %s=false.",
ClusterManagerConfig.CLUSTER_MANAGER_HOST_AFFINITY_ENABLED));
}
}

Expand All @@ -96,30 +96,33 @@ private void validateConfig() {
*/
/* package private */ JobGraph createJobGraph(OperatorSpecGraph specGraph) {
JobGraph jobGraph = new JobGraph(config, specGraph);
StreamConfig streamConfig = new StreamConfig(config);

// Source streams contain both input and intermediate streams.
Set<StreamSpec> sourceStreams = getStreamSpecs(specGraph.getInputOperators().keySet(), streamConfig);
// Sink streams contain both output and intermediate streams.
Set<StreamSpec> sinkStreams = getStreamSpecs(specGraph.getOutputStreams().keySet(), streamConfig);
Set<StreamSpec> intStreams = new HashSet<>(sourceStreams);
Set<TableSpec> tables = new HashSet<>(specGraph.getTables().keySet());
intStreams.retainAll(sinkStreams);
sourceStreams.removeAll(intStreams);
sinkStreams.removeAll(intStreams);

Set<StreamSpec> intermediateStreams = Sets.intersection(sourceStreams, sinkStreams);
Set<StreamSpec> inputStreams = Sets.difference(sourceStreams, intermediateStreams);
Set<StreamSpec> outputStreams = Sets.difference(sinkStreams, intermediateStreams);

Set<TableSpec> tables = specGraph.getTables().keySet();

// For this phase, we have a single job node for the whole dag
String jobName = config.get(JobConfig.JOB_NAME());
String jobId = config.get(JobConfig.JOB_ID(), "1");
JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId);

// add sources
sourceStreams.forEach(spec -> jobGraph.addSource(spec, node));
// Add input streams
inputStreams.forEach(spec -> jobGraph.addInputStream(spec, node));

// add sinks
sinkStreams.forEach(spec -> jobGraph.addSink(spec, node));
// Add output streams
outputStreams.forEach(spec -> jobGraph.addOutputStream(spec, node));

// add intermediate streams
intStreams.forEach(spec -> jobGraph.addIntermediateStream(spec, node, node));
// Add intermediate streams
intermediateStreams.forEach(spec -> jobGraph.addIntermediateStream(spec, node, node));

// add tables
// Add tables
tables.forEach(spec -> jobGraph.addTable(spec, node));

jobGraph.validate();
Expand All @@ -132,71 +135,80 @@ private void validateConfig() {
*/
/* package private */ void calculatePartitions(JobGraph jobGraph) {
// calculate the partitions for the input streams of join operators
calculateJoinInputPartitions(jobGraph, config);
calculateJoinInputPartitions(jobGraph, streamConfig);

// calculate the partitions for the rest of intermediate streams
calculateIntStreamPartitions(jobGraph, config);
calculateIntermediateStreamPartitions(jobGraph, config);

// validate all the partitions are assigned
validatePartitions(jobGraph);
validateIntermediateStreamPartitions(jobGraph);
}

/**
* Fetch the partitions of source/sink streams and update the StreamEdges.
* Fetch the partitions of input/output streams and update the corresponding StreamEdges.
* @param jobGraph {@link JobGraph}
* @param streamManager the {@link StreamManager} to interface with the streams.
*/
/* package private */ static void updateExistingPartitions(JobGraph jobGraph, StreamManager streamManager) {
/* package private */ static void fetchInputAndOutputStreamPartitions(JobGraph jobGraph, StreamManager streamManager) {
Set<StreamEdge> existingStreams = new HashSet<>();
existingStreams.addAll(jobGraph.getSources());
existingStreams.addAll(jobGraph.getSinks());
existingStreams.addAll(jobGraph.getInputStreams());
existingStreams.addAll(jobGraph.getOutputStreams());

// System to StreamEdges
Multimap<String, StreamEdge> systemToStreamEdges = HashMultimap.create();
// group the StreamEdge(s) based on the system name
existingStreams.forEach(streamEdge -> {
SystemStream systemStream = streamEdge.getSystemStream();
systemToStreamEdges.put(systemStream.getSystem(), streamEdge);
});
for (Map.Entry<String, Collection<StreamEdge>> entry : systemToStreamEdges.asMap().entrySet()) {
String systemName = entry.getKey();
Collection<StreamEdge> streamEdges = entry.getValue();

// Group StreamEdges by system
for (StreamEdge streamEdge : existingStreams) {
String system = streamEdge.getSystemStream().getSystem();
systemToStreamEdges.put(system, streamEdge);
}

// Fetch partition count for every set of StreamEdges belonging to a particular system.
for (String system : systemToStreamEdges.keySet()) {
Collection<StreamEdge> streamEdges = systemToStreamEdges.get(system);

// Map every stream to its corresponding StreamEdge so we can retrieve a StreamEdge given its stream.
Map<String, StreamEdge> streamToStreamEdge = new HashMap<>();
// create the stream name to StreamEdge mapping for this system
streamEdges.forEach(streamEdge -> streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge));
// retrieve the partition counts for the streams in this system
Map<String, Integer> streamToPartitionCount = streamManager.getStreamPartitionCounts(systemName, streamToStreamEdge.keySet());
// set the partitions of a stream to its StreamEdge
streamToPartitionCount.forEach((stream, partitionCount) -> {
streamToStreamEdge.get(stream).setPartitionCount(partitionCount);
log.info("Partition count is {} for stream {}", partitionCount, stream);
});
for (StreamEdge streamEdge : streamEdges) {
streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge);
}

// Retrieve partition count for every set of streams.
Set<String> streams = streamToStreamEdge.keySet();
Map<String, Integer> streamToPartitionCount = streamManager.getStreamPartitionCounts(system, streams);

// Retrieve StreamEdge corresponding to every stream and set partition count on it.
for (Map.Entry<String, Integer> entry : streamToPartitionCount.entrySet()) {
String stream = entry.getKey();
Integer partitionCount = entry.getValue();
streamToStreamEdge.get(stream).setPartitionCount(partitionCount);
log.info("Fetched partition count value {} for stream {}", partitionCount, stream);
}
}
}

/**
* Calculate the partitions for the input streams of join operators
*/
/* package private */ static void calculateJoinInputPartitions(JobGraph jobGraph, Config config) {
/* package private */ static void calculateJoinInputPartitions(JobGraph jobGraph, StreamConfig streamConfig) {
// mapping from a source stream to all join specs reachable from it
Multimap<OperatorSpec, StreamEdge> joinSpecToStreamEdges = HashMultimap.create();
Multimap<JoinOperatorSpec, StreamEdge> joinSpecToStreamEdges = HashMultimap.create();
// reverse mapping of the above
Multimap<StreamEdge, OperatorSpec> streamEdgeToJoinSpecs = HashMultimap.create();
Multimap<StreamEdge, JoinOperatorSpec> streamEdgeToJoinSpecs = HashMultimap.create();
// A queue of joins with known input partitions
Queue<OperatorSpec> joinQ = new LinkedList<>();
Queue<JoinOperatorSpec> joinQ = new LinkedList<>();
// The visited set keeps track of the join specs that have been already inserted in the queue before
Set<OperatorSpec> visited = new HashSet<>();
Set<JoinOperatorSpec> visited = new HashSet<>();

StreamConfig streamConfig = new StreamConfig(config);

jobGraph.getSpecGraph().getInputOperators().forEach((key, value) -> {
StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(key, streamConfig));
jobGraph.getSpecGraph().getInputOperators().forEach((streamId, inputOperatorSpec) -> {
StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(streamId, streamConfig));
// Traverses the StreamGraph to find and update mappings for all Joins reachable from this input StreamEdge
findReachableJoins(value, streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs, joinQ, visited);
findReachableJoins(inputOperatorSpec, streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs, joinQ, visited);
});

// At this point, joinQ contains joinSpecs where at least one of the input stream edge partitions is known.
while (!joinQ.isEmpty()) {
OperatorSpec join = joinQ.poll();
JoinOperatorSpec join = joinQ.poll();
int partitions = StreamEdge.PARTITIONS_UNKNOWN;
// loop through the input streams to the join and find the partition count
for (StreamEdge edge : joinSpecToStreamEdges.get(join)) {
Expand All @@ -223,7 +235,7 @@ private void validateConfig() {
edge.setPartitionCount(partitions);

// find other joins can be inferred by setting this edge
for (OperatorSpec op : streamEdgeToJoinSpecs.get(edge)) {
for (JoinOperatorSpec op : streamEdgeToJoinSpecs.get(edge)) {
if (!visited.contains(op)) {
joinQ.add(op);
visited.add(op);
Expand All @@ -244,17 +256,19 @@ private void validateConfig() {
* @param joinQ queue that contains joinSpecs where at least one of the input stream edge partitions is known.
*/
private static void findReachableJoins(OperatorSpec operatorSpec, StreamEdge sourceStreamEdge,
Multimap<OperatorSpec, StreamEdge> joinSpecToStreamEdges,
Multimap<StreamEdge, OperatorSpec> streamEdgeToJoinSpecs,
Queue<OperatorSpec> joinQ, Set<OperatorSpec> visited) {
Multimap<JoinOperatorSpec, StreamEdge> joinSpecToStreamEdges,
Multimap<StreamEdge, JoinOperatorSpec> streamEdgeToJoinSpecs,
Queue<JoinOperatorSpec> joinQ, Set<JoinOperatorSpec> visited) {

if (operatorSpec instanceof JoinOperatorSpec) {
joinSpecToStreamEdges.put(operatorSpec, sourceStreamEdge);
streamEdgeToJoinSpecs.put(sourceStreamEdge, operatorSpec);
JoinOperatorSpec joinOperatorSpec = (JoinOperatorSpec) operatorSpec;
joinSpecToStreamEdges.put(joinOperatorSpec, sourceStreamEdge);
streamEdgeToJoinSpecs.put(sourceStreamEdge, joinOperatorSpec);

if (!visited.contains(operatorSpec) && sourceStreamEdge.getPartitionCount() > 0) {
if (!visited.contains(joinOperatorSpec) && sourceStreamEdge.getPartitionCount() > 0) {
// put the joins with known input partitions into the queue and mark as visited
joinQ.add(operatorSpec);
visited.add(operatorSpec);
joinQ.add(joinOperatorSpec);
visited.add(joinOperatorSpec);
}
}

Expand All @@ -265,23 +279,34 @@ private static void findReachableJoins(OperatorSpec operatorSpec, StreamEdge sou
}
}

private static void calculateIntStreamPartitions(JobGraph jobGraph, Config config) {
int partitions = config.getInt(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), StreamEdge.PARTITIONS_UNKNOWN);
if (partitions < 0) {
private static void calculateIntermediateStreamPartitions(JobGraph jobGraph, Config config) {
final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS();
int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN);
if (partitions == StreamEdge.PARTITIONS_UNKNOWN) {
// use the following simple algo to figure out the partitions
// partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions))
// partition will be further bounded by MAX_INFERRED_PARTITIONS.
// This is important when running in hadoop where an HDFS input can have lots of files (partitions).
int maxInPartitions = maxPartition(jobGraph.getSources());
int maxOutPartitions = maxPartition(jobGraph.getSinks());
int maxInPartitions = maxPartitions(jobGraph.getInputStreams());
int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams());
partitions = Math.max(maxInPartitions, maxOutPartitions);

if (partitions > MAX_INFERRED_PARTITIONS) {
partitions = MAX_INFERRED_PARTITIONS;
log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.",
partitions, MAX_INFERRED_PARTITIONS));
}
} else {
// Reject any zero or other negative values explicitly specified in config.
if (partitions <= 0) {
throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions,
defaultPartitionsConfigProperty));
}

log.info("Using partition count value {} specified for config property {}", partitions,
defaultPartitionsConfigProperty);
}

for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
if (edge.getPartitionCount() <= 0) {
log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions);
Expand All @@ -290,16 +315,15 @@ private static void calculateIntStreamPartitions(JobGraph jobGraph, Config confi
}
}

private static void validatePartitions(JobGraph jobGraph) {
private static void validateIntermediateStreamPartitions(JobGraph jobGraph) {
for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
if (edge.getPartitionCount() <= 0) {
throw new SamzaException(String.format("Failure to assign the partitions to Stream %s", edge.getName()));
throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName()));
}
}
}

/* package private */ static int maxPartition(Collection<StreamEdge> edges) {
return edges.stream().map(StreamEdge::getPartitionCount).reduce(Integer::max).orElse(StreamEdge.PARTITIONS_UNKNOWN);
/* package private */ static int maxPartitions(Collection<StreamEdge> edges) {
return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN);
}

}
Loading