diff --git a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java index ea892feb7b..810f4240be 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java @@ -21,6 +21,7 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; +import com.google.common.collect.Sets; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; @@ -38,7 +39,6 @@ import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; import org.apache.samza.system.StreamSpec; -import org.apache.samza.system.SystemStream; import org.apache.samza.table.TableSpec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,14 +54,16 @@ public class ExecutionPlanner { private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class); - static final int MAX_INFERRED_PARTITIONS = 256; + /* package private */ static final int MAX_INFERRED_PARTITIONS = 256; private final Config config; + private final StreamConfig streamConfig; private final StreamManager streamManager; public ExecutionPlanner(Config config, StreamManager streamManager) { this.config = config; this.streamManager = streamManager; + this.streamConfig = new StreamConfig(config); } public ExecutionPlan plan(OperatorSpecGraph specGraph) { @@ -71,12 +73,10 @@ public ExecutionPlan plan(OperatorSpecGraph specGraph) { JobGraph jobGraph = createJobGraph(specGraph); // fetch the external streams partition info - updateExistingPartitions(jobGraph, streamManager); + fetchInputAndOutputStreamPartitions(jobGraph, streamManager); - if (!jobGraph.getIntermediateStreamEdges().isEmpty()) { - // figure out the partitions for internal streams - calculatePartitions(jobGraph); - } + // figure out the partitions for internal streams + calculatePartitions(jobGraph); return jobGraph; } @@ -85,9 +85,9 @@ private void validateConfig() { ApplicationConfig appConfig = new ApplicationConfig(config); ClusterManagerConfig clusterConfig = new ClusterManagerConfig(config); // currently we don't support host-affinity in batch mode - if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH - && clusterConfig.getHostAffinityEnabled()) { - throw new SamzaException("Host affinity is not supported in batch mode. Please configure job.host-affinity.enabled=false."); + if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH && clusterConfig.getHostAffinityEnabled()) { + throw new SamzaException(String.format("Host affinity is not supported in batch mode. Please configure %s=false.", + ClusterManagerConfig.CLUSTER_MANAGER_HOST_AFFINITY_ENABLED)); } } @@ -96,30 +96,33 @@ private void validateConfig() { */ /* package private */ JobGraph createJobGraph(OperatorSpecGraph specGraph) { JobGraph jobGraph = new JobGraph(config, specGraph); - StreamConfig streamConfig = new StreamConfig(config); + + // Source streams contain both input and intermediate streams. Set sourceStreams = getStreamSpecs(specGraph.getInputOperators().keySet(), streamConfig); + // Sink streams contain both output and intermediate streams. Set sinkStreams = getStreamSpecs(specGraph.getOutputStreams().keySet(), streamConfig); - Set intStreams = new HashSet<>(sourceStreams); - Set tables = new HashSet<>(specGraph.getTables().keySet()); - intStreams.retainAll(sinkStreams); - sourceStreams.removeAll(intStreams); - sinkStreams.removeAll(intStreams); + + Set intermediateStreams = Sets.intersection(sourceStreams, sinkStreams); + Set inputStreams = Sets.difference(sourceStreams, intermediateStreams); + Set outputStreams = Sets.difference(sinkStreams, intermediateStreams); + + Set tables = specGraph.getTables().keySet(); // For this phase, we have a single job node for the whole dag String jobName = config.get(JobConfig.JOB_NAME()); String jobId = config.get(JobConfig.JOB_ID(), "1"); JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId); - // add sources - sourceStreams.forEach(spec -> jobGraph.addSource(spec, node)); + // Add input streams + inputStreams.forEach(spec -> jobGraph.addInputStream(spec, node)); - // add sinks - sinkStreams.forEach(spec -> jobGraph.addSink(spec, node)); + // Add output streams + outputStreams.forEach(spec -> jobGraph.addOutputStream(spec, node)); - // add intermediate streams - intStreams.forEach(spec -> jobGraph.addIntermediateStream(spec, node, node)); + // Add intermediate streams + intermediateStreams.forEach(spec -> jobGraph.addIntermediateStream(spec, node, node)); - // add tables + // Add tables tables.forEach(spec -> jobGraph.addTable(spec, node)); jobGraph.validate(); @@ -132,71 +135,80 @@ private void validateConfig() { */ /* package private */ void calculatePartitions(JobGraph jobGraph) { // calculate the partitions for the input streams of join operators - calculateJoinInputPartitions(jobGraph, config); + calculateJoinInputPartitions(jobGraph, streamConfig); // calculate the partitions for the rest of intermediate streams - calculateIntStreamPartitions(jobGraph, config); + calculateIntermediateStreamPartitions(jobGraph, config); // validate all the partitions are assigned - validatePartitions(jobGraph); + validateIntermediateStreamPartitions(jobGraph); } /** - * Fetch the partitions of source/sink streams and update the StreamEdges. + * Fetch the partitions of input/output streams and update the corresponding StreamEdges. * @param jobGraph {@link JobGraph} * @param streamManager the {@link StreamManager} to interface with the streams. */ - /* package private */ static void updateExistingPartitions(JobGraph jobGraph, StreamManager streamManager) { + /* package private */ static void fetchInputAndOutputStreamPartitions(JobGraph jobGraph, StreamManager streamManager) { Set existingStreams = new HashSet<>(); - existingStreams.addAll(jobGraph.getSources()); - existingStreams.addAll(jobGraph.getSinks()); + existingStreams.addAll(jobGraph.getInputStreams()); + existingStreams.addAll(jobGraph.getOutputStreams()); + // System to StreamEdges Multimap systemToStreamEdges = HashMultimap.create(); - // group the StreamEdge(s) based on the system name - existingStreams.forEach(streamEdge -> { - SystemStream systemStream = streamEdge.getSystemStream(); - systemToStreamEdges.put(systemStream.getSystem(), streamEdge); - }); - for (Map.Entry> entry : systemToStreamEdges.asMap().entrySet()) { - String systemName = entry.getKey(); - Collection streamEdges = entry.getValue(); + + // Group StreamEdges by system + for (StreamEdge streamEdge : existingStreams) { + String system = streamEdge.getSystemStream().getSystem(); + systemToStreamEdges.put(system, streamEdge); + } + + // Fetch partition count for every set of StreamEdges belonging to a particular system. + for (String system : systemToStreamEdges.keySet()) { + Collection streamEdges = systemToStreamEdges.get(system); + + // Map every stream to its corresponding StreamEdge so we can retrieve a StreamEdge given its stream. Map streamToStreamEdge = new HashMap<>(); - // create the stream name to StreamEdge mapping for this system - streamEdges.forEach(streamEdge -> streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge)); - // retrieve the partition counts for the streams in this system - Map streamToPartitionCount = streamManager.getStreamPartitionCounts(systemName, streamToStreamEdge.keySet()); - // set the partitions of a stream to its StreamEdge - streamToPartitionCount.forEach((stream, partitionCount) -> { - streamToStreamEdge.get(stream).setPartitionCount(partitionCount); - log.info("Partition count is {} for stream {}", partitionCount, stream); - }); + for (StreamEdge streamEdge : streamEdges) { + streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge); + } + + // Retrieve partition count for every set of streams. + Set streams = streamToStreamEdge.keySet(); + Map streamToPartitionCount = streamManager.getStreamPartitionCounts(system, streams); + + // Retrieve StreamEdge corresponding to every stream and set partition count on it. + for (Map.Entry entry : streamToPartitionCount.entrySet()) { + String stream = entry.getKey(); + Integer partitionCount = entry.getValue(); + streamToStreamEdge.get(stream).setPartitionCount(partitionCount); + log.info("Fetched partition count value {} for stream {}", partitionCount, stream); + } } } /** * Calculate the partitions for the input streams of join operators */ - /* package private */ static void calculateJoinInputPartitions(JobGraph jobGraph, Config config) { + /* package private */ static void calculateJoinInputPartitions(JobGraph jobGraph, StreamConfig streamConfig) { // mapping from a source stream to all join specs reachable from it - Multimap joinSpecToStreamEdges = HashMultimap.create(); + Multimap joinSpecToStreamEdges = HashMultimap.create(); // reverse mapping of the above - Multimap streamEdgeToJoinSpecs = HashMultimap.create(); + Multimap streamEdgeToJoinSpecs = HashMultimap.create(); // A queue of joins with known input partitions - Queue joinQ = new LinkedList<>(); + Queue joinQ = new LinkedList<>(); // The visited set keeps track of the join specs that have been already inserted in the queue before - Set visited = new HashSet<>(); + Set visited = new HashSet<>(); - StreamConfig streamConfig = new StreamConfig(config); - - jobGraph.getSpecGraph().getInputOperators().forEach((key, value) -> { - StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(key, streamConfig)); + jobGraph.getSpecGraph().getInputOperators().forEach((streamId, inputOperatorSpec) -> { + StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(streamId, streamConfig)); // Traverses the StreamGraph to find and update mappings for all Joins reachable from this input StreamEdge - findReachableJoins(value, streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs, joinQ, visited); + findReachableJoins(inputOperatorSpec, streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs, joinQ, visited); }); // At this point, joinQ contains joinSpecs where at least one of the input stream edge partitions is known. while (!joinQ.isEmpty()) { - OperatorSpec join = joinQ.poll(); + JoinOperatorSpec join = joinQ.poll(); int partitions = StreamEdge.PARTITIONS_UNKNOWN; // loop through the input streams to the join and find the partition count for (StreamEdge edge : joinSpecToStreamEdges.get(join)) { @@ -223,7 +235,7 @@ private void validateConfig() { edge.setPartitionCount(partitions); // find other joins can be inferred by setting this edge - for (OperatorSpec op : streamEdgeToJoinSpecs.get(edge)) { + for (JoinOperatorSpec op : streamEdgeToJoinSpecs.get(edge)) { if (!visited.contains(op)) { joinQ.add(op); visited.add(op); @@ -244,17 +256,19 @@ private void validateConfig() { * @param joinQ queue that contains joinSpecs where at least one of the input stream edge partitions is known. */ private static void findReachableJoins(OperatorSpec operatorSpec, StreamEdge sourceStreamEdge, - Multimap joinSpecToStreamEdges, - Multimap streamEdgeToJoinSpecs, - Queue joinQ, Set visited) { + Multimap joinSpecToStreamEdges, + Multimap streamEdgeToJoinSpecs, + Queue joinQ, Set visited) { + if (operatorSpec instanceof JoinOperatorSpec) { - joinSpecToStreamEdges.put(operatorSpec, sourceStreamEdge); - streamEdgeToJoinSpecs.put(sourceStreamEdge, operatorSpec); + JoinOperatorSpec joinOperatorSpec = (JoinOperatorSpec) operatorSpec; + joinSpecToStreamEdges.put(joinOperatorSpec, sourceStreamEdge); + streamEdgeToJoinSpecs.put(sourceStreamEdge, joinOperatorSpec); - if (!visited.contains(operatorSpec) && sourceStreamEdge.getPartitionCount() > 0) { + if (!visited.contains(joinOperatorSpec) && sourceStreamEdge.getPartitionCount() > 0) { // put the joins with known input partitions into the queue and mark as visited - joinQ.add(operatorSpec); - visited.add(operatorSpec); + joinQ.add(joinOperatorSpec); + visited.add(joinOperatorSpec); } } @@ -265,15 +279,16 @@ private static void findReachableJoins(OperatorSpec operatorSpec, StreamEdge sou } } - private static void calculateIntStreamPartitions(JobGraph jobGraph, Config config) { - int partitions = config.getInt(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), StreamEdge.PARTITIONS_UNKNOWN); - if (partitions < 0) { + private static void calculateIntermediateStreamPartitions(JobGraph jobGraph, Config config) { + final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(); + int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN); + if (partitions == StreamEdge.PARTITIONS_UNKNOWN) { // use the following simple algo to figure out the partitions // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions)) // partition will be further bounded by MAX_INFERRED_PARTITIONS. // This is important when running in hadoop where an HDFS input can have lots of files (partitions). - int maxInPartitions = maxPartition(jobGraph.getSources()); - int maxOutPartitions = maxPartition(jobGraph.getSinks()); + int maxInPartitions = maxPartitions(jobGraph.getInputStreams()); + int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams()); partitions = Math.max(maxInPartitions, maxOutPartitions); if (partitions > MAX_INFERRED_PARTITIONS) { @@ -281,7 +296,17 @@ private static void calculateIntStreamPartitions(JobGraph jobGraph, Config confi log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", partitions, MAX_INFERRED_PARTITIONS)); } + } else { + // Reject any zero or other negative values explicitly specified in config. + if (partitions <= 0) { + throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions, + defaultPartitionsConfigProperty)); + } + + log.info("Using partition count value {} specified for config property {}", partitions, + defaultPartitionsConfigProperty); } + for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { if (edge.getPartitionCount() <= 0) { log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions); @@ -290,16 +315,15 @@ private static void calculateIntStreamPartitions(JobGraph jobGraph, Config confi } } - private static void validatePartitions(JobGraph jobGraph) { + private static void validateIntermediateStreamPartitions(JobGraph jobGraph) { for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { if (edge.getPartitionCount() <= 0) { - throw new SamzaException(String.format("Failure to assign the partitions to Stream %s", edge.getName())); + throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName())); } } } - /* package private */ static int maxPartition(Collection edges) { - return edges.stream().map(StreamEdge::getPartitionCount).reduce(Integer::max).orElse(StreamEdge.PARTITIONS_UNKNOWN); + /* package private */ static int maxPartitions(Collection edges) { + return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN); } - } diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java index f49e6db6fc..5b190954c6 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java @@ -54,8 +54,8 @@ private final Map nodes = new HashMap<>(); private final Map edges = new HashMap<>(); - private final Set sources = new HashSet<>(); - private final Set sinks = new HashSet<>(); + private final Set inputStreams = new HashSet<>(); + private final Set outputStreams = new HashSet<>(); private final Set intermediateStreams = new HashSet<>(); private final Set tables = new HashSet<>(); private final Config config; @@ -115,26 +115,26 @@ public OperatorSpecGraph getSpecGraph() { /** * Add a source stream to a {@link JobNode} - * @param input source stream - * @param node the job node that consumes from the source + * @param streamSpec input stream + * @param node the job node that consumes from the streamSpec */ - void addSource(StreamSpec input, JobNode node) { - StreamEdge edge = getOrCreateStreamEdge(input); + void addInputStream(StreamSpec streamSpec, JobNode node) { + StreamEdge edge = getOrCreateStreamEdge(streamSpec); edge.addTargetNode(node); node.addInEdge(edge); - sources.add(edge); + inputStreams.add(edge); } /** - * Add a sink stream to a {@link JobNode} - * @param output sink stream - * @param node the job node that outputs to the sink + * Add an output stream to a {@link JobNode} + * @param streamSpec output stream + * @param node the job node that outputs to the output stream */ - void addSink(StreamSpec output, JobNode node) { - StreamEdge edge = getOrCreateStreamEdge(output); + void addOutputStream(StreamSpec streamSpec, JobNode node) { + StreamEdge edge = getOrCreateStreamEdge(streamSpec); edge.addSourceNode(node); node.addOutEdge(edge); - sinks.add(edge); + outputStreams.add(edge); } /** @@ -204,19 +204,19 @@ List getJobNodes() { } /** - * Returns the source streams in the graph + * Returns the input streams in the graph * @return unmodifiable set of {@link StreamEdge} */ - Set getSources() { - return Collections.unmodifiableSet(sources); + Set getInputStreams() { + return Collections.unmodifiableSet(inputStreams); } /** - * Return the sink streams in the graph + * Return the output streams in the graph * @return unmodifiable set of {@link StreamEdge} */ - Set getSinks() { - return Collections.unmodifiableSet(sinks); + Set getOutputStreams() { + return Collections.unmodifiableSet(outputStreams); } /** @@ -236,22 +236,22 @@ Set getIntermediateStreamEdges() { } /** - * Validate the graph has the correct topology, meaning the sources are coming from external streams, - * sinks are going to external streams, and the nodes are connected with intermediate streams. - * Also validate all the nodes are reachable from the sources. + * Validate the graph has the correct topology, meaning the input streams are coming from external streams, + * output streams are going to external streams, and the nodes are connected with intermediate streams. + * Also validate all the nodes are reachable from the input streams. */ void validate() { - validateSources(); - validateSinks(); + validateInputStreams(); + validateOutputStreams(); validateInternalStreams(); validateReachability(); } /** - * Validate the sources should have indegree being 0 and outdegree greater than 0 + * Validate the input streams should have indegree being 0 and outdegree greater than 0 */ - private void validateSources() { - sources.forEach(edge -> { + private void validateInputStreams() { + inputStreams.forEach(edge -> { if (!edge.getSourceNodes().isEmpty()) { throw new IllegalArgumentException( String.format("Source stream %s should not have producers.", edge.getName())); @@ -264,10 +264,10 @@ private void validateSources() { } /** - * Validate the sinks should have outdegree being 0 and indegree greater than 0 + * Validate the output streams should have outdegree being 0 and indegree greater than 0 */ - private void validateSinks() { - sinks.forEach(edge -> { + private void validateOutputStreams() { + outputStreams.forEach(edge -> { if (!edge.getTargetNodes().isEmpty()) { throw new IllegalArgumentException( String.format("Sink stream %s should not have consumers", edge.getName())); @@ -284,8 +284,8 @@ private void validateSinks() { */ private void validateInternalStreams() { Set internalEdges = new HashSet<>(edges.values()); - internalEdges.removeAll(sources); - internalEdges.removeAll(sinks); + internalEdges.removeAll(inputStreams); + internalEdges.removeAll(outputStreams); internalEdges.forEach(edge -> { if (edge.getSourceNodes().isEmpty() || edge.getTargetNodes().isEmpty()) { @@ -296,10 +296,10 @@ private void validateInternalStreams() { } /** - * Validate all nodes are reachable by sources. + * Validate all nodes are reachable by input streams. */ private void validateReachability() { - // validate all nodes are reachable from the sources + // validate all nodes are reachable from the input streams final Set reachable = findReachable(); if (reachable.size() != nodes.size()) { Set unreachable = new HashSet<>(nodes.values()); @@ -317,8 +317,8 @@ Set findReachable() { Queue queue = new ArrayDeque<>(); Set visited = new HashSet<>(); - sources.forEach(source -> { - List next = source.getTargetNodes(); + inputStreams.forEach(input -> { + List next = input.getTargetNodes(); queue.addAll(next); visited.addAll(next); }); @@ -353,11 +353,11 @@ List topologicalSort() { pnodes.forEach(node -> { String nid = node.getId(); //only count the degrees of intermediate streams - long degree = node.getInEdges().stream().filter(e -> !sources.contains(e)).count(); + long degree = node.getInEdges().stream().filter(e -> !inputStreams.contains(e)).count(); indegree.put(nid, degree); if (degree == 0L) { - // start from the nodes that has no intermediate input streams, so it only consumes from sources + // start from the nodes that has no intermediate input streams, so it only consumes from input streams q.add(node); visited.add(node); } @@ -410,9 +410,9 @@ List topologicalSort() { q.add(minNode); visited.add(minNode); } else { - // all the remaining nodes should be reachable from sources - // start from sources again to find the next node that hasn't been visited - JobNode nextNode = sources.stream().flatMap(source -> source.getTargetNodes().stream()) + // all the remaining nodes should be reachable from input streams + // start from input streams again to find the next node that hasn't been visited + JobNode nextNode = inputStreams.stream().flatMap(input -> input.getTargetNodes().stream()) .filter(node -> !visited.contains(node)) .findAny().get(); q.add(nextNode); diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java b/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java index 3a8d5c9427..91453d2dbd 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java @@ -134,8 +134,8 @@ static final class JobGraphJson { jobGraphJson.sinkStreams = new HashMap<>(); jobGraphJson.intermediateStreams = new HashMap<>(); jobGraphJson.tables = new HashMap<>(); - jobGraph.getSources().forEach(e -> buildStreamEdgeJson(e, jobGraphJson.sourceStreams)); - jobGraph.getSinks().forEach(e -> buildStreamEdgeJson(e, jobGraphJson.sinkStreams)); + jobGraph.getInputStreams().forEach(e -> buildStreamEdgeJson(e, jobGraphJson.sourceStreams)); + jobGraph.getOutputStreams().forEach(e -> buildStreamEdgeJson(e, jobGraphJson.sinkStreams)); jobGraph.getIntermediateStreamEdges().forEach(e -> buildStreamEdgeJson(e, jobGraphJson.intermediateStreams)); jobGraph.getTables().forEach(t -> buildTableJson(t, jobGraphJson.tables)); diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java index 61cf6c52f1..c08922545d 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java @@ -28,10 +28,12 @@ import java.util.Map; import java.util.Set; import org.apache.samza.Partition; +import org.apache.samza.SamzaException; import org.apache.samza.application.StreamApplicationDescriptorImpl; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; +import org.apache.samza.config.StreamConfig; import org.apache.samza.config.TaskConfig; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; @@ -72,7 +74,6 @@ public class TestExecutionPlanner { private GenericInputDescriptor> input2Descriptor; private StreamSpec input3Spec; private GenericInputDescriptor> input3Descriptor; - private StreamSpec input4Spec; private GenericInputDescriptor> input4Descriptor; private StreamSpec output1Spec; private GenericOutputDescriptor> output1Descriptor; @@ -168,44 +169,49 @@ private StreamApplicationDescriptorImpl createStreamGraphWithJoin() { private StreamApplicationDescriptorImpl createStreamGraphWithJoinAndWindow() { return new StreamApplicationDescriptorImpl(appDesc -> { - MessageStream> messageStream1 = - appDesc.getInputStream(input1Descriptor) - .map(m -> m); + MessageStream> messageStream1 = appDesc.getInputStream(input1Descriptor).map(m -> m); MessageStream> messageStream2 = - appDesc.getInputStream(input2Descriptor) - .partitionBy(m -> m.key, m -> m.value, "p1") - .filter(m -> true); + appDesc.getInputStream(input2Descriptor).partitionBy(m -> m.key, m -> m.value, "p1").filter(m -> true); MessageStream> messageStream3 = - appDesc.getInputStream(input3Descriptor) - .filter(m -> true) - .partitionBy(m -> m.key, m -> m.value, "p2") - .map(m -> m); + appDesc.getInputStream(input3Descriptor).filter(m -> true).partitionBy(m -> m.key, m -> m.value, "p2").map(m -> m); OutputStream> output1 = appDesc.getOutputStream(output1Descriptor); OutputStream> output2 = appDesc.getOutputStream(output2Descriptor); messageStream1.map(m -> m) - .filter(m->true) - .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(8), mock(Serde.class), mock(Serde.class)), "w1"); + .filter(m -> true) + .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(8), mock(Serde.class), mock(Serde.class)), "w1"); messageStream2.map(m -> m) - .filter(m->true) - .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(16), mock(Serde.class), mock(Serde.class)), "w2"); + .filter(m -> true) + .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(16), mock(Serde.class), mock(Serde.class)), "w2"); + + messageStream1.join(messageStream2, (JoinFunction, KV, KV>) mock(JoinFunction.class), + mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(1600), "j1").sendTo(output1); + messageStream3.join(messageStream2, (JoinFunction, KV, KV>) mock(JoinFunction.class), + mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(100), "j2").sendTo(output2); + messageStream3.join(messageStream2, (JoinFunction, KV, KV>) mock(JoinFunction.class), + mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(252), "j3").sendTo(output2); + }, config); + } + + private StreamApplicationDescriptorImpl createStreamGraphWithInvalidJoin() { + /** + * input1 (64) -- + * | + * join -> output1 (8) + * | + * input3 (32) -- + */ + return new StreamApplicationDescriptorImpl(appDesc -> { + MessageStream> messageStream1 = appDesc.getInputStream(input1Descriptor); + MessageStream> messageStream3 = appDesc.getInputStream(input3Descriptor); + OutputStream> output1 = appDesc.getOutputStream(output1Descriptor); messageStream1 - .join(messageStream2, - (JoinFunction, KV, KV>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(1600), "j1") - .sendTo(output1); - messageStream3 - .join(messageStream2, - (JoinFunction, KV, KV>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(100), "j2") - .sendTo(output2); - messageStream3 - .join(messageStream2, - (JoinFunction, KV, KV>) mock(JoinFunction.class), - mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(252), "j3") - .sendTo(output2); + .join(messageStream3, + (JoinFunction, KV, KV>) mock(JoinFunction.class), + mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") + .sendTo(output1); }, config); } @@ -225,7 +231,6 @@ public void setup() { input1Spec = new StreamSpec("input1", "input1", "system1"); input2Spec = new StreamSpec("input2", "input2", "system2"); input3Spec = new StreamSpec("input3", "input3", "system2"); - input4Spec = new StreamSpec("input4", "input4", "system1"); output1Spec = new StreamSpec("output1", "output1", "system1"); output2Spec = new StreamSpec("output2", "output2", "system2"); @@ -265,8 +270,8 @@ public void testCreateProcessorGraph() { StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph()); - assertTrue(jobGraph.getSources().size() == 3); - assertTrue(jobGraph.getSinks().size() == 2); + assertTrue(jobGraph.getInputStreams().size() == 3); + assertTrue(jobGraph.getOutputStreams().size() == 2); assertTrue(jobGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy } @@ -276,7 +281,7 @@ public void testFetchExistingStreamPartitions() { StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph()); - ExecutionPlanner.updateExistingPartitions(jobGraph, streamManager); + ExecutionPlanner.fetchInputAndOutputStreamPartitions(jobGraph, streamManager); assertTrue(jobGraph.getOrCreateStreamEdge(input1Spec).getPartitionCount() == 64); assertTrue(jobGraph.getOrCreateStreamEdge(input2Spec).getPartitionCount() == 16); assertTrue(jobGraph.getOrCreateStreamEdge(input3Spec).getPartitionCount() == 32); @@ -294,8 +299,8 @@ public void testCalculateJoinInputPartitions() { StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph()); - ExecutionPlanner.updateExistingPartitions(jobGraph, streamManager); - ExecutionPlanner.calculateJoinInputPartitions(jobGraph, config); + ExecutionPlanner.fetchInputAndOutputStreamPartitions(jobGraph, streamManager); + ExecutionPlanner.calculateJoinInputPartitions(jobGraph, new StreamConfig(config)); // the partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { @@ -303,6 +308,14 @@ public void testCalculateJoinInputPartitions() { }); } + @Test(expected = SamzaException.class) + public void testRejectsInvalidJoin() { + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidJoin(); + + planner.plan(graphSpec.getOperatorSpecGraph()); + } + @Test public void testDefaultPartitions() { Map map = new HashMap<>(config); @@ -321,7 +334,7 @@ public void testDefaultPartitions() { } @Test - public void testTriggerIntervalForJoins() throws Exception { + public void testTriggerIntervalForJoins() { Map map = new HashMap<>(config); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); Config cfg = new MapConfig(map); @@ -336,7 +349,7 @@ public void testTriggerIntervalForJoins() throws Exception { } @Test - public void testTriggerIntervalForWindowsAndJoins() throws Exception { + public void testTriggerIntervalForWindowsAndJoins() { Map map = new HashMap<>(config); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); Config cfg = new MapConfig(map); @@ -352,7 +365,7 @@ public void testTriggerIntervalForWindowsAndJoins() throws Exception { } @Test - public void testTriggerIntervalWithInvalidWindowMs() throws Exception { + public void testTriggerIntervalWithInvalidWindowMs() { Map map = new HashMap<>(config); map.put(TaskConfig.WINDOW_MS(), "-1"); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); @@ -368,9 +381,8 @@ public void testTriggerIntervalWithInvalidWindowMs() throws Exception { assertEquals("4", jobConfigs.get(0).get(TaskConfig.WINDOW_MS())); } - @Test - public void testTriggerIntervalForStatelessOperators() throws Exception { + public void testTriggerIntervalForStatelessOperators() { Map map = new HashMap<>(config); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); Config cfg = new MapConfig(map); @@ -384,7 +396,7 @@ public void testTriggerIntervalForStatelessOperators() throws Exception { } @Test - public void testTriggerIntervalWhenWindowMsIsConfigured() throws Exception { + public void testTriggerIntervalWhenWindowMsIsConfigured() { Map map = new HashMap<>(config); map.put(TaskConfig.WINDOW_MS(), "2000"); map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS)); @@ -399,7 +411,7 @@ public void testTriggerIntervalWhenWindowMsIsConfigured() throws Exception { } @Test - public void testCalculateIntStreamPartitions() throws Exception { + public void testCalculateIntStreamPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph()); @@ -423,10 +435,10 @@ public void testMaxPartition() { edge.setPartitionCount(16); edges.add(edge); - assertEquals(32, ExecutionPlanner.maxPartition(edges)); + assertEquals(32, ExecutionPlanner.maxPartitions(edges)); edges = Collections.emptyList(); - assertEquals(StreamEdge.PARTITIONS_UNKNOWN, ExecutionPlanner.maxPartition(edges)); + assertEquals(StreamEdge.PARTITIONS_UNKNOWN, ExecutionPlanner.maxPartitions(edges)); } @Test diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java index 73452d8b77..ed35d6725b 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java @@ -74,9 +74,9 @@ private void createGraph1() { JobNode n10 = graph1.getOrCreateJobNode("10", "1"); JobNode n11 = graph1.getOrCreateJobNode("11", "1"); - graph1.addSource(genStream(), n5); - graph1.addSource(genStream(), n7); - graph1.addSource(genStream(), n3); + graph1.addInputStream(genStream(), n5); + graph1.addInputStream(genStream(), n7); + graph1.addInputStream(genStream(), n3); graph1.addIntermediateStream(genStream(), n5, n11); graph1.addIntermediateStream(genStream(), n7, n11); graph1.addIntermediateStream(genStream(), n7, n8); @@ -85,9 +85,9 @@ private void createGraph1() { graph1.addIntermediateStream(genStream(), n11, n9); graph1.addIntermediateStream(genStream(), n8, n9); graph1.addIntermediateStream(genStream(), n11, n10); - graph1.addSink(genStream(), n2); - graph1.addSink(genStream(), n9); - graph1.addSink(genStream(), n10); + graph1.addOutputStream(genStream(), n2); + graph1.addOutputStream(genStream(), n9); + graph1.addOutputStream(genStream(), n10); } /** @@ -108,7 +108,7 @@ private void createGraph2() { JobNode n6 = graph2.getOrCreateJobNode("6", "1"); JobNode n7 = graph2.getOrCreateJobNode("7", "1"); - graph2.addSource(genStream(), n1); + graph2.addInputStream(genStream(), n1); graph2.addIntermediateStream(genStream(), n1, n2); graph2.addIntermediateStream(genStream(), n2, n3); graph2.addIntermediateStream(genStream(), n3, n4); @@ -117,7 +117,7 @@ private void createGraph2() { graph2.addIntermediateStream(genStream(), n6, n2); graph2.addIntermediateStream(genStream(), n5, n5); graph2.addIntermediateStream(genStream(), n5, n7); - graph2.addSink(genStream(), n7); + graph2.addOutputStream(genStream(), n7); } /** @@ -132,7 +132,7 @@ private void createGraph3() { JobNode n1 = graph3.getOrCreateJobNode("1", "1"); JobNode n2 = graph3.getOrCreateJobNode("2", "1"); - graph3.addSource(genStream(), n1); + graph3.addInputStream(genStream(), n1); graph3.addIntermediateStream(genStream(), n1, n1); graph3.addIntermediateStream(genStream(), n1, n2); graph3.addIntermediateStream(genStream(), n2, n2); @@ -149,7 +149,7 @@ private void createGraph4() { JobNode n1 = graph4.getOrCreateJobNode("1", "1"); - graph4.addSource(genStream(), n1); + graph4.addInputStream(genStream(), n1); graph4.addIntermediateStream(genStream(), n1, n1); } @@ -180,12 +180,12 @@ public void testAddSource() { StreamSpec s1 = genStream(); StreamSpec s2 = genStream(); StreamSpec s3 = genStream(); - graph.addSource(s1, n1); - graph.addSource(s2, n1); - graph.addSource(s3, n2); - graph.addSource(s3, n3); + graph.addInputStream(s1, n1); + graph.addInputStream(s2, n1); + graph.addInputStream(s3, n2); + graph.addInputStream(s3, n3); - assertTrue(graph.getSources().size() == 3); + assertTrue(graph.getInputStreams().size() == 3); assertTrue(graph.getOrCreateJobNode("1", "1").getInEdges().size() == 2); assertTrue(graph.getOrCreateJobNode("2", "1").getInEdges().size() == 1); @@ -214,11 +214,11 @@ public void testAddSink() { StreamSpec s1 = genStream(); StreamSpec s2 = genStream(); StreamSpec s3 = genStream(); - graph.addSink(s1, n1); - graph.addSink(s2, n2); - graph.addSink(s3, n2); + graph.addOutputStream(s1, n1); + graph.addOutputStream(s2, n2); + graph.addOutputStream(s3, n2); - assertTrue(graph.getSinks().size() == 3); + assertTrue(graph.getOutputStreams().size() == 3); assertTrue(graph.getOrCreateJobNode("1", "1").getOutEdges().size() == 1); assertTrue(graph.getOrCreateJobNode("2", "1").getOutEdges().size() == 2);