diff --git a/build.gradle b/build.gradle index 48a28f23db..a45a875d1e 100644 --- a/build.gradle +++ b/build.gradle @@ -194,6 +194,7 @@ project(":samza-core_$scalaVersion") { testCompile "org.powermock:powermock-core:$powerMockVersion" testCompile "org.powermock:powermock-module-junit4:$powerMockVersion" testCompile "org.scalatest:scalatest_$scalaVersion:$scalaTestVersion" + testCompile "org.hamcrest:hamcrest-all:$hamcrestVersion" } checkstyle { diff --git a/samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java b/samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java index 96791366f3..b58d5a567a 100644 --- a/samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java @@ -19,6 +19,7 @@ package org.apache.samza.application; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; import java.util.Optional; @@ -26,13 +27,20 @@ import org.apache.samza.config.Config; import org.apache.samza.metrics.MetricsReporterFactory; import org.apache.samza.operators.ContextManager; +import org.apache.samza.operators.KV; import org.apache.samza.operators.TableDescriptor; import org.apache.samza.operators.descriptors.base.stream.InputDescriptor; import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor; import org.apache.samza.operators.descriptors.base.system.SystemDescriptor; +import org.apache.samza.operators.spec.InputOperatorSpec; import org.apache.samza.runtime.ProcessorLifecycleListener; import org.apache.samza.runtime.ProcessorLifecycleListenerFactory; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.serializers.Serde; import org.apache.samza.task.TaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** @@ -46,10 +54,15 @@ */ public abstract class ApplicationDescriptorImpl implements ApplicationDescriptor { + private static final Logger LOGGER = LoggerFactory.getLogger(ApplicationDescriptorImpl.class); - final Config config; private final Class appClass; private final Map reporterFactories = new LinkedHashMap<>(); + // serdes used by input/output/intermediate streams, keyed by streamId + private final Map> streamSerdes = new HashMap<>(); + // serdes used by tables, keyed by tableId + private final Map> tableSerdes = new HashMap<>(); + final Config config; // Default to no-op functions in ContextManager // TODO: this should be replaced by shared context factory defined in SAMZA-1714 @@ -141,6 +154,35 @@ public Optional getDefaultSystemDescriptor() { return Optional.empty(); } + /** + * Get the corresponding {@link KVSerde} for the input {@code inputStreamId} + * + * @param streamId id of the stream + * @return the {@link KVSerde} for the stream. null if the serde is not defined or {@code streamId} does not exist + */ + public KV getStreamSerdes(String streamId) { + return streamSerdes.get(streamId); + } + + /** + * Get the corresponding {@link KVSerde} for the input {@code inputStreamId} + * + * @param tableId id of the table + * @return the {@link KVSerde} for the stream. null if the serde is not defined or {@code streamId} does not exist + */ + public KV getTableSerdes(String tableId) { + return tableSerdes.get(tableId); + } + + /** + * Get the map of all {@link InputOperatorSpec}s in this applicaiton + * + * @return an immutable map from streamId to {@link InputOperatorSpec}. Default to empty map for low-level {@link TaskApplication} + */ + public Map getInputOperators() { + return Collections.EMPTY_MAP; + } + /** * Get all the {@link InputDescriptor}s to this application * @@ -176,4 +218,66 @@ public Optional getDefaultSystemDescriptor() { */ public abstract Set getSystemDescriptors(); + /** + * Get all the unique input streamIds in this application + * + * @return an immutable set of input streamIds + */ + public abstract Set getInputStreamIds(); + + /** + * Get all the unique output streamIds in this application + * + * @return an immutable set of output streamIds + */ + public abstract Set getOutputStreamIds(); + + KV getOrCreateStreamSerdes(String streamId, Serde serde) { + Serde keySerde, valueSerde; + + KV currentSerdePair = streamSerdes.get(streamId); + + if (serde instanceof KVSerde) { + keySerde = ((KVSerde) serde).getKeySerde(); + valueSerde = ((KVSerde) serde).getValueSerde(); + } else { + keySerde = new NoOpSerde(); + valueSerde = serde; + } + + if (currentSerdePair == null) { + if (keySerde instanceof NoOpSerde) { + LOGGER.info("Using NoOpSerde as the key serde for stream " + streamId + + ". Keys will not be (de)serialized"); + } + if (valueSerde instanceof NoOpSerde) { + LOGGER.info("Using NoOpSerde as the value serde for stream " + streamId + + ". Values will not be (de)serialized"); + } + streamSerdes.put(streamId, KV.of(keySerde, valueSerde)); + } else if (!currentSerdePair.getKey().equals(keySerde) || !currentSerdePair.getValue().equals(valueSerde)) { + throw new IllegalArgumentException(String.format("Serde for stream %s is already defined. Cannot change it to " + + "different serdes.", streamId)); + } + return streamSerdes.get(streamId); + } + + KV getOrCreateTableSerdes(String tableId, KVSerde kvSerde) { + Serde keySerde, valueSerde; + keySerde = kvSerde.getKeySerde(); + valueSerde = kvSerde.getValueSerde(); + + if (!tableSerdes.containsKey(tableId)) { + tableSerdes.put(tableId, KV.of(keySerde, valueSerde)); + return tableSerdes.get(tableId); + } + + KV currentSerdePair = tableSerdes.get(tableId); + if (!currentSerdePair.getKey().equals(keySerde) || !currentSerdePair.getValue().equals(valueSerde)) { + throw new IllegalArgumentException(String.format("Serde for table %s is already defined. Cannot change it to " + + "different serdes.", tableId)); + } + return streamSerdes.get(tableId); + } + } \ No newline at end of file diff --git a/samza-core/src/main/java/org/apache/samza/application/StreamApplicationDescriptorImpl.java b/samza-core/src/main/java/org/apache/samza/application/StreamApplicationDescriptorImpl.java index d50b0d02b2..512991381b 100644 --- a/samza-core/src/main/java/org/apache/samza/application/StreamApplicationDescriptorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/application/StreamApplicationDescriptorImpl.java @@ -51,7 +51,6 @@ import org.apache.samza.operators.spec.OutputStreamImpl; import org.apache.samza.operators.stream.IntermediateMessageStreamImpl; import org.apache.samza.serializers.KVSerde; -import org.apache.samza.serializers.NoOpSerde; import org.apache.samza.serializers.Serde; import org.apache.samza.table.Table; import org.apache.samza.table.TableSpec; @@ -78,7 +77,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl inputOperators = new LinkedHashMap<>(); private final Map outputStreams = new LinkedHashMap<>(); - private final Map tables = new LinkedHashMap<>(); + private final Map tables = new LinkedHashMap<>(); private final Set operatorIds = new HashSet<>(); private Optional defaultSystemDescriptorOptional = Optional.empty(); @@ -125,7 +124,7 @@ public MessageStream getInputStream(InputDescriptor inputDescriptor "getInputStream must not be called multiple times with the same streamId: " + streamId); Serde serde = inputDescriptor.getSerde(); - KV kvSerdes = getKVSerdes(streamId, serde); + KV kvSerdes = getOrCreateStreamSerdes(streamId, serde); if (outputStreams.containsKey(streamId)) { OutputStreamImpl outputStream = outputStreams.get(streamId); Serde keySerde = outputStream.getKeySerde(); @@ -156,7 +155,7 @@ public OutputStream getOutputStream(OutputDescriptor outputDescript "getOutputStream must not be called multiple times with the same streamId: " + streamId); Serde serde = outputDescriptor.getSerde(); - KV kvSerdes = getKVSerdes(streamId, serde); + KV kvSerdes = getOrCreateStreamSerdes(streamId, serde); if (inputOperators.containsKey(streamId)) { InputOperatorSpec inputOperatorSpec = inputOperators.get(streamId); Serde keySerde = inputOperatorSpec.getKeySerde(); @@ -186,13 +185,15 @@ public Table> getTable(TableDescriptor tableDescriptor) String.format("add table descriptors multiple times with the same tableId: %s", tableDescriptor.getTableId())); tableDescriptors.put(tableDescriptor.getTableId(), tableDescriptor); - TableSpec tableSpec = ((BaseTableDescriptor) tableDescriptor).getTableSpec(); - if (tables.containsKey(tableSpec)) { + BaseTableDescriptor baseTableDescriptor = (BaseTableDescriptor) tableDescriptor; + TableSpec tableSpec = baseTableDescriptor.getTableSpec(); + if (tables.containsKey(tableSpec.getId())) { throw new IllegalStateException( String.format("getTable() invoked multiple times with the same tableId: %s", tableId)); } - tables.put(tableSpec, new TableImpl(tableSpec)); - return tables.get(tableSpec); + tables.put(tableSpec.getId(), new TableImpl(tableSpec)); + getOrCreateTableSerdes(tableSpec.getId(), baseTableDescriptor.getSerde()); + return tables.get(tableSpec.getId()); } /** @@ -247,6 +248,16 @@ public Set getSystemDescriptors() { return Collections.unmodifiableSet(new HashSet<>(systemDescriptors.values())); } + @Override + public Set getInputStreamIds() { + return Collections.unmodifiableSet(new HashSet<>(inputOperators.keySet())); + } + + @Override + public Set getOutputStreamIds() { + return Collections.unmodifiableSet(new HashSet<>(outputStreams.keySet())); + } + /** * Get the default {@link SystemDescriptor} in this application * @@ -306,7 +317,7 @@ public Map getOutputStreams() { return Collections.unmodifiableMap(outputStreams); } - public Map getTables() { + public Map getTables() { return Collections.unmodifiableMap(tables); } @@ -342,7 +353,7 @@ public IntermediateMessageStreamImpl getIntermediateStream(String streamI kvSerdes = new KV<>(null, null); // and that key and msg serdes are provided for job.default.system in configs } else { isKeyed = serde instanceof KVSerde; - kvSerdes = getKVSerdes(streamId, serde); + kvSerdes = getOrCreateStreamSerdes(streamId, serde); } InputTransformer transformer = (InputTransformer) getDefaultSystemDescriptor() @@ -356,29 +367,6 @@ public IntermediateMessageStreamImpl getIntermediateStream(String streamI return new IntermediateMessageStreamImpl<>(this, inputOperators.get(streamId), outputStreams.get(streamId)); } - private KV getKVSerdes(String streamId, Serde serde) { - Serde keySerde, valueSerde; - - if (serde instanceof KVSerde) { - keySerde = ((KVSerde) serde).getKeySerde(); - valueSerde = ((KVSerde) serde).getValueSerde(); - } else { - keySerde = new NoOpSerde(); - valueSerde = serde; - } - - if (keySerde instanceof NoOpSerde) { - LOGGER.info("Using NoOpSerde as the key serde for stream " + streamId + - ". Keys will not be (de)serialized"); - } - if (valueSerde instanceof NoOpSerde) { - LOGGER.info("Using NoOpSerde as the value serde for stream " + streamId + - ". Values will not be (de)serialized"); - } - - return KV.of(keySerde, valueSerde); - } - // check uniqueness of the {@code systemDescriptor} and add if it is unique private void addSystemDescriptor(SystemDescriptor systemDescriptor) { Preconditions.checkState(!systemDescriptors.containsKey(systemDescriptor.getSystemName()) diff --git a/samza-core/src/main/java/org/apache/samza/application/TaskApplicationDescriptorImpl.java b/samza-core/src/main/java/org/apache/samza/application/TaskApplicationDescriptorImpl.java index 3597d7c4fc..d140a907a0 100644 --- a/samza-core/src/main/java/org/apache/samza/application/TaskApplicationDescriptorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/application/TaskApplicationDescriptorImpl.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; import org.apache.samza.config.Config; +import org.apache.samza.operators.BaseTableDescriptor; import org.apache.samza.operators.TableDescriptor; import org.apache.samza.operators.descriptors.base.stream.InputDescriptor; import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor; @@ -65,6 +66,7 @@ public void addInputStream(InputDescriptor inputDescriptor) { // TODO: SAMZA-1841: need to add to the broadcast streams if inputDescriptor is for a broadcast stream Preconditions.checkState(!inputDescriptors.containsKey(inputDescriptor.getStreamId()), String.format("add input descriptors multiple times with the same streamId: %s", inputDescriptor.getStreamId())); + getOrCreateStreamSerdes(inputDescriptor.getStreamId(), inputDescriptor.getSerde()); inputDescriptors.put(inputDescriptor.getStreamId(), inputDescriptor); addSystemDescriptor(inputDescriptor.getSystemDescriptor()); } @@ -73,6 +75,7 @@ public void addInputStream(InputDescriptor inputDescriptor) { public void addOutputStream(OutputDescriptor outputDescriptor) { Preconditions.checkState(!outputDescriptors.containsKey(outputDescriptor.getStreamId()), String.format("add output descriptors multiple times with the same streamId: %s", outputDescriptor.getStreamId())); + getOrCreateStreamSerdes(outputDescriptor.getStreamId(), outputDescriptor.getSerde()); outputDescriptors.put(outputDescriptor.getStreamId(), outputDescriptor); addSystemDescriptor(outputDescriptor.getSystemDescriptor()); } @@ -81,6 +84,7 @@ public void addOutputStream(OutputDescriptor outputDescriptor) { public void addTable(TableDescriptor tableDescriptor) { Preconditions.checkState(!tableDescriptors.containsKey(tableDescriptor.getTableId()), String.format("add table descriptors multiple times with the same tableId: %s", tableDescriptor.getTableId())); + getOrCreateTableSerdes(tableDescriptor.getTableId(), ((BaseTableDescriptor) tableDescriptor).getSerde()); tableDescriptors.put(tableDescriptor.getTableId(), tableDescriptor); } @@ -111,6 +115,16 @@ public Set getSystemDescriptors() { return Collections.unmodifiableSet(new HashSet<>(systemDescriptors.values())); } + @Override + public Set getInputStreamIds() { + return Collections.unmodifiableSet(new HashSet<>(inputDescriptors.keySet())); + } + + @Override + public Set getOutputStreamIds() { + return Collections.unmodifiableSet(new HashSet<>(outputDescriptors.keySet())); + } + /** * Get the user-defined {@link TaskFactory} * @return the {@link TaskFactory} object diff --git a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java index 46aef8d248..eea63878ee 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java @@ -22,72 +22,57 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; -import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.apache.samza.SamzaException; +import org.apache.samza.application.ApplicationDescriptor; +import org.apache.samza.application.ApplicationDescriptorImpl; +import org.apache.samza.application.LegacyTaskApplication; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.ClusterManagerConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.StreamConfig; -import org.apache.samza.operators.OperatorSpecGraph; -import org.apache.samza.operators.spec.InputOperatorSpec; -import org.apache.samza.operators.spec.JoinOperatorSpec; +import org.apache.samza.operators.BaseTableDescriptor; import org.apache.samza.system.StreamSpec; import org.apache.samza.table.TableSpec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.samza.execution.ExecutionPlanner.StreamEdgeSet.StreamEdgeSetCategory; import static org.apache.samza.util.StreamUtil.*; /** - * The ExecutionPlanner creates the physical execution graph for the {@link OperatorSpecGraph}, and + * The ExecutionPlanner creates the physical execution graph for the {@link ApplicationDescriptorImpl}, and * the intermediate topics needed for the execution. */ // TODO: ExecutionPlanner needs to be able to generate single node JobGraph for low-level TaskApplication as well (SAMZA-1811) public class ExecutionPlanner { private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class); - /* package private */ static final int MAX_INFERRED_PARTITIONS = 256; - private final Config config; - private final StreamConfig streamConfig; private final StreamManager streamManager; public ExecutionPlanner(Config config, StreamManager streamManager) { this.config = config; this.streamManager = streamManager; - this.streamConfig = new StreamConfig(config); } - public ExecutionPlan plan(OperatorSpecGraph opSpecGraph) { + public ExecutionPlan plan(ApplicationDescriptorImpl appDesc) { validateConfig(); - // Create physical job graph based on stream graph - JobGraph jobGraph = createJobGraph(opSpecGraph); - - // Fetch the external streams partition info - fetchInputAndOutputStreamPartitions(jobGraph); + // create physical job graph based on stream graph + JobGraph jobGraph = createJobGraph(config, appDesc); - // Verify agreement in partition count between all joined input/intermediate streams - validateJoinInputStreamPartitions(jobGraph); + // fetch the external streams partition info + setInputAndOutputStreamPartitionCount(jobGraph, streamManager); - if (!jobGraph.getIntermediateStreamEdges().isEmpty()) { - // Set partition count of intermediate streams not participating in joins - setIntermediateStreamPartitions(jobGraph); - - // Validate partition counts were assigned for all intermediate streams - validateIntermediateStreamPartitions(jobGraph); - } + // figure out the partitions for internal streams + new IntermediateStreamManager(config, appDesc).calculatePartitions(jobGraph); return jobGraph; } @@ -103,21 +88,23 @@ private void validateConfig() { } /** - * Creates the physical graph from {@link OperatorSpecGraph} + * Create the physical graph from {@link ApplicationDescriptorImpl} */ - /* package private */ JobGraph createJobGraph(OperatorSpecGraph opSpecGraph) { - JobGraph jobGraph = new JobGraph(config, opSpecGraph); - + /* package private */ + JobGraph createJobGraph(Config config, ApplicationDescriptorImpl appDesc) { + JobGraph jobGraph = new JobGraph(config, appDesc); + StreamConfig streamConfig = new StreamConfig(config); // Source streams contain both input and intermediate streams. - Set sourceStreams = getStreamSpecs(opSpecGraph.getInputOperators().keySet(), streamConfig); + Set sourceStreams = getStreamSpecs(appDesc.getInputStreamIds(), streamConfig); // Sink streams contain both output and intermediate streams. - Set sinkStreams = getStreamSpecs(opSpecGraph.getOutputStreams().keySet(), streamConfig); + Set sinkStreams = getStreamSpecs(appDesc.getOutputStreamIds(), streamConfig); Set intermediateStreams = Sets.intersection(sourceStreams, sinkStreams); Set inputStreams = Sets.difference(sourceStreams, intermediateStreams); Set outputStreams = Sets.difference(sinkStreams, intermediateStreams); - Set tables = opSpecGraph.getTables().keySet(); + Set tables = appDesc.getTableDescriptors().stream() + .map(tableDescriptor -> ((BaseTableDescriptor) tableDescriptor).getTableSpec()).collect(Collectors.toSet()); // For this phase, we have a single job node for the whole dag String jobName = config.get(JobConfig.JOB_NAME()); @@ -136,15 +123,20 @@ private void validateConfig() { // Add tables tables.forEach(spec -> jobGraph.addTable(spec, node)); - jobGraph.validate(); + if (!LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass())) { + // skip the validation when input streamIds are empty. This is only possible for LegacyTaskApplication + jobGraph.validate(); + } return jobGraph; } /** - * Fetches the partitions of input/output streams and update the corresponding StreamEdges. + * Fetch the partitions of source/sink streams and update the StreamEdges. + * @param jobGraph {@link JobGraph} + * @param streamManager the {@link StreamManager} to interface with the streams. */ - /* package private */ void fetchInputAndOutputStreamPartitions(JobGraph jobGraph) { + /* package private */ static void setInputAndOutputStreamPartitionCount(JobGraph jobGraph, StreamManager streamManager) { Set existingStreams = new HashSet<>(); existingStreams.addAll(jobGraph.getInputStreams()); existingStreams.addAll(jobGraph.getOutputStreams()); @@ -182,224 +174,4 @@ private void validateConfig() { } } - /** - * Validates agreement in partition count between input/intermediate streams participating in join operations. - */ - private void validateJoinInputStreamPartitions(JobGraph jobGraph) { - // Group input operator specs (input/intermediate streams) by the joins they participate in. - Multimap joinOpSpecToInputOpSpecs = - OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(jobGraph.getSpecGraph()); - - // Convert every group of input operator specs into a group of corresponding stream edges. - List streamEdgeSets = new ArrayList<>(); - for (JoinOperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) { - Collection joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec); - StreamEdgeSet streamEdgeSet = getStreamEdgeSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph); - streamEdgeSets.add(streamEdgeSet); - } - - /* - * Sort the stream edge groups by their category so they appear in this order: - * 1. groups composed exclusively of stream edges with set partition counts - * 2. groups composed of a mix of stream edges with set/unset partition counts - * 3. groups composed exclusively of stream edges with unset partition counts - * - * This guarantees that we process the most constrained stream edge groups first, - * which is crucial for intermediate stream edges that are members of multiple - * stream edge groups. For instance, if we have the following groups of stream - * edges (partition counts in parentheses, question marks for intermediate streams): - * - * a. e1 (16), e2 (16) - * b. e2 (16), e3 (?) - * c. e3 (?), e4 (?) - * - * processing them in the above order (most constrained first) is guaranteed to - * yield correct assignment of partition counts of e3 and e4 in a single scan. - */ - Collections.sort(streamEdgeSets, Comparator.comparingInt(e -> e.getCategory().getSortOrder())); - - // Verify agreement between joined input/intermediate streams. - // This may involve setting partition counts of intermediate stream edges. - streamEdgeSets.forEach(ExecutionPlanner::validateAndAssignStreamEdgeSetPartitions); - } - - /** - * Creates a {@link StreamEdgeSet} whose Id is {@code setId}, and {@link StreamEdge}s - * correspond to the provided {@code inputOpSpecs}. - */ - private StreamEdgeSet getStreamEdgeSet(String setId, Iterable inputOpSpecs, - JobGraph jobGraph) { - - int countStreamEdgeWithSetPartitions = 0; - Set streamEdges = new HashSet<>(); - - for (InputOperatorSpec inputOpSpec : inputOpSpecs) { - StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(inputOpSpec.getStreamId(), streamConfig)); - if (streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) { - ++countStreamEdgeWithSetPartitions; - } - streamEdges.add(streamEdge); - } - - // Determine category of stream group based on stream partition counts. - StreamEdgeSetCategory category; - if (countStreamEdgeWithSetPartitions == 0) { - category = StreamEdgeSetCategory.NO_PARTITION_COUNT_SET; - } else if (countStreamEdgeWithSetPartitions == streamEdges.size()) { - category = StreamEdgeSetCategory.ALL_PARTITION_COUNT_SET; - } else { - category = StreamEdgeSetCategory.SOME_PARTITION_COUNT_SET; - } - - return new StreamEdgeSet(setId, streamEdges, category); - } - - /** - * Sets partition count of intermediate streams which have not been assigned partition counts. - */ - private void setIntermediateStreamPartitions(JobGraph jobGraph) { - final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(); - int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN); - if (partitions == StreamEdge.PARTITIONS_UNKNOWN) { - // use the following simple algo to figure out the partitions - // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions)) - // partition will be further bounded by MAX_INFERRED_PARTITIONS. - // This is important when running in hadoop where an HDFS input can have lots of files (partitions). - int maxInPartitions = maxPartitions(jobGraph.getInputStreams()); - int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams()); - partitions = Math.max(maxInPartitions, maxOutPartitions); - - if (partitions > MAX_INFERRED_PARTITIONS) { - partitions = MAX_INFERRED_PARTITIONS; - log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", - partitions, MAX_INFERRED_PARTITIONS)); - } - } else { - // Reject any zero or other negative values explicitly specified in config. - if (partitions <= 0) { - throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions, - defaultPartitionsConfigProperty)); - } - - log.info("Using partition count value {} specified for config property {}", partitions, - defaultPartitionsConfigProperty); - } - - for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { - if (edge.getPartitionCount() <= 0) { - log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions); - edge.setPartitionCount(partitions); - } - } - } - - /** - * Ensures all intermediate streams have been assigned partition counts. - */ - private static void validateIntermediateStreamPartitions(JobGraph jobGraph) { - for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { - if (edge.getPartitionCount() <= 0) { - throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName())); - } - } - } - - /** - * Ensures that all streams in the supplied {@link StreamEdgeSet} agree in partition count. - * This may include setting partition counts of intermediate streams in this set that do not - * have their partition counts set. - */ - private static void validateAndAssignStreamEdgeSetPartitions(StreamEdgeSet streamEdgeSet) { - Set streamEdges = streamEdgeSet.getStreamEdges(); - StreamEdge firstStreamEdgeWithSetPartitions = - streamEdges.stream() - .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) - .findFirst() - .orElse(null); - - // This group consists exclusively of intermediate streams with unknown partition counts. - // We cannot do any validation/computation of partition counts of such streams right here, - // but they are tackled later in the ExecutionPlanner. - if (firstStreamEdgeWithSetPartitions == null) { - return; - } - - // Make sure all other stream edges in this group have the same partition count. - int partitions = firstStreamEdgeWithSetPartitions.getPartitionCount(); - for (StreamEdge streamEdge : streamEdges) { - int streamPartitions = streamEdge.getPartitionCount(); - if (streamPartitions == StreamEdge.PARTITIONS_UNKNOWN) { - streamEdge.setPartitionCount(partitions); - log.info("Inferred the partition count {} for the join operator {} from {}." - , new Object[] {partitions, streamEdgeSet.getSetId(), firstStreamEdgeWithSetPartitions.getName()}); - } else if (streamPartitions != partitions) { - throw new SamzaException(String.format( - "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d", - streamEdge.getName(), streamEdgeSet.getSetId(), partitions, streamPartitions)); - } - } - } - - /* package private */ static int maxPartitions(Collection edges) { - return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN); - } - - /** - * Represents a set of {@link StreamEdge}s. - */ - /* package private */ static class StreamEdgeSet { - - /** - * Indicates whether all stream edges in this group have their partition counts assigned. - */ - public enum StreamEdgeSetCategory { - /** - * All stream edges in this group have their partition counts assigned. - */ - ALL_PARTITION_COUNT_SET(0), - - /** - * Only some stream edges in this group have their partition counts assigned. - */ - SOME_PARTITION_COUNT_SET(1), - - /** - * No stream edge in this group is assigned a partition count. - */ - NO_PARTITION_COUNT_SET(2); - - - private final int sortOrder; - - StreamEdgeSetCategory(int sortOrder) { - this.sortOrder = sortOrder; - } - - public int getSortOrder() { - return sortOrder; - } - } - - private final String setId; - private final Set streamEdges; - private final StreamEdgeSetCategory category; - - public StreamEdgeSet(String setId, Set streamEdges, StreamEdgeSetCategory category) { - this.setId = setId; - this.streamEdges = streamEdges; - this.category = category; - } - - public Set getStreamEdges() { - return streamEdges; - } - - public String getSetId() { - return setId; - } - - public StreamEdgeSetCategory getCategory() { - return category; - } - } } diff --git a/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java b/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java new file mode 100644 index 0000000000..66cbe6a01c --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.execution; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Multimap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.samza.SamzaException; +import org.apache.samza.application.ApplicationDescriptor; +import org.apache.samza.application.ApplicationDescriptorImpl; +import org.apache.samza.config.Config; +import org.apache.samza.config.JobConfig; +import org.apache.samza.operators.spec.InputOperatorSpec; +import org.apache.samza.operators.spec.JoinOperatorSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link IntermediateStreamManager} calculates intermediate stream partitions based on the high-level application graph. + */ +class IntermediateStreamManager { + + private static final Logger log = LoggerFactory.getLogger(IntermediateStreamManager.class); + + private final Config config; + private final Map inputOperators; + + @VisibleForTesting + static final int MAX_INFERRED_PARTITIONS = 256; + + IntermediateStreamManager(Config config, ApplicationDescriptorImpl appDesc) { + this.config = config; + this.inputOperators = appDesc.getInputOperators(); + } + + /** + * Figure out the number of partitions of all streams + */ + /* package private */ void calculatePartitions(JobGraph jobGraph) { + + // Verify agreement in partition count between all joined input/intermediate streams + validateJoinInputStreamPartitions(jobGraph); + + if (!jobGraph.getIntermediateStreamEdges().isEmpty()) { + // Set partition count of intermediate streams not participating in joins + setIntermediateStreamPartitions(jobGraph); + + // Validate partition counts were assigned for all intermediate streams + validateIntermediateStreamPartitions(jobGraph); + } + } + + /** + * Validates agreement in partition count between input/intermediate streams participating in join operations. + */ + private void validateJoinInputStreamPartitions(JobGraph jobGraph) { + // Group input operator specs (input/intermediate streams) by the joins they participate in. + Multimap joinOpSpecToInputOpSpecs = + OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(inputOperators.values()); + + // Convert every group of input operator specs into a group of corresponding stream edges. + List streamEdgeSets = new ArrayList<>(); + for (JoinOperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) { + Collection joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec); + StreamEdgeSet streamEdgeSet = getStreamEdgeSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph); + streamEdgeSets.add(streamEdgeSet); + } + + /* + * Sort the stream edge groups by their category so they appear in this order: + * 1. groups composed exclusively of stream edges with set partition counts + * 2. groups composed of a mix of stream edges with set/unset partition counts + * 3. groups composed exclusively of stream edges with unset partition counts + * + * This guarantees that we process the most constrained stream edge groups first, + * which is crucial for intermediate stream edges that are members of multiple + * stream edge groups. For instance, if we have the following groups of stream + * edges (partition counts in parentheses, question marks for intermediate streams): + * + * a. e1 (16), e2 (16) + * b. e2 (16), e3 (?) + * c. e3 (?), e4 (?) + * + * processing them in the above order (most constrained first) is guaranteed to + * yield correct assignment of partition counts of e3 and e4 in a single scan. + */ + Collections.sort(streamEdgeSets, Comparator.comparingInt(e -> e.getCategory().getSortOrder())); + + // Verify agreement between joined input/intermediate streams. + // This may involve setting partition counts of intermediate stream edges. + streamEdgeSets.forEach(IntermediateStreamManager::validateAndAssignStreamEdgeSetPartitions); + } + + /** + * Creates a {@link StreamEdgeSet} whose Id is {@code setId}, and {@link StreamEdge}s + * correspond to the provided {@code inputOpSpecs}. + */ + private StreamEdgeSet getStreamEdgeSet(String setId, Iterable inputOpSpecs, + JobGraph jobGraph) { + + int countStreamEdgeWithSetPartitions = 0; + Set streamEdges = new HashSet<>(); + + for (InputOperatorSpec inputOpSpec : inputOpSpecs) { + StreamEdge streamEdge = jobGraph.getStreamEdge(inputOpSpec.getStreamId()); + if (streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) { + ++countStreamEdgeWithSetPartitions; + } + streamEdges.add(streamEdge); + } + + // Determine category of stream group based on stream partition counts. + StreamEdgeSet.StreamEdgeSetCategory category; + if (countStreamEdgeWithSetPartitions == 0) { + category = StreamEdgeSet.StreamEdgeSetCategory.NO_PARTITION_COUNT_SET; + } else if (countStreamEdgeWithSetPartitions == streamEdges.size()) { + category = StreamEdgeSet.StreamEdgeSetCategory.ALL_PARTITION_COUNT_SET; + } else { + category = StreamEdgeSet.StreamEdgeSetCategory.SOME_PARTITION_COUNT_SET; + } + + return new StreamEdgeSet(setId, streamEdges, category); + } + + /** + * Sets partition count of intermediate streams which have not been assigned partition counts. + */ + private void setIntermediateStreamPartitions(JobGraph jobGraph) { + final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(); + int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN); + if (partitions == StreamEdge.PARTITIONS_UNKNOWN) { + // use the following simple algo to figure out the partitions + // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions)) + // partition will be further bounded by MAX_INFERRED_PARTITIONS. + // This is important when running in hadoop where an HDFS input can have lots of files (partitions). + int maxInPartitions = maxPartitions(jobGraph.getInputStreams()); + int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams()); + partitions = Math.max(maxInPartitions, maxOutPartitions); + + if (partitions > MAX_INFERRED_PARTITIONS) { + partitions = MAX_INFERRED_PARTITIONS; + log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", + partitions, MAX_INFERRED_PARTITIONS)); + } + } else { + // Reject any zero or other negative values explicitly specified in config. + if (partitions <= 0) { + throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions, + defaultPartitionsConfigProperty)); + } + + log.info("Using partition count value {} specified for config property {}", partitions, + defaultPartitionsConfigProperty); + } + + for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { + if (edge.getPartitionCount() <= 0) { + log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions); + edge.setPartitionCount(partitions); + } + } + } + + /** + * Ensures all intermediate streams have been assigned partition counts. + */ + private static void validateIntermediateStreamPartitions(JobGraph jobGraph) { + for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) { + if (edge.getPartitionCount() <= 0) { + throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName())); + } + } + } + + /** + * Ensures that all streams in the supplied {@link StreamEdgeSet} agree in partition count. + * This may include setting partition counts of intermediate streams in this set that do not + * have their partition counts set. + */ + private static void validateAndAssignStreamEdgeSetPartitions(StreamEdgeSet streamEdgeSet) { + Set streamEdges = streamEdgeSet.getStreamEdges(); + StreamEdge firstStreamEdgeWithSetPartitions = + streamEdges.stream() + .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) + .findFirst() + .orElse(null); + + // This group consists exclusively of intermediate streams with unknown partition counts. + // We cannot do any validation/computation of partition counts of such streams right here, + // but they are tackled later in the ExecutionPlanner. + if (firstStreamEdgeWithSetPartitions == null) { + return; + } + + // Make sure all other stream edges in this group have the same partition count. + int partitions = firstStreamEdgeWithSetPartitions.getPartitionCount(); + for (StreamEdge streamEdge : streamEdges) { + int streamPartitions = streamEdge.getPartitionCount(); + if (streamPartitions == StreamEdge.PARTITIONS_UNKNOWN) { + streamEdge.setPartitionCount(partitions); + log.info("Inferred the partition count {} for the join operator {} from {}.", + new Object[] {partitions, streamEdgeSet.getSetId(), firstStreamEdgeWithSetPartitions.getName()}); + } else if (streamPartitions != partitions) { + throw new SamzaException(String.format( + "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d", + streamEdge.getName(), streamEdgeSet.getSetId(), partitions, streamPartitions)); + } + } + } + + /* package private */ static int maxPartitions(Collection edges) { + return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN); + } + + /** + * Represents a set of {@link StreamEdge}s. + */ + /* package private */ static class StreamEdgeSet { + + /** + * Indicates whether all stream edges in this group have their partition counts assigned. + */ + public enum StreamEdgeSetCategory { + /** + * All stream edges in this group have their partition counts assigned. + */ + ALL_PARTITION_COUNT_SET(0), + + /** + * Only some stream edges in this group have their partition counts assigned. + */ + SOME_PARTITION_COUNT_SET(1), + + /** + * No stream edge in this group is assigned a partition count. + */ + NO_PARTITION_COUNT_SET(2); + + + private final int sortOrder; + + StreamEdgeSetCategory(int sortOrder) { + this.sortOrder = sortOrder; + } + + public int getSortOrder() { + return sortOrder; + } + } + + private final String setId; + private final Set streamEdges; + private final StreamEdgeSetCategory category; + + StreamEdgeSet(String setId, Set streamEdges, StreamEdgeSetCategory category) { + this.setId = setId; + this.streamEdges = streamEdges; + this.category = category; + } + + Set getStreamEdges() { + return streamEdges; + } + + String getSetId() { + return setId; + } + + StreamEdgeSetCategory getCategory() { + return category; + } + } +} diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java index 5b190954c6..d975188520 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobGraph.java @@ -31,10 +31,11 @@ import java.util.Set; import java.util.stream.Collectors; +import org.apache.samza.application.ApplicationDescriptor; +import org.apache.samza.application.ApplicationDescriptorImpl; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; -import org.apache.samza.operators.OperatorSpecGraph; import org.apache.samza.system.StreamSpec; import org.apache.samza.table.TableSpec; import org.slf4j.Logger; @@ -59,16 +60,21 @@ private final Set intermediateStreams = new HashSet<>(); private final Set tables = new HashSet<>(); private final Config config; - private final JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator(); - private final OperatorSpecGraph specGraph; + private final JobGraphJsonGenerator jsonGenerator; + private final JobNodeConfigurationGenerator configGenerator; + private final ApplicationDescriptorImpl appDesc; /** * The JobGraph is only constructed by the {@link ExecutionPlanner}. - * @param config Config + * + * @param config configuration for the application + * @param appDesc {@link ApplicationDescriptorImpl} describing the application */ - JobGraph(Config config, OperatorSpecGraph specGraph) { + JobGraph(Config config, ApplicationDescriptorImpl appDesc) { this.config = config; - this.specGraph = specGraph; + this.appDesc = appDesc; + this.jsonGenerator = new JobGraphJsonGenerator(); + this.configGenerator = new JobNodeConfigurationGenerator(); } @Override @@ -91,11 +97,6 @@ public List getIntermediateStreams() { .collect(Collectors.toList()); } - void addTable(TableSpec tableSpec, JobNode node) { - tables.add(tableSpec); - node.addTable(tableSpec); - } - @Override public String getPlanAsJson() throws Exception { return jsonGenerator.toJson(this); @@ -105,14 +106,11 @@ public String getPlanAsJson() throws Exception { * Returns the config for this application * @return {@link ApplicationConfig} */ + @Override public ApplicationConfig getApplicationConfig() { return new ApplicationConfig(config); } - public OperatorSpecGraph getSpecGraph() { - return specGraph; - } - /** * Add a source stream to a {@link JobNode} * @param streamSpec input stream @@ -152,20 +150,20 @@ void addIntermediateStream(StreamSpec streamSpec, JobNode from, JobNode to) { intermediateStreams.add(edge); } + void addTable(TableSpec tableSpec, JobNode node) { + tables.add(tableSpec); + node.addTable(tableSpec); + } + /** * Get the {@link JobNode}. Create one if it does not exist. * @param jobName name of the job * @param jobId id of the job - * @return + * @return {@link JobNode} created with {@code jobName} and {@code jobId} */ JobNode getOrCreateJobNode(String jobName, String jobId) { - String nodeId = JobNode.createId(jobName, jobId); - JobNode node = nodes.get(nodeId); - if (node == null) { - node = new JobNode(jobName, jobId, specGraph, config); - nodes.put(nodeId, node); - } - return node; + String nodeId = JobNode.createJobNameAndId(jobName, jobId); + return nodes.computeIfAbsent(nodeId, k -> new JobNode(jobName, jobId, config, appDesc, configGenerator)); } /** @@ -178,20 +176,13 @@ StreamEdge getOrCreateStreamEdge(StreamSpec streamSpec) { } /** - * Get the {@link StreamEdge} for a {@link StreamSpec}. Create one if it does not exist. - * @param streamSpec spec of the StreamEdge - * @param isIntermediate boolean flag indicating whether it's an intermediate stream + * Get the {@link StreamEdge} for {@code streamId}. + * + * @param streamId the streamId for the {@link StreamEdge} * @return stream edge */ - StreamEdge getOrCreateStreamEdge(StreamSpec streamSpec, boolean isIntermediate) { - String streamId = streamSpec.getId(); - StreamEdge edge = edges.get(streamId); - if (edge == null) { - boolean isBroadcast = specGraph.getBroadcastStreams().contains(streamId); - edge = new StreamEdge(streamSpec, isIntermediate, isBroadcast, config); - edges.put(streamId, edge); - } - return edge; + StreamEdge getStreamEdge(String streamId) { + return edges.get(streamId); } /** @@ -247,6 +238,23 @@ void validate() { validateReachability(); } + /** + * Get the {@link StreamEdge} for a {@link StreamSpec}. Create one if it does not exist. + * @param streamSpec spec of the StreamEdge + * @param isIntermediate boolean flag indicating whether it's an intermediate stream + * @return stream edge + */ + private StreamEdge getOrCreateStreamEdge(StreamSpec streamSpec, boolean isIntermediate) { + String streamId = streamSpec.getId(); + StreamEdge edge = edges.get(streamId); + if (edge == null) { + boolean isBroadcast = appDesc.getBroadcastStreams().contains(streamId); + edge = new StreamEdge(streamSpec, isIntermediate, isBroadcast, config); + edges.put(streamId, edge); + } + return edge; + } + /** * Validate the input streams should have indegree being 0 and outdegree greater than 0 */ @@ -305,7 +313,7 @@ private void validateReachability() { Set unreachable = new HashSet<>(nodes.values()); unreachable.removeAll(reachable); throw new IllegalArgumentException(String.format("Jobs %s cannot be reached from Sources.", - String.join(", ", unreachable.stream().map(JobNode::getId).collect(Collectors.toList())))); + String.join(", ", unreachable.stream().map(JobNode::getJobNameAndId).collect(Collectors.toList())))); } } @@ -325,7 +333,7 @@ Set findReachable() { while (!queue.isEmpty()) { JobNode node = queue.poll(); - node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(target -> { + node.getOutEdges().values().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(target -> { if (!visited.contains(target)) { visited.add(target); queue.offer(target); @@ -351,9 +359,9 @@ List topologicalSort() { Map indegree = new HashMap<>(); Set visited = new HashSet<>(); pnodes.forEach(node -> { - String nid = node.getId(); + String nid = node.getJobNameAndId(); //only count the degrees of intermediate streams - long degree = node.getInEdges().stream().filter(e -> !inputStreams.contains(e)).count(); + long degree = node.getInEdges().values().stream().filter(e -> !inputStreams.contains(e)).count(); indegree.put(nid, degree); if (degree == 0L) { @@ -378,8 +386,8 @@ List topologicalSort() { while (!q.isEmpty()) { JobNode node = q.poll(); sortedNodes.add(node); - node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(n -> { - String nid = n.getId(); + node.getOutEdges().values().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(n -> { + String nid = n.getJobNameAndId(); Long degree = indegree.get(nid) - 1; indegree.put(nid, degree); if (degree == 0L && !visited.contains(n)) { @@ -400,7 +408,7 @@ List topologicalSort() { long min = Long.MAX_VALUE; JobNode minNode = null; for (JobNode node : reachable) { - Long degree = indegree.get(node.getId()); + Long degree = indegree.get(node.getJobNameAndId()); if (degree < min) { min = degree; minNode = node; diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java b/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java index 91453d2dbd..18705e4f41 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java @@ -32,7 +32,6 @@ import org.apache.samza.config.ApplicationConfig; import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; -import org.apache.samza.operators.spec.OperatorSpec.OpCode; import org.apache.samza.operators.spec.OutputOperatorSpec; import org.apache.samza.operators.spec.OutputStreamImpl; import org.apache.samza.operators.spec.PartitionByOperatorSpec; @@ -140,7 +139,7 @@ static final class JobGraphJson { jobGraph.getTables().forEach(t -> buildTableJson(t, jobGraphJson.tables)); jobGraphJson.jobs = jobGraph.getJobNodes().stream() - .map(jobNode -> buildJobNodeJson(jobNode)) + .map(this::buildJobNodeJson) .collect(Collectors.toList()); ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -149,54 +148,12 @@ static final class JobGraphJson { return new String(out.toByteArray()); } - /** - * Create JSON POJO for a {@link JobNode}, including the {@link org.apache.samza.operators.StreamGraph} for this job - * @param jobNode job node in the {@link JobGraph} - * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.JobNodeJson} - */ - private JobNodeJson buildJobNodeJson(JobNode jobNode) { - JobNodeJson job = new JobNodeJson(); - job.jobName = jobNode.getJobName(); - job.jobId = jobNode.getJobId(); - job.operatorGraph = buildOperatorGraphJson(jobNode); - return job; - } - - /** - * Traverse the {@link OperatorSpec} graph and build the operator graph JSON POJO. - * @param jobNode job node in the {@link JobGraph} - * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.OperatorGraphJson} - */ - private OperatorGraphJson buildOperatorGraphJson(JobNode jobNode) { - OperatorGraphJson opGraph = new OperatorGraphJson(); - opGraph.inputStreams = new ArrayList<>(); - jobNode.getSpecGraph().getInputOperators().forEach((streamId, operatorSpec) -> { - StreamJson inputJson = new StreamJson(); - opGraph.inputStreams.add(inputJson); - inputJson.streamId = streamId; - inputJson.nextOperatorIds = operatorSpec.getRegisteredOperatorSpecs().stream() - .map(OperatorSpec::getOpId).collect(Collectors.toSet()); - - updateOperatorGraphJson(operatorSpec, opGraph); - }); - - opGraph.outputStreams = new ArrayList<>(); - jobNode.getSpecGraph().getOutputStreams().keySet().forEach(streamId -> { - StreamJson outputJson = new StreamJson(); - outputJson.streamId = streamId; - opGraph.outputStreams.add(outputJson); - }); - return opGraph; - } - - /** - * Traverse the {@link OperatorSpec} graph recursively and update the operator graph JSON POJO. - * @param operatorSpec input - * @param opGraph operator graph to build - */ private void updateOperatorGraphJson(OperatorSpec operatorSpec, OperatorGraphJson opGraph) { - // TODO xiliu: render input operators instead of input streams - if (operatorSpec.getOpCode() != OpCode.INPUT) { + if (operatorSpec == null) { + // task application may not have any defined OperatorSpec + return; + } + if (operatorSpec.getOpCode() != OperatorSpec.OpCode.INPUT) { opGraph.operators.put(operatorSpec.getOpId(), operatorToMap(operatorSpec)); } Collection specs = operatorSpec.getRegisteredOperatorSpecs(); @@ -242,6 +199,46 @@ private Map operatorToMap(OperatorSpec spec) { return map; } + /** + * Create JSON POJO for a {@link JobNode}, including the {@link org.apache.samza.application.ApplicationDescriptorImpl} + * for this job + * + * @param jobNode job node in the {@link JobGraph} + * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.JobNodeJson} + */ + private JobNodeJson buildJobNodeJson(JobNode jobNode) { + JobNodeJson job = new JobNodeJson(); + job.jobName = jobNode.getJobName(); + job.jobId = jobNode.getJobId(); + job.operatorGraph = buildOperatorGraphJson(jobNode); + return job; + } + + /** + * Traverse the {@link OperatorSpec} graph and build the operator graph JSON POJO. + * @param jobNode job node in the {@link JobGraph} + * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.OperatorGraphJson} + */ + private OperatorGraphJson buildOperatorGraphJson(JobNode jobNode) { + OperatorGraphJson opGraph = new OperatorGraphJson(); + opGraph.inputStreams = new ArrayList<>(); + jobNode.getInEdges().values().forEach(inStream -> { + StreamJson inputJson = new StreamJson(); + opGraph.inputStreams.add(inputJson); + inputJson.streamId = inStream.getStreamSpec().getId(); + inputJson.nextOperatorIds = jobNode.getNextOperatorIds(inputJson.streamId); + updateOperatorGraphJson(jobNode.getInputOperator(inputJson.streamId), opGraph); + }); + + opGraph.outputStreams = new ArrayList<>(); + jobNode.getOutEdges().values().forEach(outStream -> { + StreamJson outputJson = new StreamJson(); + outputJson.streamId = outStream.getStreamSpec().getId(); + opGraph.outputStreams.add(outputJson); + }); + return opGraph; + } + /** * Get or create the JSON POJO for a {@link StreamEdge} * @param edge {@link StreamEdge} @@ -261,15 +258,11 @@ private StreamEdgeJson buildStreamEdgeJson(StreamEdge edge, Map sourceJobs = new ArrayList<>(); - edge.getSourceNodes().forEach(jobNode -> { - sourceJobs.add(jobNode.getJobName()); - }); + edge.getSourceNodes().forEach(jobNode -> sourceJobs.add(jobNode.getJobName())); edgeJson.sourceJobs = sourceJobs; List targetJobs = new ArrayList<>(); - edge.getTargetNodes().forEach(jobNode -> { - targetJobs.add(jobNode.getJobName()); - }); + edge.getTargetNodes().forEach(jobNode -> targetJobs.add(jobNode.getJobName())); edgeJson.targetJobs = targetJobs; streamEdges.put(streamId, edgeJson); @@ -285,12 +278,7 @@ private StreamEdgeJson buildStreamEdgeJson(StreamEdge edge, Map tableSpecs) { String tableId = tableSpec.getId(); - TableSpecJson tableSpecJson = tableSpecs.get(tableId); - if (tableSpecJson == null) { - tableSpecJson = buildTableJson(tableSpec); - tableSpecs.put(tableId, tableSpecJson); - } - return tableSpecJson; + return tableSpecs.computeIfAbsent(tableId, k -> buildTableJson(tableSpec)); } /** diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobNode.java b/samza-core/src/main/java/org/apache/samza/execution/JobNode.java index 47705ee3eb..af556f5e1d 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobNode.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobNode.java @@ -19,45 +19,26 @@ package org.apache.samza.execution; -import java.util.ArrayList; -import java.util.Base64; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; -import java.util.UUID; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; - -import org.apache.commons.lang3.StringUtils; +import org.apache.samza.application.ApplicationDescriptor; +import org.apache.samza.application.ApplicationDescriptorImpl; +import org.apache.samza.application.LegacyTaskApplication; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; -import org.apache.samza.config.MapConfig; -import org.apache.samza.config.SerializerConfig; -import org.apache.samza.config.StorageConfig; -import org.apache.samza.config.StreamConfig; -import org.apache.samza.config.TaskConfig; -import org.apache.samza.config.TaskConfigJava; -import org.apache.samza.operators.OperatorSpecGraph; +import org.apache.samza.operators.KV; import org.apache.samza.operators.spec.InputOperatorSpec; -import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; -import org.apache.samza.operators.spec.OutputStreamImpl; -import org.apache.samza.operators.spec.StatefulOperatorSpec; -import org.apache.samza.operators.spec.WindowOperatorSpec; -import org.apache.samza.table.TableConfigGenerator; -import org.apache.samza.util.MathUtil; import org.apache.samza.serializers.Serde; -import org.apache.samza.serializers.SerializableSerde; import org.apache.samza.table.TableSpec; -import org.apache.samza.util.StreamUtil; -import org.apache.samza.util.Util; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.base.Joiner; - - /** * A JobNode is a physical execution unit. In RemoteExecutionEnvironment, it's a job that will be submitted * to remote cluster. In LocalExecutionEnvironment, it's a set of StreamProcessors for local execution. @@ -65,64 +46,71 @@ */ public class JobNode { private static final Logger log = LoggerFactory.getLogger(JobNode.class); - private static final String CONFIG_INTERNAL_EXECUTION_PLAN = "samza.internal.execution.plan"; private final String jobName; private final String jobId; - private final String id; - private final OperatorSpecGraph specGraph; - private final List inEdges = new ArrayList<>(); - private final List outEdges = new ArrayList<>(); - private final List tables = new ArrayList<>(); + private final String jobNameAndId; private final Config config; - - JobNode(String jobName, String jobId, OperatorSpecGraph specGraph, Config config) { + private final JobNodeConfigurationGenerator configGenerator; + // The following maps (i.e. inEdges and outEdges) uses the streamId as the key + private final Map inEdges = new HashMap<>(); + private final Map outEdges = new HashMap<>(); + // Similarly, tables uses tableId as the key + private final Map tables = new HashMap<>(); + private final ApplicationDescriptorImpl appDesc; + + JobNode(String jobName, String jobId, Config config, ApplicationDescriptorImpl appDesc, + JobNodeConfigurationGenerator configureGenerator) { this.jobName = jobName; this.jobId = jobId; - this.id = createId(jobName, jobId); - this.specGraph = specGraph; + this.jobNameAndId = createJobNameAndId(jobName, jobId); this.config = config; + this.appDesc = appDesc; + this.configGenerator = configureGenerator; } - public static Config mergeJobConfig(Config fullConfig, Config generatedConfig) { - return new JobConfig(Util.rewriteConfig(extractScopedConfig( - fullConfig, generatedConfig, String.format(JobConfig.CONFIG_JOB_PREFIX(), new JobConfig(fullConfig).getName().get())))); - } - - public OperatorSpecGraph getSpecGraph() { - return this.specGraph; + static String createJobNameAndId(String jobName, String jobId) { + return String.format("%s-%s", jobName, jobId); } - public String getId() { - return id; + String getJobNameAndId() { + return jobNameAndId; } - public String getJobName() { + String getJobName() { return jobName; } - public String getJobId() { + String getJobId() { return jobId; } + Config getConfig() { + return config; + } + void addInEdge(StreamEdge in) { - inEdges.add(in); + inEdges.put(in.getStreamSpec().getId(), in); } void addOutEdge(StreamEdge out) { - outEdges.add(out); + outEdges.put(out.getStreamSpec().getId(), out); } - List getInEdges() { + void addTable(TableSpec tableSpec) { + tables.put(tableSpec.getId(), tableSpec); + } + + Map getInEdges() { return inEdges; } - List getOutEdges() { + Map getOutEdges() { return outEdges; } - void addTable(TableSpec tableSpec) { - tables.add(tableSpec); + Map getTables() { + return tables; } /** @@ -130,250 +118,65 @@ void addTable(TableSpec tableSpec) { * @param executionPlanJson JSON representation of the execution plan * @return config of the job */ - public JobConfig generateConfig(String executionPlanJson) { - Map configs = new HashMap<>(); - configs.put(JobConfig.JOB_NAME(), jobName); - configs.put(JobConfig.JOB_ID(), jobId); + JobConfig generateConfig(String executionPlanJson) { + return configGenerator.generateJobConfig(this, executionPlanJson); + } - final List inputs = new ArrayList<>(); - final List broadcasts = new ArrayList<>(); - for (StreamEdge inEdge : inEdges) { - String formattedSystemStream = inEdge.getName(); - if (inEdge.isBroadcast()) { - broadcasts.add(formattedSystemStream + "#0"); - } else { - inputs.add(formattedSystemStream); - } + KV getInputSerdes(String streamId) { + if (!inEdges.containsKey(streamId)) { + return null; } + return appDesc.getStreamSerdes(streamId); + } - if (!broadcasts.isEmpty()) { - // TODO: remove this once we support defining broadcast input stream in high-level - // task.broadcast.input should be generated by the planner in the future. - final String taskBroadcasts = config.get(TaskConfigJava.BROADCAST_INPUT_STREAMS); - if (StringUtils.isNoneEmpty(taskBroadcasts)) { - broadcasts.add(taskBroadcasts); - } - configs.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcasts)); + KV getOutputSerde(String streamId) { + if (!outEdges.containsKey(streamId)) { + return null; } + return appDesc.getStreamSerdes(streamId); + } - // set triggering interval if a window or join is defined - if (specGraph.hasWindowOrJoins()) { - if ("-1".equals(config.get(TaskConfig.WINDOW_MS(), "-1"))) { - long triggerInterval = computeTriggerInterval(); - log.info("Using triggering interval: {} for jobName: {}", triggerInterval, jobName); + Collection getReachableOperators() { + Set inputOperatorsInJobNode = inEdges.values().stream().map(inEdge -> + appDesc.getInputOperators().get(inEdge.getStreamSpec().getId())).filter(Objects::nonNull).collect(Collectors.toSet()); + Set reachableOperators = new HashSet<>(); + findReachableOperators(inputOperatorsInJobNode, reachableOperators); + return reachableOperators; + } - configs.put(TaskConfig.WINDOW_MS(), String.valueOf(triggerInterval)); - } + // get all next operators consuming from the input {@code streamId} + Set getNextOperatorIds(String streamId) { + if (!appDesc.getInputOperators().containsKey(streamId) || !inEdges.containsKey(streamId)) { + return new HashSet<>(); } + return appDesc.getInputOperators().get(streamId).getRegisteredOperatorSpecs().stream() + .map(op -> op.getOpId()).collect(Collectors.toSet()); + } - specGraph.getAllOperatorSpecs().forEach(opSpec -> { - if (opSpec instanceof StatefulOperatorSpec) { - ((StatefulOperatorSpec) opSpec).getStoreDescriptors() - .forEach(sd -> configs.putAll(sd.getStorageConfigs())); - // store key and message serdes are configured separately in #addSerdeConfigs - } - }); - - configs.put(CONFIG_INTERNAL_EXECUTION_PLAN, executionPlanJson); - - // write input/output streams to configs - inEdges.stream().filter(StreamEdge::isIntermediate).forEach(edge -> configs.putAll(edge.generateConfig())); - - // write serialized serde instances and stream serde configs to configs - addSerdeConfigs(configs); - - configs.putAll(TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(configs), tables)); - - // Add side inputs to the inputs and mark the stream as bootstrap - tables.forEach(tableSpec -> { - List sideInputs = tableSpec.getSideInputs(); - if (sideInputs != null && !sideInputs.isEmpty()) { - sideInputs.stream() - .map(sideInput -> StreamUtil.getSystemStreamFromNameOrId(config, sideInput)) - .forEach(systemStream -> { - inputs.add(StreamUtil.getNameFromSystemStream(systemStream)); - configs.put(String.format(StreamConfig.STREAM_PREFIX() + StreamConfig.BOOTSTRAP(), - systemStream.getSystem(), systemStream.getStream()), "true"); - }); - } - }); - - configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs)); - - log.info("Job {} has generated configs {}", jobName, configs); - - String configPrefix = String.format(JobConfig.CONFIG_JOB_PREFIX(), jobName); - - // Disallow user specified job inputs/outputs. This info comes strictly from the user application. - Map allowedConfigs = new HashMap<>(config); - if (allowedConfigs.containsKey(TaskConfig.INPUT_STREAMS())) { - log.warn("Specifying task inputs in configuration is not allowed with Fluent API. " - + "Ignoring configured value for " + TaskConfig.INPUT_STREAMS()); - allowedConfigs.remove(TaskConfig.INPUT_STREAMS()); + InputOperatorSpec getInputOperator(String inputStreamId) { + if (!inEdges.containsKey(inputStreamId)) { + return null; } - - log.debug("Job {} has allowed configs {}", jobName, allowedConfigs); - return new JobConfig( - Util.rewriteConfig( - extractScopedConfig(new MapConfig(allowedConfigs), new MapConfig(configs), configPrefix))); + return appDesc.getInputOperators().get(inputStreamId); } - /** - * Serializes the {@link Serde} instances for operators, adds them to the provided config, and - * sets the serde configuration for the input/output/intermediate streams appropriately. - * - * We try to preserve the number of Serde instances before and after serialization. However we don't - * guarantee that references shared between these serdes instances (e.g. an Jackson ObjectMapper shared - * between two json serdes) are shared after deserialization too. - * - * Ideally all the user defined objects in the application should be serialized and de-serialized in one pass - * from the same output/input stream so that we can maintain reference sharing relationships. - * - * @param configs the configs to add serialized serde instances and stream serde configs to - */ - void addSerdeConfigs(Map configs) { - // collect all key and msg serde instances for streams - Map streamKeySerdes = new HashMap<>(); - Map streamMsgSerdes = new HashMap<>(); - Map inputOperators = specGraph.getInputOperators(); - inEdges.forEach(edge -> { - String streamId = edge.getStreamSpec().getId(); - InputOperatorSpec inputOperatorSpec = inputOperators.get(streamId); - Serde keySerde = inputOperatorSpec.getKeySerde(); - if (keySerde != null) { - streamKeySerdes.put(streamId, keySerde); - } - Serde valueSerde = inputOperatorSpec.getValueSerde(); - if (valueSerde != null) { - streamMsgSerdes.put(streamId, valueSerde); - } - }); - Map outputStreams = specGraph.getOutputStreams(); - outEdges.forEach(edge -> { - String streamId = edge.getStreamSpec().getId(); - OutputStreamImpl outputStream = outputStreams.get(streamId); - Serde keySerde = outputStream.getKeySerde(); - if (keySerde != null) { - streamKeySerdes.put(streamId, keySerde); - } - Serde valueSerde = outputStream.getValueSerde(); - if (valueSerde != null) { - streamMsgSerdes.put(streamId, valueSerde); - } - }); - - // collect all key and msg serde instances for stores - Map storeKeySerdes = new HashMap<>(); - Map storeMsgSerdes = new HashMap<>(); - specGraph.getAllOperatorSpecs().forEach(opSpec -> { - if (opSpec instanceof StatefulOperatorSpec) { - ((StatefulOperatorSpec) opSpec).getStoreDescriptors().forEach(storeDescriptor -> { - storeKeySerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getKeySerde()); - storeMsgSerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getMsgSerde()); - }); - } - }); - - // for each unique stream or store serde instance, generate a unique name and serialize to config - HashSet serdes = new HashSet<>(streamKeySerdes.values()); - serdes.addAll(streamMsgSerdes.values()); - serdes.addAll(storeKeySerdes.values()); - serdes.addAll(storeMsgSerdes.values()); - SerializableSerde serializableSerde = new SerializableSerde<>(); - Base64.Encoder base64Encoder = Base64.getEncoder(); - Map serdeUUIDs = new HashMap<>(); - serdes.forEach(serde -> { - String serdeName = serdeUUIDs.computeIfAbsent(serde, - s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString()); - configs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName), - base64Encoder.encodeToString(serializableSerde.toBytes(serde))); - }); - - // set key and msg serdes for streams to the serde names generated above - streamKeySerdes.forEach((streamId, serde) -> { - String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId); - String keySerdeConfigKey = streamIdPrefix + StreamConfig.KEY_SERDE(); - configs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); - }); - - streamMsgSerdes.forEach((streamId, serde) -> { - String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId); - String valueSerdeConfigKey = streamIdPrefix + StreamConfig.MSG_SERDE(); - configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde)); - }); - - // set key and msg serdes for stores to the serde names generated above - storeKeySerdes.forEach((storeName, serde) -> { - String keySerdeConfigKey = String.format(StorageConfig.KEY_SERDE(), storeName); - configs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); - }); - - storeMsgSerdes.forEach((storeName, serde) -> { - String msgSerdeConfigKey = String.format(StorageConfig.MSG_SERDE(), storeName); - configs.put(msgSerdeConfigKey, serdeUUIDs.get(serde)); - }); + boolean isLegacyTaskApplication() { + return LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass()); } - /** - * Computes the triggering interval to use during the execution of this {@link JobNode} - */ - private long computeTriggerInterval() { - // Obtain the operator specs from the specGraph - Collection operatorSpecs = specGraph.getAllOperatorSpecs(); - - // Filter out window operators, and obtain a list of their triggering interval values - List windowTimerIntervals = operatorSpecs.stream() - .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW) - .map(spec -> ((WindowOperatorSpec) spec).getDefaultTriggerMs()) - .collect(Collectors.toList()); - - // Filter out the join operators, and obtain a list of their ttl values - List joinTtlIntervals = operatorSpecs.stream() - .filter(spec -> spec instanceof JoinOperatorSpec) - .map(spec -> ((JoinOperatorSpec) spec).getTtlMs()) - .collect(Collectors.toList()); - - // Combine both the above lists - List candidateTimerIntervals = new ArrayList<>(joinTtlIntervals); - candidateTimerIntervals.addAll(windowTimerIntervals); - - if (candidateTimerIntervals.isEmpty()) { - return -1; - } - - // Compute the gcd of the resultant list - return MathUtil.gcd(candidateTimerIntervals); + KV getTableSerdes(String tableId) { + //TODO: SAMZA-1893: should test whether the table is used in the current JobNode + return appDesc.getTableSerdes(tableId); } - /** - * This function extract the subset of configs from the full config, and use it to override the generated configs - * from the job. - * @param fullConfig full config - * @param generatedConfig config generated for the job - * @param configPrefix prefix to extract the subset of the config overrides - * @return config that merges the generated configs and overrides - */ - private static Config extractScopedConfig(Config fullConfig, Config generatedConfig, String configPrefix) { - Config scopedConfig = fullConfig.subset(configPrefix); - - Config[] configPrecedence = new Config[] {fullConfig, generatedConfig, scopedConfig}; - // Strip empty configs so they don't override the configs before them. - Map mergedConfig = new HashMap<>(); - for (Map config : configPrecedence) { - for (Map.Entry property : config.entrySet()) { - String value = property.getValue(); - if (!(value == null || value.isEmpty())) { - mergedConfig.put(property.getKey(), property.getValue()); + private void findReachableOperators(Collection inputOperatorsInJobNode, + Set reachableOperators) { + inputOperatorsInJobNode.forEach(op -> { + if (reachableOperators.contains(op)) { + return; } - } - } - scopedConfig = new MapConfig(mergedConfig); - log.debug("Prefix '{}' has merged config {}", configPrefix, scopedConfig); - - return scopedConfig; - } - - static String createId(String jobName, String jobId) { - return String.format("%s-%s", jobName, jobId); + reachableOperators.add(op); + findReachableOperators(op.getRegisteredOperatorSpecs(), reachableOperators); + }); } } diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java b/samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java new file mode 100644 index 0000000000..676d28ebf1 --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.execution; + +import com.google.common.base.Joiner; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; +import org.apache.samza.config.Config; +import org.apache.samza.config.JavaTableConfig; +import org.apache.samza.config.JobConfig; +import org.apache.samza.config.MapConfig; +import org.apache.samza.config.SerializerConfig; +import org.apache.samza.config.StorageConfig; +import org.apache.samza.config.StreamConfig; +import org.apache.samza.config.TaskConfig; +import org.apache.samza.config.TaskConfigJava; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.spec.JoinOperatorSpec; +import org.apache.samza.operators.spec.OperatorSpec; +import org.apache.samza.operators.spec.StatefulOperatorSpec; +import org.apache.samza.operators.spec.StoreDescriptor; +import org.apache.samza.operators.spec.WindowOperatorSpec; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.serializers.Serde; +import org.apache.samza.serializers.SerializableSerde; +import org.apache.samza.table.TableConfigGenerator; +import org.apache.samza.table.TableSpec; +import org.apache.samza.util.MathUtil; +import org.apache.samza.util.StreamUtil; +import org.apache.samza.util.Util; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * This class provides methods to generate configuration for a {@link JobNode} + */ +/* package private */ class JobNodeConfigurationGenerator { + + private static final Logger LOG = LoggerFactory.getLogger(JobNodeConfigurationGenerator.class); + + static final String CONFIG_INTERNAL_EXECUTION_PLAN = "samza.internal.execution.plan"; + + static JobConfig mergeJobConfig(Config originalConfig, Config generatedConfig) { + JobConfig jobConfig = new JobConfig(originalConfig); + String jobNameAndId = JobNode.createJobNameAndId(jobConfig.getName().get(), jobConfig.getJobId()); + return new JobConfig(Util.rewriteConfig(extractScopedConfig(originalConfig, generatedConfig, + String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), jobNameAndId)))); + } + + JobConfig generateJobConfig(JobNode jobNode, String executionPlanJson) { + Map configs = new HashMap<>(); + // set up job name and job ID + configs.put(JobConfig.JOB_NAME(), jobNode.getJobName()); + configs.put(JobConfig.JOB_ID(), jobNode.getJobId()); + + Map inEdges = jobNode.getInEdges(); + Map outEdges = jobNode.getOutEdges(); + Collection reachableOperators = jobNode.getReachableOperators(); + List stores = getStoreDescriptors(reachableOperators); + Map reachableTables = getReachableTables(reachableOperators, jobNode); + Config config = jobNode.getConfig(); + + // check all inputs to the node for broadcast and input streams + final Set inputs = new HashSet<>(); + final Set broadcasts = new HashSet<>(); + for (StreamEdge inEdge : inEdges.values()) { + String formattedSystemStream = inEdge.getName(); + if (inEdge.isBroadcast()) { + broadcasts.add(formattedSystemStream + "#0"); + } else { + inputs.add(formattedSystemStream); + } + } + + configureBroadcastInputs(configs, config, broadcasts); + + // compute window and join operator intervals in this node + configureWindowInterval(configs, config, reachableOperators); + + // set store configuration for stateful operators. + stores.forEach(sd -> configs.putAll(sd.getStorageConfigs())); + + // set the execution plan in json + configs.put(CONFIG_INTERNAL_EXECUTION_PLAN, executionPlanJson); + + // write intermediate input/output streams to configs + inEdges.values().stream().filter(StreamEdge::isIntermediate).forEach(edge -> configs.putAll(edge.generateConfig())); + + // write serialized serde instances and stream, store, and table serdes to configs + // serde configuration generation has to happen before table configuration, since the serde configuration + // is required when generating configurations for some TableProvider (i.e. local store backed tables) + configureSerdes(configs, inEdges, outEdges, stores, reachableTables.keySet(), jobNode); + + // generate table configuration and potential side input configuration + configureTables(configs, config, reachableTables, inputs); + + // finalize the task.inputs configuration + configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs)); + + LOG.info("Job {} has generated configs {}", jobNode.getJobNameAndId(), configs); + + // apply configure rewriters and user configure overrides + return applyConfigureRewritersAndOverrides(configs, config, jobNode); + } + + private Map getReachableTables(Collection reachableOperators, JobNode jobNode) { + // TODO: Fix this in SAMZA-1893. For now, returning all tables for single-job execution plan + return jobNode.getTables(); + } + + private void configureBroadcastInputs(Map configs, Config config, Set broadcastStreams) { + // TODO: SAMZA-1841: remove this once we support defining broadcast input stream in high-level + // task.broadcast.input should be generated by the planner in the future. + if (broadcastStreams.isEmpty()) { + return; + } + final String taskBroadcasts = config.get(TaskConfigJava.BROADCAST_INPUT_STREAMS); + if (StringUtils.isNoneEmpty(taskBroadcasts)) { + broadcastStreams.add(taskBroadcasts); + } + configs.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcastStreams)); + } + + private void configureWindowInterval(Map configs, Config config, + Collection reachableOperators) { + if (!reachableOperators.stream().anyMatch(op -> op.getOpCode() == OperatorSpec.OpCode.WINDOW + || op.getOpCode() == OperatorSpec.OpCode.JOIN)) { + return; + } + + // set triggering interval if a window or join is defined. Only applies to high-level applications + if ("-1".equals(config.get(TaskConfig.WINDOW_MS(), "-1"))) { + long triggerInterval = computeTriggerInterval(reachableOperators); + LOG.info("Using triggering interval: {}", triggerInterval); + + configs.put(TaskConfig.WINDOW_MS(), String.valueOf(triggerInterval)); + } + } + + /** + * Computes the triggering interval to use during the execution of this {@link JobNode} + */ + private long computeTriggerInterval(Collection reachableOperators) { + List windowTimerIntervals = reachableOperators.stream() + .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW) + .map(spec -> ((WindowOperatorSpec) spec).getDefaultTriggerMs()) + .collect(Collectors.toList()); + + // Filter out the join operators, and obtain a list of their ttl values + List joinTtlIntervals = reachableOperators.stream() + .filter(spec -> spec instanceof JoinOperatorSpec) + .map(spec -> ((JoinOperatorSpec) spec).getTtlMs()) + .collect(Collectors.toList()); + + // Combine both the above lists + List candidateTimerIntervals = new ArrayList<>(joinTtlIntervals); + candidateTimerIntervals.addAll(windowTimerIntervals); + + if (candidateTimerIntervals.isEmpty()) { + return -1; + } + + // Compute the gcd of the resultant list + return MathUtil.gcd(candidateTimerIntervals); + } + + private JobConfig applyConfigureRewritersAndOverrides(Map configs, Config config, JobNode jobNode) { + // Disallow user specified job inputs/outputs. This info comes strictly from the user application. + Map allowedConfigs = new HashMap<>(config); + if (!jobNode.isLegacyTaskApplication()) { + if (allowedConfigs.containsKey(TaskConfig.INPUT_STREAMS())) { + LOG.warn("Specifying task inputs in configuration is not allowed for SamzaApplication. " + + "Ignoring configured value for " + TaskConfig.INPUT_STREAMS()); + allowedConfigs.remove(TaskConfig.INPUT_STREAMS()); + } + } + + LOG.debug("Job {} has allowed configs {}", jobNode.getJobNameAndId(), allowedConfigs); + return mergeJobConfig(new MapConfig(allowedConfigs), new MapConfig(configs)); + } + + /** + * This function extract the subset of configs from the full config, and use it to override the generated configs + * from the job. + * @param fullConfig full config + * @param generatedConfig config generated for the job + * @param configPrefix prefix to extract the subset of the config overrides + * @return config that merges the generated configs and overrides + */ + private static Config extractScopedConfig(Config fullConfig, Config generatedConfig, String configPrefix) { + Config scopedConfig = fullConfig.subset(configPrefix); + + Config[] configPrecedence = new Config[] {fullConfig, generatedConfig, scopedConfig}; + // Strip empty configs so they don't override the configs before them. + Map mergedConfig = new HashMap<>(); + for (Map config : configPrecedence) { + for (Map.Entry property : config.entrySet()) { + String value = property.getValue(); + if (!(value == null || value.isEmpty())) { + mergedConfig.put(property.getKey(), property.getValue()); + } + } + } + scopedConfig = new MapConfig(mergedConfig); + LOG.debug("Prefix '{}' has merged config {}", configPrefix, scopedConfig); + + return scopedConfig; + } + + private List getStoreDescriptors(Collection reachableOperators) { + return reachableOperators.stream().filter(operatorSpec -> operatorSpec instanceof StatefulOperatorSpec) + .map(operatorSpec -> ((StatefulOperatorSpec) operatorSpec).getStoreDescriptors()).flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + private void configureTables(Map configs, Config config, Map tables, Set inputs) { + configs.putAll(TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(configs), + tables.values().stream().collect(Collectors.toList()))); + + // Add side inputs to the inputs and mark the stream as bootstrap + tables.values().forEach(tableSpec -> { + List sideInputs = tableSpec.getSideInputs(); + if (sideInputs != null && !sideInputs.isEmpty()) { + sideInputs.stream() + .map(sideInput -> StreamUtil.getSystemStreamFromNameOrId(config, sideInput)) + .forEach(systemStream -> { + inputs.add(StreamUtil.getNameFromSystemStream(systemStream)); + configs.put(String.format(StreamConfig.STREAM_PREFIX() + StreamConfig.BOOTSTRAP(), + systemStream.getSystem(), systemStream.getStream()), "true"); + }); + } + }); + } + + /** + * Serializes the {@link Serde} instances for operators, adds them to the provided config, and + * sets the serde configuration for the input/output/intermediate streams appropriately. + * + * We try to preserve the number of Serde instances before and after serialization. However we don't + * guarantee that references shared between these serdes instances (e.g. an Jackson ObjectMapper shared + * between two json serdes) are shared after deserialization too. + * + * Ideally all the user defined objects in the application should be serialized and de-serialized in one pass + * from the same output/input stream so that we can maintain reference sharing relationships. + * + * @param configs the configs to add serialized serde instances and stream serde configs to + */ + private void configureSerdes(Map configs, Map inEdges, Map outEdges, + List stores, Collection tables, JobNode jobNode) { + // collect all key and msg serde instances for streams + Map streamKeySerdes = new HashMap<>(); + Map streamMsgSerdes = new HashMap<>(); + inEdges.keySet().forEach(streamId -> + addSerdes(jobNode.getInputSerdes(streamId), streamId, streamKeySerdes, streamMsgSerdes)); + outEdges.keySet().forEach(streamId -> + addSerdes(jobNode.getOutputSerde(streamId), streamId, streamKeySerdes, streamMsgSerdes)); + + Map storeKeySerdes = new HashMap<>(); + Map storeMsgSerdes = new HashMap<>(); + stores.forEach(storeDescriptor -> { + storeKeySerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getKeySerde()); + storeMsgSerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getMsgSerde()); + }); + + Map tableKeySerdes = new HashMap<>(); + Map tableMsgSerdes = new HashMap<>(); + tables.forEach(tableId -> { + addSerdes(jobNode.getTableSerdes(tableId), tableId, tableKeySerdes, tableMsgSerdes); + }); + + // for each unique stream or store serde instance, generate a unique name and serialize to config + HashSet serdes = new HashSet<>(streamKeySerdes.values()); + serdes.addAll(streamMsgSerdes.values()); + serdes.addAll(storeKeySerdes.values()); + serdes.addAll(storeMsgSerdes.values()); + serdes.addAll(tableKeySerdes.values()); + serdes.addAll(tableMsgSerdes.values()); + SerializableSerde serializableSerde = new SerializableSerde<>(); + Base64.Encoder base64Encoder = Base64.getEncoder(); + Map serdeUUIDs = new HashMap<>(); + serdes.forEach(serde -> { + String serdeName = serdeUUIDs.computeIfAbsent(serde, + s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString()); + configs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName), + base64Encoder.encodeToString(serializableSerde.toBytes(serde))); + }); + + // set key and msg serdes for streams to the serde names generated above + streamKeySerdes.forEach((streamId, serde) -> { + String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId); + String keySerdeConfigKey = streamIdPrefix + StreamConfig.KEY_SERDE(); + configs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); + }); + + streamMsgSerdes.forEach((streamId, serde) -> { + String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId); + String valueSerdeConfigKey = streamIdPrefix + StreamConfig.MSG_SERDE(); + configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde)); + }); + + // set key and msg serdes for stores to the serde names generated above + storeKeySerdes.forEach((storeName, serde) -> { + String keySerdeConfigKey = String.format(StorageConfig.KEY_SERDE(), storeName); + configs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); + }); + + storeMsgSerdes.forEach((storeName, serde) -> { + String msgSerdeConfigKey = String.format(StorageConfig.MSG_SERDE(), storeName); + configs.put(msgSerdeConfigKey, serdeUUIDs.get(serde)); + }); + + // set key and msg serdes for stores to the serde names generated above + tableKeySerdes.forEach((tableId, serde) -> { + String keySerdeConfigKey = String.format(JavaTableConfig.TABLE_KEY_SERDE, tableId); + configs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); + }); + + tableMsgSerdes.forEach((tableId, serde) -> { + String valueSerdeConfigKey = String.format(JavaTableConfig.TABLE_VALUE_SERDE, tableId); + configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde)); + }); + } + + private void addSerdes(KV serdes, String streamId, Map keySerdeMap, + Map msgSerdeMap) { + if (serdes != null) { + if (serdes.getKey() != null && !(serdes.getKey() instanceof NoOpSerde)) { + keySerdeMap.put(streamId, serdes.getKey()); + } + if (serdes.getValue() != null && !(serdes.getValue() instanceof NoOpSerde)) { + msgSerdeMap.put(streamId, serdes.getValue()); + } + } + } +} diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/JobPlanner.java index a2050e535d..abbec18c68 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/JobPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/JobPlanner.java @@ -20,29 +20,20 @@ import java.io.File; import java.io.PrintWriter; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.lang3.StringUtils; -import org.apache.samza.SamzaException; import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; -import org.apache.samza.application.StreamApplicationDescriptorImpl; -import org.apache.samza.application.TaskApplicationDescriptorImpl; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; import org.apache.samza.config.ShellCommandConfig; import org.apache.samza.config.StreamConfig; -import org.apache.samza.operators.BaseTableDescriptor; -import org.apache.samza.operators.OperatorSpecGraph; -import org.apache.samza.table.TableConfigGenerator; -import org.apache.samza.table.TableSpec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,22 +55,7 @@ public abstract class JobPlanner { this.config = descriptor.getConfig(); } - public List prepareJobs() { - String appId = new ApplicationConfig(appDesc.getConfig()).getGlobalAppId(); - if (appDesc instanceof TaskApplicationDescriptorImpl) { - return Collections.singletonList(prepareTaskJob((TaskApplicationDescriptorImpl) appDesc)); - } else if (appDesc instanceof StreamApplicationDescriptorImpl) { - try { - return prepareStreamJobs((StreamApplicationDescriptorImpl) appDesc); - } catch (Exception e) { - throw new SamzaException("Failed to generate JobConfig for StreamApplication " + appId, e); - } - } - throw new IllegalArgumentException(String.format("ApplicationDescriptorImpl has to be either TaskApplicationDescriptorImpl or " - + "StreamApplicationDescriptorImpl. class %s is not supported", appDesc.getClass().getName())); - } - - abstract List prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception; + public abstract List prepareJobs(); StreamManager buildAndStartStreamManager(Config config) { StreamManager streamManager = new StreamManager(config); @@ -87,12 +63,12 @@ StreamManager buildAndStartStreamManager(Config config) { return streamManager; } - ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph) { - return getExecutionPlan(specGraph, null); + ExecutionPlan getExecutionPlan() { + return getExecutionPlan(null); } /* package private */ - ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph, String runId) { + ExecutionPlan getExecutionPlan(String runId) { // update application configs Map cfg = new HashMap<>(); @@ -101,8 +77,8 @@ ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph, String runId) { } StreamConfig streamConfig = new StreamConfig(config); - Set inputStreams = new HashSet<>(specGraph.getInputOperators().keySet()); - inputStreams.removeAll(specGraph.getOutputStreams().keySet()); + Set inputStreams = new HashSet<>(appDesc.getInputStreamIds()); + inputStreams.removeAll(appDesc.getOutputStreamIds()); ApplicationConfig.ApplicationMode mode = inputStreams.stream().allMatch(streamConfig::getIsBounded) ? ApplicationConfig.ApplicationMode.BATCH : ApplicationConfig.ApplicationMode.STREAM; cfg.put(ApplicationConfig.APP_MODE, mode.name()); @@ -117,12 +93,12 @@ ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph, String runId) { // create the physical execution plan and merge with overrides. This works for a single-stage job now // TODO: This should all be consolidated with ExecutionPlanner after fixing SAMZA-1811 - Config mergedConfig = JobNode.mergeJobConfig(config, new MapConfig(cfg)); + Config mergedConfig = JobNodeConfigurationGenerator.mergeJobConfig(config, new MapConfig(cfg)); // creating the StreamManager to get all input/output streams' metadata for planning StreamManager streamManager = buildAndStartStreamManager(mergedConfig); try { ExecutionPlanner planner = new ExecutionPlanner(mergedConfig, streamManager); - return planner.plan(specGraph); + return planner.plan(appDesc); } finally { streamManager.stop(); } @@ -149,25 +125,6 @@ final void writePlanJsonFile(String planJson) { } } - // TODO: SAMZA-1814: the following configuration generation still misses serde configuration generation, - // side input configuration, broadcast input and task inputs configuration generation for low-level task - // applications - // helper method to generate a single node job configuration for low level task applications - private JobConfig prepareTaskJob(TaskApplicationDescriptorImpl taskAppDesc) { - // copy original configure - Map cfg = new HashMap<>(); - // expand system and streams configure - Map systemStreamConfigs = expandSystemStreamConfigs(taskAppDesc); - cfg.putAll(systemStreamConfigs); - // expand table configure - cfg.putAll(expandTableConfigs(cfg, taskAppDesc)); - // adding app.class in the configuration - cfg.put(ApplicationConfig.APP_CLASS, appDesc.getAppClass().getName()); - // create the physical execution plan and merge with overrides. This works for a single-stage job now - // TODO: This should all be consolidated with ExecutionPlanner after fixing SAMZA-1811 - return new JobConfig(JobNode.mergeJobConfig(config, new MapConfig(cfg))); - } - private Map expandSystemStreamConfigs(ApplicationDescriptorImpl appDesc) { Map systemStreamConfigs = new HashMap<>(); appDesc.getInputDescriptors().forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig())); @@ -177,12 +134,4 @@ private Map expandSystemStreamConfigs(ApplicationDescriptorImpl< systemStreamConfigs.put(JobConfig.JOB_DEFAULT_SYSTEM(), dsd.getSystemName())); return systemStreamConfigs; } - - private Map expandTableConfigs(Map originConfig, - ApplicationDescriptorImpl appDesc) { - List tableSpecs = new ArrayList<>(); - appDesc.getTableDescriptors().stream().map(td -> ((BaseTableDescriptor) td).getTableSpec()) - .forEach(spec -> tableSpecs.add(spec)); - return TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(originConfig), tableSpecs); - } } diff --git a/samza-core/src/main/java/org/apache/samza/execution/LocalJobPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/LocalJobPlanner.java index 7996d6bb27..86aca0fa39 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/LocalJobPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/LocalJobPlanner.java @@ -25,7 +25,6 @@ import org.apache.samza.SamzaException; import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; -import org.apache.samza.application.StreamApplicationDescriptorImpl; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.JobConfig; import org.apache.samza.config.JobCoordinatorConfig; @@ -37,7 +36,7 @@ /** - * Temporarily helper class with specific implementation of {@link JobPlanner#prepareStreamJobs(StreamApplicationDescriptorImpl)} + * Temporarily helper class with specific implementation of {@link JobPlanner#prepareJobs()} * for standalone Samza processors. * * TODO: we need to consolidate this with {@link ExecutionPlanner} after SAMZA-1811. @@ -53,17 +52,23 @@ public LocalJobPlanner(ApplicationDescriptorImpl prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception { + public List prepareJobs() { // for high-level DAG, generating the plan and job configs // 1. initialize and plan - ExecutionPlan plan = getExecutionPlan(streamAppDesc.getOperatorSpecGraph()); + ExecutionPlan plan = getExecutionPlan(); - String executionPlanJson = plan.getPlanAsJson(); + String executionPlanJson = ""; + try { + executionPlanJson = plan.getPlanAsJson(); + } catch (Exception e) { + throw new SamzaException("Failed to create plan JSON.", e); + } writePlanJsonFile(executionPlanJson); LOG.info("Execution Plan: \n" + executionPlanJson); String planId = String.valueOf(executionPlanJson.hashCode()); - if (plan.getJobConfigs().isEmpty()) { + List jobConfigs = plan.getJobConfigs(); + if (jobConfigs.isEmpty()) { throw new SamzaException("No jobs in the plan."); } @@ -71,7 +76,7 @@ List prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) // TODO: System generated intermediate streams should have robust naming scheme. See SAMZA-1391 // TODO: this works for single-job applications. For multi-job applications, ExecutionPlan should return an AppConfig // to be used for the whole application - JobConfig jobConfig = plan.getJobConfigs().get(0); + JobConfig jobConfig = jobConfigs.get(0); StreamManager streamManager = null; try { // create the StreamManager to create intermediate streams in the plan @@ -82,7 +87,7 @@ List prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) streamManager.stop(); } } - return plan.getJobConfigs(); + return jobConfigs; } /** diff --git a/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java b/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java index aa1dff92fd..ca912147b1 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java +++ b/samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java @@ -27,15 +27,14 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; -import org.apache.samza.operators.OperatorSpecGraph; import org.apache.samza.operators.spec.InputOperatorSpec; import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; /** - * A utility class that encapsulates the logic for traversing an {@link OperatorSpecGraph} and building - * associations between related {@link OperatorSpec}s. + * A utility class that encapsulates the logic for traversing operators in the graph from the set of {@link InputOperatorSpec} + * and building associations between related {@link OperatorSpec}s. */ /* package private */ class OperatorSpecGraphAnalyzer { @@ -43,14 +42,13 @@ * Returns a grouping of {@link InputOperatorSpec}s by the joins, i.e. {@link JoinOperatorSpec}s, they participate in. */ public static Multimap getJoinToInputOperatorSpecs( - OperatorSpecGraph operatorSpecGraph) { + Collection inputOperatorSpecs) { Multimap joinOpSpecToInputOpSpecs = HashMultimap.create(); // Traverse graph starting from every input operator spec, observing connectivity between input operator specs // and Join operator specs. - Iterable inputOpSpecs = operatorSpecGraph.getInputOperators().values(); - for (InputOperatorSpec inputOpSpec : inputOpSpecs) { + for (InputOperatorSpec inputOpSpec : inputOperatorSpecs) { // Observe all join operator specs reachable from this input operator spec. JoinOperatorSpecVisitor joinOperatorSpecVisitor = new JoinOperatorSpecVisitor(); traverse(inputOpSpec, joinOperatorSpecVisitor, opSpec -> opSpec.getRegisteredOperatorSpecs()); @@ -77,7 +75,7 @@ private static void traverse(OperatorSpec startOpSpec, Consumer vi } /** - * An {@link OperatorSpecGraph} visitor that records all {@link JoinOperatorSpec}s encountered in the graph. + * An visitor that records all {@link JoinOperatorSpec}s encountered in the graph of {@link OperatorSpec}s */ private static class JoinOperatorSpecVisitor implements Consumer { private Set joinOpSpecs = new HashSet<>(); diff --git a/samza-core/src/main/java/org/apache/samza/execution/RemoteJobPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/RemoteJobPlanner.java index 254ff97c51..54f86d5989 100644 --- a/samza-core/src/main/java/org/apache/samza/execution/RemoteJobPlanner.java +++ b/samza-core/src/main/java/org/apache/samza/execution/RemoteJobPlanner.java @@ -23,7 +23,6 @@ import org.apache.samza.SamzaException; import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; -import org.apache.samza.application.StreamApplicationDescriptorImpl; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; @@ -34,7 +33,7 @@ /** - * Temporary helper class with specific implementation of {@link JobPlanner#prepareStreamJobs(StreamApplicationDescriptorImpl)} + * Temporary helper class with specific implementation of {@link JobPlanner#prepareJobs()} * for remote-launched Samza processors (e.g. in YARN). * * TODO: we need to consolidate this class with {@link ExecutionPlanner} after SAMZA-1811. @@ -47,7 +46,7 @@ public RemoteJobPlanner(ApplicationDescriptorImpl prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception { + public List prepareJobs() { // for high-level DAG, generate the plan and job configs // TODO: run.id needs to be set for standalone: SAMZA-1531 // run.id is based on current system time with the most significant bits in UUID (8 digits) to avoid collision @@ -55,17 +54,22 @@ List prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) LOG.info("The run id for this run is {}", runId); // 1. initialize and plan - ExecutionPlan plan = getExecutionPlan(streamAppDesc.getOperatorSpecGraph(), runId); - writePlanJsonFile(plan.getPlanAsJson()); + ExecutionPlan plan = getExecutionPlan(runId); + try { + writePlanJsonFile(plan.getPlanAsJson()); + } catch (Exception e) { + throw new SamzaException("Failed to create plan JSON.", e); + } - if (plan.getJobConfigs().isEmpty()) { + List jobConfigs = plan.getJobConfigs(); + if (jobConfigs.isEmpty()) { throw new SamzaException("No jobs in the plan."); } // 2. create the necessary streams // TODO: this works for single-job applications. For multi-job applications, ExecutionPlan should return an AppConfig // to be used for the whole application - JobConfig jobConfig = plan.getJobConfigs().get(0); + JobConfig jobConfig = jobConfigs.get(0); StreamManager streamManager = null; try { // create the StreamManager to create intermediate streams in the plan @@ -79,7 +83,7 @@ List prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) streamManager.stop(); } } - return plan.getJobConfigs(); + return jobConfigs; } private Config getConfigFromPrevRun() { diff --git a/samza-core/src/main/java/org/apache/samza/operators/BaseTableDescriptor.java b/samza-core/src/main/java/org/apache/samza/operators/BaseTableDescriptor.java index 1e4194a259..1830d1c8a0 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/BaseTableDescriptor.java +++ b/samza-core/src/main/java/org/apache/samza/operators/BaseTableDescriptor.java @@ -72,6 +72,15 @@ public String getTableId() { return tableId; } + /** + * Get the serde assigned to this {@link TableDescriptor} + * + * @return {@link KVSerde} used by this table + */ + public KVSerde getSerde() { + return serde; + } + /** * Generate config for {@link TableSpec}; this method is used internally. * @param tableSpecConfig configuration for the {@link TableSpec} diff --git a/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java b/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java index b75b1e8538..5329fd7557 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java +++ b/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java @@ -30,7 +30,6 @@ import org.apache.samza.operators.spec.OperatorSpec; import org.apache.samza.operators.spec.OutputStreamImpl; import org.apache.samza.serializers.SerializableSerde; -import org.apache.samza.table.TableSpec; /** @@ -45,7 +44,6 @@ public class OperatorSpecGraph implements Serializable { private final Map inputOperators; private final Map outputStreams; private final Set broadcastStreams; - private final Map tables; private final Set allOpSpecs; private final boolean hasWindowOrJoins; @@ -57,7 +55,6 @@ public OperatorSpecGraph(StreamApplicationDescriptorImpl streamAppDesc) { this.inputOperators = streamAppDesc.getInputOperators(); this.outputStreams = streamAppDesc.getOutputStreams(); this.broadcastStreams = streamAppDesc.getBroadcastStreams(); - this.tables = streamAppDesc.getTables(); this.allOpSpecs = Collections.unmodifiableSet(this.findAllOperatorSpecs()); this.hasWindowOrJoins = checkWindowOrJoins(); this.serializedOpSpecGraph = opSpecGraphSerde.toBytes(this); @@ -75,10 +72,6 @@ public Set getBroadcastStreams() { return broadcastStreams; } - public Map getTables() { - return tables; - } - /** * Get all {@link OperatorSpec}s available in this {@link StreamApplicationDescriptorImpl} * diff --git a/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java b/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java index 98864d2eeb..b9bb1f6c0a 100644 --- a/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java +++ b/samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java @@ -75,7 +75,7 @@ public static void main(String[] args) throws Exception { throw new SamzaException("can not find the job name"); } String jobName = jobConfig.getName().get(); - String jobId = jobConfig.getJobId().getOrElse(ScalaJavaUtil.defaultValue("1")); + String jobId = jobConfig.getJobId(); MDC.put("containerName", "samza-container-" + containerId); MDC.put("jobName", jobName); MDC.put("jobId", jobId); diff --git a/samza-core/src/main/java/org/apache/samza/table/TableConfigGenerator.java b/samza-core/src/main/java/org/apache/samza/table/TableConfigGenerator.java index 085131c6cd..03be758f10 100644 --- a/samza-core/src/main/java/org/apache/samza/table/TableConfigGenerator.java +++ b/samza-core/src/main/java/org/apache/samza/table/TableConfigGenerator.java @@ -20,22 +20,16 @@ package org.apache.samza.table; import java.util.ArrayList; -import java.util.Base64; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.UUID; import org.apache.samza.config.Config; import org.apache.samza.config.JavaTableConfig; -import org.apache.samza.config.SerializerConfig; import org.apache.samza.operators.BaseTableDescriptor; import org.apache.samza.operators.TableDescriptor; import org.apache.samza.operators.TableImpl; -import org.apache.samza.serializers.Serde; -import org.apache.samza.serializers.SerializableSerde; import org.apache.samza.util.Util; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,8 +60,6 @@ static public Map generateConfigsForTableDescs(Config config, Li static public Map generateConfigsForTableSpecs(Config config, List tableSpecs) { Map tableConfigs = new HashMap<>(); - tableConfigs.putAll(generateTableKVSerdeConfigs(tableSpecs)); - tableSpecs.forEach(tableSpec -> { // Add table provider factory config tableConfigs.put(String.format(JavaTableConfig.TABLE_PROVIDER_FACTORY, tableSpec.getId()), @@ -103,44 +95,4 @@ static public List getTableSpecs(List tableDescs) { }); return new ArrayList<>(tableSpecs.keySet()); } - - static private Map generateTableKVSerdeConfigs(List tableSpecs) { - Map serdeConfigs = new HashMap<>(); - - // Collect key and msg serde instances for all the tables - Map tableKeySerdes = new HashMap<>(); - Map tableValueSerdes = new HashMap<>(); - HashSet serdes = new HashSet<>(); - - tableSpecs.forEach(tableSpec -> { - tableKeySerdes.put(tableSpec.getId(), tableSpec.getSerde().getKeySerde()); - tableValueSerdes.put(tableSpec.getId(), tableSpec.getSerde().getValueSerde()); - }); - serdes.addAll(tableKeySerdes.values()); - serdes.addAll(tableValueSerdes.values()); - - // Generate serde names - SerializableSerde serializableSerde = new SerializableSerde<>(); - Base64.Encoder base64Encoder = Base64.getEncoder(); - Map serdeUUIDs = new HashMap<>(); - serdes.forEach(serde -> { - String serdeName = serdeUUIDs.computeIfAbsent(serde, - s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString()); - serdeConfigs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName), - base64Encoder.encodeToString(serializableSerde.toBytes(serde))); - }); - - // Set key and msg serdes for tables to the serde names generated above - tableKeySerdes.forEach((tableId, serde) -> { - String keySerdeConfigKey = String.format(JavaTableConfig.TABLE_KEY_SERDE, tableId); - serdeConfigs.put(keySerdeConfigKey, serdeUUIDs.get(serde)); - }); - - tableValueSerdes.forEach((tableId, serde) -> { - String valueSerdeConfigKey = String.format(JavaTableConfig.TABLE_VALUE_SERDE, tableId); - serdeConfigs.put(valueSerdeConfigKey, serdeUUIDs.get(serde)); - }); - - return serdeConfigs; - } } diff --git a/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinatorFactory.java b/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinatorFactory.java index 3dad6c171b..41294a3c25 100644 --- a/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinatorFactory.java +++ b/samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinatorFactory.java @@ -35,7 +35,6 @@ public class ZkJobCoordinatorFactory implements JobCoordinatorFactory { private static final Logger LOG = LoggerFactory.getLogger(ZkJobCoordinatorFactory.class); private static final String JOB_COORDINATOR_ZK_PATH_FORMAT = "%s/%s-%s-coordinationData"; - private static final String DEFAULT_JOB_ID = "1"; private static final String DEFAULT_JOB_NAME = "defaultJob"; /** @@ -68,9 +67,7 @@ public static String getJobCoordinationZkPath(Config config) { String jobName = jobConfig.getName().isDefined() ? jobConfig.getName().get() : DEFAULT_JOB_NAME; - String jobId = jobConfig.getJobId().isDefined() - ? jobConfig.getJobId().get() - : DEFAULT_JOB_ID; + String jobId = jobConfig.getJobId(); return String.format(JOB_COORDINATOR_ZK_PATH_FORMAT, appId, jobName, jobId); } diff --git a/samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala b/samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala index fc8780f5fd..d7b71b5195 100644 --- a/samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala +++ b/samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala @@ -39,7 +39,7 @@ object JobConfig { */ val CONFIG_REWRITERS = "job.config.rewriters" // streaming.job_config_rewriters val CONFIG_REWRITER_CLASS = "job.config.rewriter.%s.class" // streaming.job_config_rewriter_class - regex, system, config - val CONFIG_JOB_PREFIX = "jobs.%s." + val CONFIG_OVERRIDE_JOBS_PREFIX = "jobs.%s." val JOB_NAME = "job.name" // streaming.job_name val JOB_ID = "job.id" // streaming.job_id val SAMZA_FWK_PATH = "samza.fwk.path" @@ -164,7 +164,7 @@ class JobConfig(config: Config) extends ScalaMapConfig(config) with Logging { def getStreamJobFactoryClass = getOption(JobConfig.STREAM_JOB_FACTORY_CLASS) - def getJobId = getOption(JobConfig.JOB_ID) + def getJobId = getOption(JobConfig.JOB_ID).getOrElse("1") def failOnCheckpointValidation = { getBoolean(JobConfig.JOB_FAIL_CHECKPOINT_VALIDATION, true) } diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala index fba7329a5e..417fc18518 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala @@ -101,7 +101,7 @@ object SamzaContainer extends Logging { if(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) != null) { val jobNameAndId = ( config.getName.getOrElse(throw new ConfigException("Missing required config: job.name")), - config.getJobId.getOrElse("1") + config.getJobId ) loggedStorageBaseDir = new File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) diff --git a/samza-core/src/main/scala/org/apache/samza/metrics/reporter/MetricsSnapshotReporterFactory.scala b/samza-core/src/main/scala/org/apache/samza/metrics/reporter/MetricsSnapshotReporterFactory.scala index d1e655412d..8a9c021cfa 100644 --- a/samza-core/src/main/scala/org/apache/samza/metrics/reporter/MetricsSnapshotReporterFactory.scala +++ b/samza-core/src/main/scala/org/apache/samza/metrics/reporter/MetricsSnapshotReporterFactory.scala @@ -44,7 +44,6 @@ class MetricsSnapshotReporterFactory extends MetricsReporterFactory with Logging val jobId = config .getJobId - .getOrElse(1.toString) val taskClass = config .getTaskClass diff --git a/samza-core/src/main/scala/org/apache/samza/util/CoordinatorStreamUtil.scala b/samza-core/src/main/scala/org/apache/samza/util/CoordinatorStreamUtil.scala index cd74716d18..bfb2271fc6 100644 --- a/samza-core/src/main/scala/org/apache/samza/util/CoordinatorStreamUtil.scala +++ b/samza-core/src/main/scala/org/apache/samza/util/CoordinatorStreamUtil.scala @@ -89,6 +89,6 @@ object CoordinatorStreamUtil { */ private def getJobNameAndId(config: Config) = { (config.getName.getOrElse(throw new ConfigException("Missing required config: job.name")), - config.getJobId.getOrElse("1")) + config.getJobId) } } diff --git a/samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java b/samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java index db85e3335e..1fe602308d 100644 --- a/samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java @@ -522,10 +522,11 @@ public void testGetTable() throws Exception { TableSpec testTableSpec = new TableSpec("t1", KVSerde.of(new NoOpSerde(), new NoOpSerde()), "", new HashMap<>()); when(mockTableDescriptor.getTableSpec()).thenReturn(testTableSpec); when(mockTableDescriptor.getTableId()).thenReturn(testTableSpec.getId()); + when(mockTableDescriptor.getSerde()).thenReturn(testTableSpec.getSerde()); StreamApplicationDescriptorImpl streamAppDesc = new StreamApplicationDescriptorImpl(appDesc -> { appDesc.getTable(mockTableDescriptor); }, mockConfig); - assertNotNull(streamAppDesc.getTables().get(testTableSpec)); + assertNotNull(streamAppDesc.getTables().get(testTableSpec.getId())); } @Test diff --git a/samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java b/samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java index 9418c1f056..abe5ce1249 100644 --- a/samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java @@ -23,12 +23,14 @@ import java.util.List; import java.util.Set; import org.apache.samza.config.Config; +import org.apache.samza.operators.BaseTableDescriptor; import org.apache.samza.operators.ContextManager; import org.apache.samza.operators.TableDescriptor; import org.apache.samza.operators.descriptors.base.stream.InputDescriptor; import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor; import org.apache.samza.operators.descriptors.base.system.SystemDescriptor; import org.apache.samza.runtime.ProcessorLifecycleListenerFactory; +import org.apache.samza.serializers.KVSerde; import org.apache.samza.task.TaskFactory; import org.junit.Before; import org.junit.Test; @@ -64,10 +66,12 @@ public class TestTaskApplicationDescriptorImpl { this.add(mock2); } }; private Set mockTables = new HashSet() { { - TableDescriptor mock1 = mock(TableDescriptor.class); - TableDescriptor mock2 = mock(TableDescriptor.class); + BaseTableDescriptor mock1 = mock(BaseTableDescriptor.class); + BaseTableDescriptor mock2 = mock(BaseTableDescriptor.class); when(mock1.getTableId()).thenReturn("test-table1"); when(mock2.getTableId()).thenReturn("test-table2"); + when(mock1.getSerde()).thenReturn(mock(KVSerde.class)); + when(mock2.getSerde()).thenReturn(mock(KVSerde.class)); this.add(mock1); this.add(mock2); } }; diff --git a/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java b/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java new file mode 100644 index 0000000000..f507c70827 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.execution; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import org.apache.samza.application.ApplicationDescriptorImpl; +import org.apache.samza.application.LegacyTaskApplication; +import org.apache.samza.application.StreamApplication; +import org.apache.samza.application.StreamApplicationDescriptorImpl; +import org.apache.samza.application.TaskApplication; +import org.apache.samza.config.Config; +import org.apache.samza.config.JobConfig; +import org.apache.samza.config.MapConfig; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.OutputStream; +import org.apache.samza.operators.descriptors.GenericInputDescriptor; +import org.apache.samza.operators.descriptors.GenericOutputDescriptor; +import org.apache.samza.operators.descriptors.GenericSystemDescriptor; +import org.apache.samza.operators.functions.JoinFunction; +import org.apache.samza.serializers.JsonSerdeV2; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.Serde; +import org.apache.samza.serializers.StringSerde; +import org.apache.samza.task.IdentityStreamTask; +import org.junit.Before; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + + +/** + * Unit test base class to set up commonly used test application and configuration. + */ +class ExecutionPlannerTestBase { + protected StreamApplicationDescriptorImpl mockStreamAppDesc; + protected Config mockConfig; + protected JobNode mockJobNode; + protected KVSerde defaultSerde; + protected GenericSystemDescriptor inputSystemDescriptor; + protected GenericSystemDescriptor outputSystemDescriptor; + protected GenericSystemDescriptor intermediateSystemDescriptor; + protected GenericInputDescriptor> input1Descriptor; + protected GenericInputDescriptor> input2Descriptor; + protected GenericInputDescriptor> intermediateInputDescriptor; + protected GenericInputDescriptor> broadcastInputDesriptor; + protected GenericOutputDescriptor> outputDescriptor; + protected GenericOutputDescriptor> intermediateOutputDescriptor; + + @Before + public void setUp() { + defaultSerde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>()); + inputSystemDescriptor = new GenericSystemDescriptor("input-system", "mockSystemFactoryClassName"); + outputSystemDescriptor = new GenericSystemDescriptor("output-system", "mockSystemFactoryClassName"); + intermediateSystemDescriptor = new GenericSystemDescriptor("intermediate-system", "mockSystemFactoryClassName"); + input1Descriptor = inputSystemDescriptor.getInputDescriptor("input1", defaultSerde); + input2Descriptor = inputSystemDescriptor.getInputDescriptor("input2", defaultSerde); + outputDescriptor = outputSystemDescriptor.getOutputDescriptor("output", defaultSerde); + intermediateInputDescriptor = intermediateSystemDescriptor.getInputDescriptor("jobName-jobId-partition_by-p1", defaultSerde) + .withPhysicalName("jobName-jobId-partition_by-p1"); + intermediateOutputDescriptor = intermediateSystemDescriptor.getOutputDescriptor("jobName-jobId-partition_by-p1", defaultSerde) + .withPhysicalName("jobName-jobId-partition_by-p1"); + broadcastInputDesriptor = intermediateSystemDescriptor.getInputDescriptor("jobName-jobId-broadcast-b1", defaultSerde) + .withPhysicalName("jobName-jobId-broadcast-b1"); + + Map configs = new HashMap<>(); + configs.put(JobConfig.JOB_NAME(), "jobName"); + configs.put(JobConfig.JOB_ID(), "jobId"); + configs.putAll(input1Descriptor.toConfig()); + configs.putAll(input2Descriptor.toConfig()); + configs.putAll(outputDescriptor.toConfig()); + configs.putAll(inputSystemDescriptor.toConfig()); + configs.putAll(outputSystemDescriptor.toConfig()); + configs.putAll(intermediateSystemDescriptor.toConfig()); + configs.put(JobConfig.JOB_DEFAULT_SYSTEM(), intermediateSystemDescriptor.getSystemName()); + mockConfig = spy(new MapConfig(configs)); + + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + } + + String getJobNameAndId() { + return "jobName-jobId"; + } + + void configureJobNode(ApplicationDescriptorImpl mockStreamAppDesc) { + JobGraph jobGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) + .createJobGraph(mockConfig, mockStreamAppDesc); + mockJobNode = spy(jobGraph.getJobNodes().get(0)); + } + + StreamApplication getRepartitionOnlyStreamApplication() { + return appDesc -> { + MessageStream> input1 = appDesc.getInputStream(input1Descriptor); + input1.partitionBy(KV::getKey, KV::getValue, "p1"); + }; + } + + StreamApplication getRepartitionJoinStreamApplication() { + return appDesc -> { + MessageStream> input1 = appDesc.getInputStream(input1Descriptor); + MessageStream> input2 = appDesc.getInputStream(input2Descriptor); + OutputStream> output = appDesc.getOutputStream(outputDescriptor); + JoinFunction> mockJoinFn = mock(JoinFunction.class); + input1 + .partitionBy(KV::getKey, KV::getValue, defaultSerde, "p1") + .map(kv -> kv.value) + .join(input2.map(kv -> kv.value), mockJoinFn, + new StringSerde(), new JsonSerdeV2<>(Object.class), new JsonSerdeV2<>(Object.class), + Duration.ofHours(1), "j1") + .sendTo(output); + }; + } + + TaskApplication getTaskApplication() { + return appDesc -> { + appDesc.addInputStream(input1Descriptor); + appDesc.addInputStream(input2Descriptor); + appDesc.addInputStream(intermediateInputDescriptor); + appDesc.addOutputStream(intermediateOutputDescriptor); + appDesc.addOutputStream(outputDescriptor); + appDesc.setTaskFactory(() -> new IdentityStreamTask()); + }; + } + + TaskApplication getLegacyTaskApplication() { + return new LegacyTaskApplication(IdentityStreamTask.class.getName()); + } + + StreamApplication getBroadcastOnlyStreamApplication(Serde serde) { + return appDesc -> { + MessageStream> input = appDesc.getInputStream(input1Descriptor); + if (serde != null) { + input.broadcast(serde, "b1"); + } else { + input.broadcast("b1"); + } + }; + } +} diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java index 779d299c06..61289afe45 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java @@ -24,12 +24,18 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.apache.samza.Partition; import org.apache.samza.SamzaException; +import org.apache.samza.application.ApplicationDescriptor; +import org.apache.samza.application.LegacyTaskApplication; +import org.apache.samza.application.SamzaApplication; import org.apache.samza.application.StreamApplicationDescriptorImpl; +import org.apache.samza.application.TaskApplicationDescriptorImpl; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; @@ -37,9 +43,13 @@ import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; +import org.apache.samza.operators.TableDescriptor; import org.apache.samza.operators.descriptors.GenericInputDescriptor; import org.apache.samza.operators.descriptors.GenericOutputDescriptor; import org.apache.samza.operators.descriptors.GenericSystemDescriptor; +import org.apache.samza.operators.descriptors.base.stream.InputDescriptor; +import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor; +import org.apache.samza.operators.descriptors.base.system.SystemDescriptor; import org.apache.samza.operators.functions.JoinFunction; import org.apache.samza.operators.windows.Windows; import org.apache.samza.serializers.KVSerde; @@ -54,8 +64,12 @@ import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestExecutionPlanner { @@ -63,6 +77,11 @@ public class TestExecutionPlanner { private static final String DEFAULT_SYSTEM = "test-system"; private static final int DEFAULT_PARTITIONS = 10; + private final Set systemDescriptors = new HashSet<>(); + private final Map inputDescriptors = new HashMap<>(); + private final Map outputDescriptors = new HashMap<>(); + private final Set tableDescriptors = new HashSet<>(); + private SystemAdmins systemAdmins; private StreamManager streamManager; private Config config; @@ -78,6 +97,8 @@ public class TestExecutionPlanner { private GenericOutputDescriptor> output1Descriptor; private StreamSpec output2Spec; private GenericOutputDescriptor> output2Descriptor; + private GenericSystemDescriptor system1Descriptor; + private GenericSystemDescriptor system2Descriptor; static SystemAdmin createSystemAdmin(Map streamToPartitions) { @@ -236,20 +257,35 @@ public void setup() { KVSerde kvSerde = new KVSerde<>(new NoOpSerde(), new NoOpSerde()); String mockSystemFactoryClass = "factory.class.name"; - GenericSystemDescriptor system1 = new GenericSystemDescriptor("system1", mockSystemFactoryClass); - GenericSystemDescriptor system2 = new GenericSystemDescriptor("system2", mockSystemFactoryClass); - input1Descriptor = system1.getInputDescriptor("input1", kvSerde); - input2Descriptor = system2.getInputDescriptor("input2", kvSerde); - input3Descriptor = system2.getInputDescriptor("input3", kvSerde); - input4Descriptor = system1.getInputDescriptor("input4", kvSerde); - output1Descriptor = system1.getOutputDescriptor("output1", kvSerde); - output2Descriptor = system2.getOutputDescriptor("output2", kvSerde); + system1Descriptor = new GenericSystemDescriptor("system1", mockSystemFactoryClass); + system2Descriptor = new GenericSystemDescriptor("system2", mockSystemFactoryClass); + input1Descriptor = system1Descriptor.getInputDescriptor("input1", kvSerde); + input2Descriptor = system2Descriptor.getInputDescriptor("input2", kvSerde); + input3Descriptor = system2Descriptor.getInputDescriptor("input3", kvSerde); + input4Descriptor = system1Descriptor.getInputDescriptor("input4", kvSerde); + output1Descriptor = system1Descriptor.getOutputDescriptor("output1", kvSerde); + output2Descriptor = system2Descriptor.getOutputDescriptor("output2", kvSerde); + + // clean and set up sets and maps of descriptors + systemDescriptors.clear(); + inputDescriptors.clear(); + outputDescriptors.clear(); + tableDescriptors.clear(); + systemDescriptors.add(system1Descriptor); + systemDescriptors.add(system2Descriptor); + inputDescriptors.put(input1Descriptor.getStreamId(), input1Descriptor); + inputDescriptors.put(input2Descriptor.getStreamId(), input2Descriptor); + inputDescriptors.put(input3Descriptor.getStreamId(), input3Descriptor); + inputDescriptors.put(input4Descriptor.getStreamId(), input4Descriptor); + outputDescriptors.put(output1Descriptor.getStreamId(), output1Descriptor); + outputDescriptors.put(output2Descriptor.getStreamId(), output2Descriptor); + // set up external partition count Map system1Map = new HashMap<>(); system1Map.put("input1", 64); system1Map.put("output1", 8); - system1Map.put("input4", ExecutionPlanner.MAX_INFERRED_PARTITIONS * 2); + system1Map.put("input4", IntermediateStreamManager.MAX_INFERRED_PARTITIONS * 2); Map system2Map = new HashMap<>(); system2Map.put("input2", 16); system2Map.put("input3", 32); @@ -268,7 +304,7 @@ public void testCreateProcessorGraph() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); assertTrue(jobGraph.getInputStreams().size() == 3); assertTrue(jobGraph.getOutputStreams().size() == 2); assertTrue(jobGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy @@ -278,9 +314,9 @@ public void testCreateProcessorGraph() { public void testFetchExistingStreamPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); - planner.fetchInputAndOutputStreamPartitions(jobGraph); + ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager); assertTrue(jobGraph.getOrCreateStreamEdge(input1Spec).getPartitionCount() == 64); assertTrue(jobGraph.getOrCreateStreamEdge(input2Spec).getPartitionCount() == 16); assertTrue(jobGraph.getOrCreateStreamEdge(input3Spec).getPartitionCount() == 32); @@ -296,7 +332,10 @@ public void testFetchExistingStreamPartitions() { public void testCalculateJoinInputPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec); + + ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager); + new IntermediateStreamManager(config, graphSpec).calculatePartitions(jobGraph); // the partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { @@ -309,7 +348,7 @@ public void testRejectsInvalidJoin() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidJoin(); - planner.plan(graphSpec.getOperatorSpecGraph()); + planner.plan(graphSpec); } @Test @@ -320,7 +359,7 @@ public void testDefaultPartitions() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); // the partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { @@ -336,7 +375,7 @@ public void testTriggerIntervalForJoins() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin(); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); List jobConfigs = plan.getJobConfigs(); for (JobConfig config : jobConfigs) { System.out.println(config); @@ -351,7 +390,7 @@ public void testTriggerIntervalForWindowsAndJoins() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoinAndWindow(); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); List jobConfigs = plan.getJobConfigs(); assertEquals(1, jobConfigs.size()); @@ -368,7 +407,7 @@ public void testTriggerIntervalWithInvalidWindowMs() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoinAndWindow(); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); List jobConfigs = plan.getJobConfigs(); assertEquals(1, jobConfigs.size()); @@ -384,7 +423,7 @@ public void testTriggerIntervalForStatelessOperators() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); List jobConfigs = plan.getJobConfigs(); assertEquals(1, jobConfigs.size()); assertFalse(jobConfigs.get(0).containsKey(TaskConfig.WINDOW_MS())); @@ -399,7 +438,7 @@ public void testTriggerIntervalWhenWindowMsIsConfigured() { ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager); StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); List jobConfigs = plan.getJobConfigs(); assertEquals(1, jobConfigs.size()); assertEquals("2000", jobConfigs.get(0).get(TaskConfig.WINDOW_MS())); @@ -409,7 +448,7 @@ public void testTriggerIntervalWhenWindowMsIsConfigured() { public void testCalculateIntStreamPartitions() { ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = createSimpleGraph(); - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); // the partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { @@ -430,15 +469,15 @@ public void testMaxPartition() { edge.setPartitionCount(16); edges.add(edge); - assertEquals(32, ExecutionPlanner.maxPartitions(edges)); + assertEquals(32, IntermediateStreamManager.maxPartitions(edges)); edges = Collections.emptyList(); - assertEquals(StreamEdge.PARTITIONS_UNKNOWN, ExecutionPlanner.maxPartitions(edges)); + assertEquals(StreamEdge.PARTITIONS_UNKNOWN, IntermediateStreamManager.maxPartitions(edges)); } @Test public void testMaxPartitionLimit() throws Exception { - int partitionLimit = ExecutionPlanner.MAX_INFERRED_PARTITIONS; + int partitionLimit = IntermediateStreamManager.MAX_INFERRED_PARTITIONS; ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { @@ -447,11 +486,99 @@ public void testMaxPartitionLimit() throws Exception { input1.partitionBy(m -> m.key, m -> m.value, "p1").map(kv -> kv).sendTo(output1); }, config); - JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph()); + JobGraph jobGraph = (JobGraph) planner.plan(graphSpec); // the partitions should be the same as input1 jobGraph.getIntermediateStreams().forEach(edge -> { assertEquals(partitionLimit, edge.getPartitionCount()); // max of input1 and output1 }); } + + @Test + public void testCreateJobGraphForTaskApplication() { + TaskApplicationDescriptorImpl taskAppDesc = mock(TaskApplicationDescriptorImpl.class); + // add interemediate streams + String intermediateStream1 = "intermediate-stream1"; + String intermediateBroadcast = "intermediate-broadcast1"; + // intermediate stream1, not broadcast + GenericInputDescriptor> intermediateInput1 = system1Descriptor.getInputDescriptor( + intermediateStream1, new KVSerde<>(new NoOpSerde(), new NoOpSerde())); + GenericOutputDescriptor> intermediateOutput1 = system1Descriptor.getOutputDescriptor( + intermediateStream1, new KVSerde<>(new NoOpSerde(), new NoOpSerde())); + // intermediate stream2, broadcast + GenericInputDescriptor> intermediateBroacastInput1 = system1Descriptor.getInputDescriptor( + intermediateBroadcast, new KVSerde<>(new NoOpSerde<>(), new NoOpSerde<>())); + GenericOutputDescriptor> intermediateBroacastOutput1 = system1Descriptor.getOutputDescriptor( + intermediateBroadcast, new KVSerde<>(new NoOpSerde<>(), new NoOpSerde<>())); + inputDescriptors.put(intermediateStream1, intermediateInput1); + outputDescriptors.put(intermediateStream1, intermediateOutput1); + inputDescriptors.put(intermediateBroadcast, intermediateBroacastInput1); + outputDescriptors.put(intermediateBroadcast, intermediateBroacastOutput1); + Set broadcastStreams = new HashSet<>(); + broadcastStreams.add(intermediateBroadcast); + + when(taskAppDesc.getInputDescriptors()).thenReturn(inputDescriptors); + when(taskAppDesc.getInputStreamIds()).thenReturn(inputDescriptors.keySet()); + when(taskAppDesc.getOutputDescriptors()).thenReturn(outputDescriptors); + when(taskAppDesc.getOutputStreamIds()).thenReturn(outputDescriptors.keySet()); + when(taskAppDesc.getTableDescriptors()).thenReturn(Collections.emptySet()); + when(taskAppDesc.getSystemDescriptors()).thenReturn(systemDescriptors); + when(taskAppDesc.getBroadcastStreams()).thenReturn(broadcastStreams); + doReturn(MockTaskApplication.class).when(taskAppDesc).getAppClass(); + + Map systemStreamConfigs = new HashMap<>(); + inputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig())); + outputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig())); + systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig())); + + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc); + assertEquals(1, jobGraph.getJobNodes().size()); + assertTrue(jobGraph.getInputStreams().stream().map(edge -> edge.getName()) + .filter(streamId -> inputDescriptors.containsKey(streamId)).collect(Collectors.toList()).isEmpty()); + Set intermediateStreams = new HashSet<>(inputDescriptors.keySet()); + jobGraph.getInputStreams().forEach(edge -> { + if (intermediateStreams.contains(edge.getStreamSpec().getId())) { + intermediateStreams.remove(edge.getStreamSpec().getId()); + } + }); + assertEquals(new HashSet() { { this.add(intermediateStream1); this.add(intermediateBroadcast); } }.toArray(), + intermediateStreams.toArray()); + } + + @Test + public void testCreateJobGraphForLegacyTaskApplication() { + TaskApplicationDescriptorImpl taskAppDesc = mock(TaskApplicationDescriptorImpl.class); + + when(taskAppDesc.getInputDescriptors()).thenReturn(new HashMap<>()); + when(taskAppDesc.getOutputDescriptors()).thenReturn(new HashMap<>()); + when(taskAppDesc.getTableDescriptors()).thenReturn(new HashSet<>()); + when(taskAppDesc.getSystemDescriptors()).thenReturn(new HashSet<>()); + when(taskAppDesc.getBroadcastStreams()).thenReturn(new HashSet<>()); + doReturn(LegacyTaskApplication.class).when(taskAppDesc).getAppClass(); + + Map systemStreamConfigs = new HashMap<>(); + inputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig())); + outputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig())); + systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig())); + + ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); + JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc); + assertEquals(1, jobGraph.getJobNodes().size()); + JobNode jobNode = jobGraph.getJobNodes().get(0); + assertEquals("test-app", jobNode.getJobName()); + assertEquals("test-app-1", jobNode.getJobNameAndId()); + assertEquals(0, jobNode.getInEdges().size()); + assertEquals(0, jobNode.getOutEdges().size()); + assertEquals(0, jobNode.getTables().size()); + assertEquals(config, jobNode.getConfig()); + } + + public static class MockTaskApplication implements SamzaApplication { + + @Override + public void describe(ApplicationDescriptor appDesc) { + + } + } } diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java b/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java new file mode 100644 index 0000000000..bc1570976e --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.execution; + +import org.apache.samza.application.StreamApplicationDescriptorImpl; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link IntermediateStreamManager} + */ +public class TestIntermediateStreamManager extends ExecutionPlannerTestBase { + + @Test + public void testCalculateRepartitionJoinTopicPartitions() { + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc); + JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) + .createJobGraph(mockConfig, mockStreamAppDesc); + // set the input stream partitions + mockGraph.getInputStreams().forEach(inEdge -> { + if (inEdge.getStreamSpec().getId().equals(input1Descriptor.getStreamId())) { + inEdge.setPartitionCount(6); + } else if (inEdge.getStreamSpec().getId().equals(input2Descriptor.getStreamId())) { + inEdge.setPartitionCount(5); + } + }); + partitionPlanner.calculatePartitions(mockGraph); + assertEquals(1, mockGraph.getIntermediateStreamEdges().size()); + assertEquals(5, mockGraph.getIntermediateStreamEdges().stream() + .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId())) + .findFirst().get().getPartitionCount()); + } + + @Test + public void testCalculateRepartitionIntermediateTopicPartitions() { + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionOnlyStreamApplication(), mockConfig); + IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc); + JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class)) + .createJobGraph(mockConfig, mockStreamAppDesc); + // set the input stream partitions + mockGraph.getInputStreams().forEach(inEdge -> inEdge.setPartitionCount(7)); + partitionPlanner.calculatePartitions(mockGraph); + assertEquals(1, mockGraph.getIntermediateStreamEdges().size()); + assertEquals(7, mockGraph.getIntermediateStreamEdges().stream() + .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId())) + .findFirst().get().getPartitionCount()); + } + +} diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java index ed35d6725b..4de0485a8a 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java @@ -19,12 +19,11 @@ package org.apache.samza.execution; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.samza.operators.OperatorSpecGraph; +import org.apache.samza.application.StreamApplicationDescriptorImpl; import org.apache.samza.system.StreamSpec; import org.junit.Before; import org.junit.Test; @@ -32,7 +31,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class TestJobGraph { @@ -61,9 +59,8 @@ private StreamSpec genStream() { * 2 9 10 */ private void createGraph1() { - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - graph1 = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + graph1 = new JobGraph(null, appDesc); JobNode n2 = graph1.getOrCreateJobNode("2", "1"); JobNode n3 = graph1.getOrCreateJobNode("3", "1"); @@ -96,9 +93,8 @@ private void createGraph1() { * |<---6 <--| <> */ private void createGraph2() { - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - graph2 = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + graph2 = new JobGraph(null, appDesc); JobNode n1 = graph2.getOrCreateJobNode("1", "1"); JobNode n2 = graph2.getOrCreateJobNode("2", "1"); @@ -125,9 +121,8 @@ private void createGraph2() { * 1<->1 -> 2<->2 */ private void createGraph3() { - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - graph3 = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + graph3 = new JobGraph(null, appDesc); JobNode n1 = graph3.getOrCreateJobNode("1", "1"); JobNode n2 = graph3.getOrCreateJobNode("2", "1"); @@ -143,9 +138,8 @@ private void createGraph3() { * 1<->1 */ private void createGraph4() { - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - graph4 = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + graph4 = new JobGraph(null, appDesc); JobNode n1 = graph4.getOrCreateJobNode("1", "1"); @@ -163,9 +157,8 @@ public void setup() { @Test public void testAddSource() { - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - JobGraph graph = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + JobGraph graph = new JobGraph(null, appDesc); /** * s1 -> 1 @@ -206,9 +199,8 @@ public void testAddSink() { * 2 -> s2 * 2 -> s3 */ - OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class); - when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet()); - JobGraph graph = new JobGraph(null, specGraph); + StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class); + JobGraph graph = new JobGraph(null, appDesc); JobNode n1 = graph.getOrCreateJobNode("1", "1"); JobNode n2 = graph.getOrCreateJobNode("2", "1"); StreamSpec s1 = genStream(); diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java index ae6e25e5ee..c207118f96 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java @@ -20,9 +20,14 @@ package org.apache.samza.execution; import java.time.Duration; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.samza.application.StreamApplicationDescriptorImpl; +import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; @@ -40,10 +45,13 @@ import org.apache.samza.serializers.NoOpSerde; import org.apache.samza.serializers.Serde; import org.apache.samza.serializers.StringSerde; +import org.apache.samza.system.StreamSpec; import org.apache.samza.system.SystemAdmin; import org.apache.samza.system.SystemAdmins; import org.apache.samza.testUtils.StreamTestUtils; import org.codehaus.jackson.map.ObjectMapper; +import org.hamcrest.Matchers; +import org.junit.Before; import org.junit.Test; import static org.apache.samza.execution.TestExecutionPlanner.*; @@ -51,16 +59,68 @@ import static org.mockito.Mockito.*; +/** + * Unit test for {@link JobGraphJsonGenerator} + */ public class TestJobGraphJsonGenerator { + private Config mockConfig; + private JobNode mockJobNode; + private StreamSpec input1Spec; + private StreamSpec input2Spec; + private StreamSpec outputSpec; + private StreamSpec repartitionSpec; + private KVSerde defaultSerde; + private GenericSystemDescriptor inputSystemDescriptor; + private GenericSystemDescriptor outputSystemDescriptor; + private GenericSystemDescriptor intermediateSystemDescriptor; + private GenericInputDescriptor> input1Descriptor; + private GenericInputDescriptor> input2Descriptor; + private GenericOutputDescriptor> outputDescriptor; - public class PageViewEvent { - String getCountry() { - return ""; - } + @Before + public void setUp() { + input1Spec = new StreamSpec("input1", "input1", "input-system"); + input2Spec = new StreamSpec("input2", "input2", "input-system"); + outputSpec = new StreamSpec("output", "output", "output-system"); + repartitionSpec = + new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system"); + + + defaultSerde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>()); + inputSystemDescriptor = new GenericSystemDescriptor("input-system", "mockSystemFactoryClassName"); + outputSystemDescriptor = new GenericSystemDescriptor("output-system", "mockSystemFactoryClassName"); + intermediateSystemDescriptor = new GenericSystemDescriptor("intermediate-system", "mockSystemFactoryClassName"); + input1Descriptor = inputSystemDescriptor.getInputDescriptor("input1", defaultSerde); + input2Descriptor = inputSystemDescriptor.getInputDescriptor("input2", defaultSerde); + outputDescriptor = outputSystemDescriptor.getOutputDescriptor("output", defaultSerde); + + Map configs = new HashMap<>(); + configs.put(JobConfig.JOB_NAME(), "jobName"); + configs.put(JobConfig.JOB_ID(), "jobId"); + mockConfig = spy(new MapConfig(configs)); + + mockJobNode = mock(JobNode.class); + StreamEdge input1Edge = new StreamEdge(input1Spec, false, false, mockConfig); + StreamEdge input2Edge = new StreamEdge(input2Spec, false, false, mockConfig); + StreamEdge outputEdge = new StreamEdge(outputSpec, false, false, mockConfig); + StreamEdge repartitionEdge = new StreamEdge(repartitionSpec, true, false, mockConfig); + Map inputEdges = new HashMap<>(); + inputEdges.put(input1Descriptor.getStreamId(), input1Edge); + inputEdges.put(input2Descriptor.getStreamId(), input2Edge); + inputEdges.put(repartitionSpec.getId(), repartitionEdge); + Map outputEdges = new HashMap<>(); + outputEdges.put(outputDescriptor.getStreamId(), outputEdge); + outputEdges.put(repartitionSpec.getId(), repartitionEdge); + when(mockJobNode.getInEdges()).thenReturn(inputEdges); + when(mockJobNode.getOutEdges()).thenReturn(outputEdges); + when(mockJobNode.getConfig()).thenReturn(mockConfig); + when(mockJobNode.getJobName()).thenReturn("jobName"); + when(mockJobNode.getJobId()).thenReturn("jobId"); + when(mockJobNode.getJobNameAndId()).thenReturn(JobNode.createJobNameAndId("jobName", "jobId")); } @Test - public void test() throws Exception { + public void testRepartitionedJoinStreamApplication() throws Exception { /** * the graph looks like the following. @@ -142,7 +202,7 @@ public void test() throws Exception { }, config); ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); String json = plan.getPlanAsJson(); System.out.println(json); @@ -157,7 +217,7 @@ public void test() throws Exception { } @Test - public void test2() throws Exception { + public void testRepartitionedWindowStreamApplication() throws Exception { Map configMap = new HashMap<>(); configMap.put(JobConfig.JOB_NAME(), "test-app"); configMap.put(JobConfig.JOB_DEFAULT_SYSTEM(), "test-system"); @@ -202,7 +262,7 @@ public void test2() throws Exception { }, config); ExecutionPlanner planner = new ExecutionPlanner(config, streamManager); - ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph()); + ExecutionPlan plan = planner.plan(graphSpec); String json = plan.getPlanAsJson(); System.out.println(json); @@ -222,4 +282,75 @@ public void test2() throws Exception { assertEquals(operatorGraphJson.operators.get("test-app-1-send_to-5").get("outputStreamId"), "PageViewCount"); } + + @Test + public void testTaskApplication() throws Exception { + JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator(); + JobGraph mockJobGraph = mock(JobGraph.class); + ApplicationConfig mockAppConfig = mock(ApplicationConfig.class); + when(mockAppConfig.getAppName()).thenReturn("testTaskApp"); + when(mockAppConfig.getAppId()).thenReturn("testTaskAppId"); + when(mockJobGraph.getApplicationConfig()).thenReturn(mockAppConfig); + // compute the three disjoint sets of the JobGraph: input only, output only, and intermediate streams + Set inEdges = new HashSet<>(mockJobNode.getInEdges().values()); + Set outEdges = new HashSet<>(mockJobNode.getOutEdges().values()); + Set intermediateEdges = new HashSet<>(inEdges); + // intermediate streams are the intersection between input and output + intermediateEdges.retainAll(outEdges); + // remove all intermediate streams from input + inEdges.removeAll(intermediateEdges); + // remove all intermediate streams from output + outEdges.removeAll(intermediateEdges); + // set the return values for mockJobGraph + when(mockJobGraph.getInputStreams()).thenReturn(inEdges); + when(mockJobGraph.getOutputStreams()).thenReturn(outEdges); + when(mockJobGraph.getIntermediateStreamEdges()).thenReturn(intermediateEdges); + when(mockJobGraph.getJobNodes()).thenReturn(Collections.singletonList(mockJobNode)); + String graphJson = jsonGenerator.toJson(mockJobGraph); + ObjectMapper objectMapper = new ObjectMapper(); + JobGraphJsonGenerator.JobGraphJson jsonObject = objectMapper.readValue(graphJson.getBytes(), JobGraphJsonGenerator.JobGraphJson.class); + assertEquals("testTaskAppId", jsonObject.applicationId); + assertEquals("testTaskApp", jsonObject.applicationName); + Set inStreamIds = inEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet()); + assertThat(jsonObject.sourceStreams.keySet(), Matchers.containsInAnyOrder(inStreamIds.toArray())); + Set outStreamIds = outEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet()); + assertThat(jsonObject.sinkStreams.keySet(), Matchers.containsInAnyOrder(outStreamIds.toArray())); + Set intStreamIds = intermediateEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet()); + assertThat(jsonObject.intermediateStreams.keySet(), Matchers.containsInAnyOrder(intStreamIds.toArray())); + JobGraphJsonGenerator.JobNodeJson expectedNodeJson = new JobGraphJsonGenerator.JobNodeJson(); + expectedNodeJson.jobId = mockJobNode.getJobId(); + expectedNodeJson.jobName = mockJobNode.getJobName(); + assertEquals(1, jsonObject.jobs.size()); + JobGraphJsonGenerator.JobNodeJson actualNodeJson = jsonObject.jobs.get(0); + assertEquals(expectedNodeJson.jobId, actualNodeJson.jobId); + assertEquals(expectedNodeJson.jobName, actualNodeJson.jobName); + assertEquals(3, actualNodeJson.operatorGraph.inputStreams.size()); + assertEquals(2, actualNodeJson.operatorGraph.outputStreams.size()); + assertEquals(0, actualNodeJson.operatorGraph.operators.size()); + } + + @Test + public void testLegacyTaskApplication() throws Exception { + JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator(); + JobGraph mockJobGraph = mock(JobGraph.class); + ApplicationConfig mockAppConfig = mock(ApplicationConfig.class); + when(mockAppConfig.getAppName()).thenReturn("testTaskApp"); + when(mockAppConfig.getAppId()).thenReturn("testTaskAppId"); + when(mockJobGraph.getApplicationConfig()).thenReturn(mockAppConfig); + String graphJson = jsonGenerator.toJson(mockJobGraph); + ObjectMapper objectMapper = new ObjectMapper(); + JobGraphJsonGenerator.JobGraphJson jsonObject = objectMapper.readValue(graphJson.getBytes(), JobGraphJsonGenerator.JobGraphJson.class); + assertEquals("testTaskAppId", jsonObject.applicationId); + assertEquals("testTaskApp", jsonObject.applicationName); + JobGraphJsonGenerator.JobNodeJson expectedNodeJson = new JobGraphJsonGenerator.JobNodeJson(); + expectedNodeJson.jobId = mockJobNode.getJobId(); + expectedNodeJson.jobName = mockJobNode.getJobName(); + assertEquals(0, jsonObject.jobs.size()); + } + + public class PageViewEvent { + String getCountry() { + return ""; + } + } } diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java deleted file mode 100644 index 163b094960..0000000000 --- a/samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.samza.execution; - -import java.time.Duration; -import java.util.Base64; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; -import org.apache.samza.application.StreamApplicationDescriptorImpl; -import org.apache.samza.config.Config; -import org.apache.samza.config.JobConfig; -import org.apache.samza.config.MapConfig; -import org.apache.samza.config.SerializerConfig; -import org.apache.samza.operators.KV; -import org.apache.samza.operators.MessageStream; -import org.apache.samza.operators.OutputStream; -import org.apache.samza.operators.descriptors.GenericInputDescriptor; -import org.apache.samza.operators.descriptors.GenericOutputDescriptor; -import org.apache.samza.operators.descriptors.GenericSystemDescriptor; -import org.apache.samza.operators.functions.JoinFunction; -import org.apache.samza.operators.impl.store.TimestampedValueSerde; -import org.apache.samza.serializers.JsonSerdeV2; -import org.apache.samza.serializers.KVSerde; -import org.apache.samza.serializers.Serde; -import org.apache.samza.serializers.SerializableSerde; -import org.apache.samza.serializers.StringSerde; -import org.apache.samza.system.StreamSpec; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.anyString; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; - -public class TestJobNode { - - @Test - public void testAddSerdeConfigs() { - StreamSpec input1Spec = new StreamSpec("input1", "input1", "input-system"); - StreamSpec input2Spec = new StreamSpec("input2", "input2", "input-system"); - StreamSpec outputSpec = new StreamSpec("output", "output", "output-system"); - StreamSpec partitionBySpec = - new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system"); - - Config mockConfig = mock(Config.class); - when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName"); - when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId"); - - StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { - KVSerde serde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>()); - GenericSystemDescriptor sd = new GenericSystemDescriptor("system1", "mockSystemFactoryClass"); - GenericInputDescriptor> inputDescriptor1 = sd.getInputDescriptor("input1", serde); - GenericInputDescriptor> inputDescriptor2 = sd.getInputDescriptor("input2", serde); - GenericOutputDescriptor> outputDescriptor = sd.getOutputDescriptor("output", serde); - MessageStream> input1 = appDesc.getInputStream(inputDescriptor1); - MessageStream> input2 = appDesc.getInputStream(inputDescriptor2); - OutputStream> output = appDesc.getOutputStream(outputDescriptor); - JoinFunction> mockJoinFn = mock(JoinFunction.class); - input1 - .partitionBy(KV::getKey, KV::getValue, serde, "p1") - .map(kv -> kv.value) - .join(input2.map(kv -> kv.value), mockJoinFn, - new StringSerde(), new JsonSerdeV2<>(Object.class), new JsonSerdeV2<>(Object.class), - Duration.ofHours(1), "j1") - .sendTo(output); - }, mockConfig); - - JobNode jobNode = new JobNode("jobName", "jobId", graphSpec.getOperatorSpecGraph(), mockConfig); - Config config = new MapConfig(); - StreamEdge input1Edge = new StreamEdge(input1Spec, false, false, config); - StreamEdge input2Edge = new StreamEdge(input2Spec, false, false, config); - StreamEdge outputEdge = new StreamEdge(outputSpec, false, false, config); - StreamEdge repartitionEdge = new StreamEdge(partitionBySpec, true, false, config); - jobNode.addInEdge(input1Edge); - jobNode.addInEdge(input2Edge); - jobNode.addOutEdge(outputEdge); - jobNode.addInEdge(repartitionEdge); - jobNode.addOutEdge(repartitionEdge); - - Map configs = new HashMap<>(); - jobNode.addSerdeConfigs(configs); - - MapConfig mapConfig = new MapConfig(configs); - Config serializers = mapConfig.subset("serializers.registry.", true); - - // make sure that the serializers deserialize correctly - SerializableSerde serializableSerde = new SerializableSerde<>(); - Map deserializedSerdes = serializers.entrySet().stream().collect(Collectors.toMap( - e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""), - e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes())) - )); - assertEquals(5, serializers.size()); // 2 default + 3 specific for join - - String input1KeySerde = mapConfig.get("streams.input1.samza.key.serde"); - String input1MsgSerde = mapConfig.get("streams.input1.samza.msg.serde"); - assertTrue("Serialized serdes should contain input1 key serde", - deserializedSerdes.containsKey(input1KeySerde)); - assertTrue("Serialized input1 key serde should be a StringSerde", - input1KeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain input1 msg serde", - deserializedSerdes.containsKey(input1MsgSerde)); - assertTrue("Serialized input1 msg serde should be a JsonSerdeV2", - input1MsgSerde.startsWith(JsonSerdeV2.class.getSimpleName())); - - String input2KeySerde = mapConfig.get("streams.input2.samza.key.serde"); - String input2MsgSerde = mapConfig.get("streams.input2.samza.msg.serde"); - assertTrue("Serialized serdes should contain input2 key serde", - deserializedSerdes.containsKey(input2KeySerde)); - assertTrue("Serialized input2 key serde should be a StringSerde", - input2KeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain input2 msg serde", - deserializedSerdes.containsKey(input2MsgSerde)); - assertTrue("Serialized input2 msg serde should be a JsonSerdeV2", - input2MsgSerde.startsWith(JsonSerdeV2.class.getSimpleName())); - - String outputKeySerde = mapConfig.get("streams.output.samza.key.serde"); - String outputMsgSerde = mapConfig.get("streams.output.samza.msg.serde"); - assertTrue("Serialized serdes should contain output key serde", - deserializedSerdes.containsKey(outputKeySerde)); - assertTrue("Serialized output key serde should be a StringSerde", - outputKeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain output msg serde", - deserializedSerdes.containsKey(outputMsgSerde)); - assertTrue("Serialized output msg serde should be a JsonSerdeV2", - outputMsgSerde.startsWith(JsonSerdeV2.class.getSimpleName())); - - String partitionByKeySerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde"); - String partitionByMsgSerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde"); - assertTrue("Serialized serdes should contain intermediate stream key serde", - deserializedSerdes.containsKey(partitionByKeySerde)); - assertTrue("Serialized intermediate stream key serde should be a StringSerde", - partitionByKeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain intermediate stream msg serde", - deserializedSerdes.containsKey(partitionByMsgSerde)); - assertTrue( - "Serialized intermediate stream msg serde should be a JsonSerdeV2", - partitionByMsgSerde.startsWith(JsonSerdeV2.class.getSimpleName())); - - String leftJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-L.key.serde"); - String leftJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-L.msg.serde"); - assertTrue("Serialized serdes should contain left join store key serde", - deserializedSerdes.containsKey(leftJoinStoreKeySerde)); - assertTrue("Serialized left join store key serde should be a StringSerde", - leftJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain left join store msg serde", - deserializedSerdes.containsKey(leftJoinStoreMsgSerde)); - assertTrue("Serialized left join store msg serde should be a TimestampedValueSerde", - leftJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName())); - - String rightJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-R.key.serde"); - String rightJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-R.msg.serde"); - assertTrue("Serialized serdes should contain right join store key serde", - deserializedSerdes.containsKey(rightJoinStoreKeySerde)); - assertTrue("Serialized right join store key serde should be a StringSerde", - rightJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName())); - assertTrue("Serialized serdes should contain right join store msg serde", - deserializedSerdes.containsKey(rightJoinStoreMsgSerde)); - assertTrue("Serialized right join store msg serde should be a TimestampedValueSerde", - rightJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName())); - } - - @Test - public void testAddSerdeConfigsForRepartitionWithNoDefaultSystem() { - StreamSpec inputSpec = new StreamSpec("input", "input", "input-system"); - StreamSpec partitionBySpec = - new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system"); - - Config mockConfig = mock(Config.class); - when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName"); - when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId"); - - StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { - GenericSystemDescriptor sd = new GenericSystemDescriptor("system1", "mockSystemFactoryClassName"); - GenericInputDescriptor> inputDescriptor1 = - sd.getInputDescriptor("input", KVSerde.of(new StringSerde(), new JsonSerdeV2<>())); - MessageStream> input = appDesc.getInputStream(inputDescriptor1); - input.partitionBy(KV::getKey, KV::getValue, "p1"); - }, mockConfig); - - JobNode jobNode = new JobNode("jobName", "jobId", graphSpec.getOperatorSpecGraph(), mockConfig); - Config config = new MapConfig(); - StreamEdge input1Edge = new StreamEdge(inputSpec, false, false, config); - StreamEdge repartitionEdge = new StreamEdge(partitionBySpec, true, false, config); - jobNode.addInEdge(input1Edge); - jobNode.addInEdge(repartitionEdge); - jobNode.addOutEdge(repartitionEdge); - - Map configs = new HashMap<>(); - jobNode.addSerdeConfigs(configs); - - MapConfig mapConfig = new MapConfig(configs); - Config serializers = mapConfig.subset("serializers.registry.", true); - - // make sure that the serializers deserialize correctly - SerializableSerde serializableSerde = new SerializableSerde<>(); - Map deserializedSerdes = serializers.entrySet().stream().collect(Collectors.toMap( - e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""), - e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes())) - )); - assertEquals(2, serializers.size()); // 2 input stream - - String partitionByKeySerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde"); - String partitionByMsgSerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde"); - assertTrue("Serialized serdes should not contain intermediate stream key serde", - !deserializedSerdes.containsKey(partitionByKeySerde)); - assertTrue("Serialized serdes should not contain intermediate stream msg serde", - !deserializedSerdes.containsKey(partitionByMsgSerde)); - } -} diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java new file mode 100644 index 0000000000..f351c4472b --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.execution; + +import com.google.common.base.Joiner; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.samza.application.StreamApplicationDescriptorImpl; +import org.apache.samza.application.TaskApplicationDescriptorImpl; +import org.apache.samza.config.Config; +import org.apache.samza.config.ConfigRewriter; +import org.apache.samza.config.JobConfig; +import org.apache.samza.config.MapConfig; +import org.apache.samza.config.SerializerConfig; +import org.apache.samza.config.TaskConfig; +import org.apache.samza.config.TaskConfigJava; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.operators.BaseTableDescriptor; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.TableDescriptor; +import org.apache.samza.operators.descriptors.GenericInputDescriptor; +import org.apache.samza.operators.impl.store.TimestampedValueSerde; +import org.apache.samza.serializers.JsonSerdeV2; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.Serde; +import org.apache.samza.serializers.SerializableSerde; +import org.apache.samza.serializers.StringSerde; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.table.Table; +import org.apache.samza.table.TableProvider; +import org.apache.samza.table.TableProviderFactory; +import org.apache.samza.table.TableSpec; +import org.apache.samza.task.TaskContext; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + + +/** + * Unit test for {@link JobNodeConfigurationGenerator} + */ +public class TestJobNodeConfigurationGenerator extends ExecutionPlannerTestBase { + + @Test + public void testConfigureSerdesWithRepartitionJoinApplication() { + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + configureJobNode(mockStreamAppDesc); + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + + // Verify the results + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + // additional, check the computed window.ms for join + assertEquals("3600000", jobConfig.get(TaskConfig.WINDOW_MS())); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 5); + validateStreamConfigures(jobConfig, deserializedSerdes); + validateJoinStoreConfigures(jobConfig, deserializedSerdes); + } + + @Test + public void testConfigureSerdesForRepartitionWithNoDefaultSystem() { + // set the application to RepartitionOnlyStreamApplication + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionOnlyStreamApplication(), mockConfig); + configureJobNode(mockStreamAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + + // Verify the results + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2); + validateStreamConfigures(jobConfig, null); + + String partitionByKeySerde = jobConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde"); + String partitionByMsgSerde = jobConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde"); + assertTrue("Serialized serdes should not contain intermediate stream key serde", + !deserializedSerdes.containsKey(partitionByKeySerde)); + assertTrue("Serialized serdes should not contain intermediate stream msg serde", + !deserializedSerdes.containsKey(partitionByMsgSerde)); + } + + @Test + public void testGenerateJobConfigWithTaskApplication() { + // set the application to TaskApplication, which still wire up all input/output/intermediate streams + TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getTaskApplication(), mockConfig); + configureJobNode(taskAppDesc); + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + + // Verify the results + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2); + validateStreamConfigures(jobConfig, deserializedSerdes); + } + + @Test + public void testGenerateJobConfigWithLegacyTaskApplication() { + TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getLegacyTaskApplication(), mockConfig); + configureJobNode(taskAppDesc); + Map originConfig = new HashMap<>(mockConfig); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, ""); + // jobConfig should be exactly the same as original config + Map generatedConfig = new HashMap<>(jobConfig); + assertEquals(originConfig, generatedConfig); + } + + @Test + public void testBroadcastStreamApplication() { + // set the application to BroadcastStreamApplication + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(defaultSerde), mockConfig); + configureJobNode(mockStreamAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2); + validateStreamSerdeConfigure(broadcastInputDesriptor.getStreamId(), jobConfig, deserializedSerdes); + validateIntermediateStreamConfigure(broadcastInputDesriptor.getStreamId(), broadcastInputDesriptor.getPhysicalName().get(), jobConfig); + } + + @Test + public void testBroadcastStreamApplicationWithoutSerde() { + // set the application to BroadcastStreamApplication withoutSerde + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(null), mockConfig); + configureJobNode(mockStreamAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2); + validateIntermediateStreamConfigure(broadcastInputDesriptor.getStreamId(), broadcastInputDesriptor.getPhysicalName().get(), jobConfig); + + String keySerde = jobConfig.get(String.format("streams.%s.samza.key.serde", broadcastInputDesriptor.getStreamId())); + String msgSerde = jobConfig.get(String.format("streams.%s.samza.msg.serde", broadcastInputDesriptor.getStreamId())); + assertTrue("Serialized serdes should not contain intermediate stream key serde", + !deserializedSerdes.containsKey(keySerde)); + assertTrue("Serialized serdes should not contain intermediate stream msg serde", + !deserializedSerdes.containsKey(msgSerde)); + } + + @Test + public void testStreamApplicationWithTableAndSideInput() { + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + // add table to the RepartitionJoinStreamApplication + GenericInputDescriptor> sideInput1 = inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde); + BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class); + TableSpec mockTableSpec = mock(TableSpec.class); + when(mockTableSpec.getId()).thenReturn("testTable"); + when(mockTableSpec.getSerde()).thenReturn((KVSerde) defaultSerde); + when(mockTableSpec.getTableProviderFactoryClassName()).thenReturn(MockTableProviderFactory.class.getName()); + List sideInputs = new ArrayList<>(); + sideInputs.add(sideInput1.getStreamId()); + when(mockTableSpec.getSideInputs()).thenReturn(sideInputs); + when(mockTableDescriptor.getTableId()).thenReturn("testTable"); + when(mockTableDescriptor.getTableSpec()).thenReturn(mockTableSpec); + when(mockTableDescriptor.getSerde()).thenReturn(defaultSerde); + // add side input and terminate at table in the appplication + mockStreamAppDesc.getInputStream(sideInput1).sendTo(mockStreamAppDesc.getTable(mockTableDescriptor)); + StreamEdge sideInputEdge = new StreamEdge(new StreamSpec(sideInput1.getStreamId(), "sideInput1", + inputSystemDescriptor.getSystemName()), false, false, mockConfig); + // need to put the sideInput related stream configuration to the original config + // TODO: this is confusing since part of the system and stream related configuration is generated outside the JobGraphConfigureGenerator + // It would be nice if all system and stream related configuration is generated in one place and only intermediate stream + // configuration is generated by JobGraphConfigureGenerator + Map configs = new HashMap<>(mockConfig); + configs.putAll(sideInputEdge.generateConfig()); + mockConfig = spy(new MapConfig(configs)); + configureJobNode(mockStreamAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 5); + validateTableConfigure(jobConfig, deserializedSerdes, mockTableDescriptor); + } + + @Test + public void testTaskApplicationWithTableAndSideInput() { + // add table to the RepartitionJoinStreamApplication + GenericInputDescriptor> sideInput1 = inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde); + BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class); + TableSpec mockTableSpec = mock(TableSpec.class); + when(mockTableSpec.getId()).thenReturn("testTable"); + when(mockTableSpec.getSerde()).thenReturn((KVSerde) defaultSerde); + when(mockTableSpec.getTableProviderFactoryClassName()).thenReturn(MockTableProviderFactory.class.getName()); + List sideInputs = new ArrayList<>(); + sideInputs.add(sideInput1.getStreamId()); + when(mockTableSpec.getSideInputs()).thenReturn(sideInputs); + when(mockTableDescriptor.getTableId()).thenReturn("testTable"); + when(mockTableDescriptor.getTableSpec()).thenReturn(mockTableSpec); + when(mockTableDescriptor.getSerde()).thenReturn(defaultSerde); + StreamEdge sideInputEdge = new StreamEdge(new StreamSpec(sideInput1.getStreamId(), "sideInput1", + inputSystemDescriptor.getSystemName()), false, false, mockConfig); + // need to put the sideInput related stream configuration to the original config + // TODO: this is confusing since part of the system and stream related configuration is generated outside the JobGraphConfigureGenerator + // It would be nice if all system and stream related configuration is generated in one place and only intermediate stream + // configuration is generated by JobGraphConfigureGenerator + Map configs = new HashMap<>(mockConfig); + configs.putAll(sideInputEdge.generateConfig()); + mockConfig = spy(new MapConfig(configs)); + + // set the application to TaskApplication, which still wire up all input/output/intermediate streams + TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getTaskApplication(), mockConfig); + // add table to the task application + taskAppDesc.addTable(mockTableDescriptor); + taskAppDesc.addInputStream(inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde)); + configureJobNode(taskAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + + // Verify the results + Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedJobConfig, jobConfig); + Map deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2); + validateStreamConfigures(jobConfig, deserializedSerdes); + validateTableConfigure(jobConfig, deserializedSerdes, mockTableDescriptor); + } + + @Test + public void testTaskInputsRemovedFromOriginalConfig() { + Map configs = new HashMap<>(mockConfig); + configs.put(TaskConfig.INPUT_STREAMS(), "not.allowed1,not.allowed2"); + mockConfig = spy(new MapConfig(configs)); + + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(defaultSerde), mockConfig); + configureJobNode(mockStreamAppDesc); + + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedConfig, jobConfig); + } + + @Test + public void testTaskInputsRetainedForLegacyTaskApplication() { + Map originConfig = new HashMap<>(mockConfig); + originConfig.put(TaskConfig.INPUT_STREAMS(), "must.retain1,must.retain2"); + mockConfig = new MapConfig(originConfig); + TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getLegacyTaskApplication(), mockConfig); + configureJobNode(taskAppDesc); + + // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, ""); + // jobConfig should be exactly the same as original config + Map generatedConfig = new HashMap<>(jobConfig); + assertEquals(originConfig, generatedConfig); + } + + @Test + public void testOverrideConfigs() { + Map configs = new HashMap<>(mockConfig); + String streamCfgToOverride = String.format("streams.%s.samza.system", intermediateInputDescriptor.getStreamId()); + String overrideCfgKey = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()) + streamCfgToOverride; + configs.put(overrideCfgKey, "customized-system"); + mockConfig = spy(new MapConfig(configs)); + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + configureJobNode(mockStreamAppDesc); + + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedConfig, jobConfig); + assertEquals("customized-system", jobConfig.get(streamCfgToOverride)); + } + + @Test + public void testConfigureRewriter() { + Map configs = new HashMap<>(mockConfig); + String streamCfgToOverride = String.format("streams.%s.samza.system", intermediateInputDescriptor.getStreamId()); + String overrideCfgKey = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()) + streamCfgToOverride; + configs.put(overrideCfgKey, "customized-system"); + configs.put(String.format(JobConfig.CONFIG_REWRITER_CLASS(), "mock"), MockConfigRewriter.class.getName()); + configs.put(JobConfig.CONFIG_REWRITERS(), "mock"); + configs.put(String.format("job.config.rewriter.mock.%s", streamCfgToOverride), "rewritten-system"); + mockConfig = spy(new MapConfig(configs)); + mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig); + configureJobNode(mockStreamAppDesc); + + JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator(); + JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson"); + Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges()); + validateJobConfig(expectedConfig, jobConfig); + assertEquals("rewritten-system", jobConfig.get(streamCfgToOverride)); + } + + private void validateTableConfigure(JobConfig jobConfig, Map deserializedSerdes, + TableDescriptor tableDescriptor) { + Config tableConfig = jobConfig.subset(String.format("tables.%s.", tableDescriptor.getTableId())); + assertEquals(MockTableProviderFactory.class.getName(), tableConfig.get("provider.factory")); + MockTableProvider mockTableProvider = + (MockTableProvider) new MockTableProviderFactory().getTableProvider(((BaseTableDescriptor) tableDescriptor).getTableSpec()); + assertEquals(mockTableProvider.configMap.get("mock.table.provider.config"), jobConfig.get("mock.table.provider.config")); + validateTableSerdeConfigure(tableDescriptor.getTableId(), jobConfig, deserializedSerdes); + } + + private Config getExpectedJobConfig(Config originConfig, Map inputEdges) { + Map configMap = new HashMap<>(originConfig); + Set inputs = new HashSet<>(); + Set broadcasts = new HashSet<>(); + for (StreamEdge inputEdge : inputEdges.values()) { + if (inputEdge.isBroadcast()) { + broadcasts.add(inputEdge.getName() + "#0"); + } else { + inputs.add(inputEdge.getName()); + } + } + if (!inputs.isEmpty()) { + configMap.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs)); + } + if (!broadcasts.isEmpty()) { + configMap.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcasts)); + } + return new MapConfig(configMap); + } + + private Map validateAndGetDeserializedSerdes(Config jobConfig, int numSerdes) { + Config serializers = jobConfig.subset("serializers.registry.", true); + // make sure that the serializers deserialize correctly + SerializableSerde serializableSerde = new SerializableSerde<>(); + assertEquals(numSerdes, serializers.size()); + return serializers.entrySet().stream().collect(Collectors.toMap( + e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""), + e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes())) + )); + } + + private void validateJobConfig(Config expectedConfig, JobConfig jobConfig) { + assertEquals(expectedConfig.get(JobConfig.JOB_NAME()), jobConfig.getName().get()); + assertEquals(expectedConfig.get(JobConfig.JOB_ID()), jobConfig.getJobId()); + assertEquals("testJobGraphJson", jobConfig.get(JobNodeConfigurationGenerator.CONFIG_INTERNAL_EXECUTION_PLAN)); + assertEquals(expectedConfig.get(TaskConfig.INPUT_STREAMS()), jobConfig.get(TaskConfig.INPUT_STREAMS())); + assertEquals(expectedConfig.get(TaskConfigJava.BROADCAST_INPUT_STREAMS), jobConfig.get(TaskConfigJava.BROADCAST_INPUT_STREAMS)); + } + + private void validateStreamSerdeConfigure(String streamId, Config config, Map deserializedSerdes) { + Config streamConfig = config.subset(String.format("streams.%s.samza.", streamId)); + String keySerdeName = streamConfig.get("key.serde"); + String valueSerdeName = streamConfig.get("msg.serde"); + assertTrue(String.format("Serialized serdes should contain %s key serde", streamId), deserializedSerdes.containsKey(keySerdeName)); + assertTrue(String.format("Serialized %s key serde should be a StringSerde", streamId), keySerdeName.startsWith(StringSerde.class.getSimpleName())); + assertTrue(String.format("Serialized serdes should contain %s msg serde", streamId), deserializedSerdes.containsKey(valueSerdeName)); + assertTrue(String.format("Serialized %s msg serde should be a JsonSerdeV2", streamId), valueSerdeName.startsWith(JsonSerdeV2.class.getSimpleName())); + } + + private void validateTableSerdeConfigure(String tableId, Config config, Map deserializedSerdes) { + Config streamConfig = config.subset(String.format("tables.%s.", tableId)); + String keySerdeName = streamConfig.get("key.serde"); + String valueSerdeName = streamConfig.get("value.serde"); + assertTrue(String.format("Serialized serdes should contain %s key serde", tableId), deserializedSerdes.containsKey(keySerdeName)); + assertTrue(String.format("Serialized %s key serde should be a StringSerde", tableId), keySerdeName.startsWith(StringSerde.class.getSimpleName())); + assertTrue(String.format("Serialized serdes should contain %s value serde", tableId), deserializedSerdes.containsKey(valueSerdeName)); + assertTrue(String.format("Serialized %s msg serde should be a JsonSerdeV2", tableId), valueSerdeName.startsWith(JsonSerdeV2.class.getSimpleName())); + } + + private void validateIntermediateStreamConfigure(String streamId, String physicalName, Config config) { + Config intStreamConfig = config.subset(String.format("streams.%s.", streamId), true); + assertEquals("intermediate-system", intStreamConfig.get("samza.system")); + assertEquals(String.valueOf(Integer.MAX_VALUE), intStreamConfig.get("samza.priority")); + assertEquals("true", intStreamConfig.get("samza.delete.committed.messages")); + assertEquals(physicalName, intStreamConfig.get("samza.physical.name")); + assertEquals("true", intStreamConfig.get("samza.intermediate")); + assertEquals("oldest", intStreamConfig.get("samza.offset.default")); + } + + private void validateStreamConfigures(Config config, Map deserializedSerdes) { + + if (deserializedSerdes != null) { + validateStreamSerdeConfigure(input1Descriptor.getStreamId(), config, deserializedSerdes); + validateStreamSerdeConfigure(input2Descriptor.getStreamId(), config, deserializedSerdes); + validateStreamSerdeConfigure(outputDescriptor.getStreamId(), config, deserializedSerdes); + validateStreamSerdeConfigure(intermediateInputDescriptor.getStreamId(), config, deserializedSerdes); + } + + // generated stream config for intermediate stream + String physicalName = intermediateInputDescriptor.getPhysicalName().isPresent() ? + intermediateInputDescriptor.getPhysicalName().get() : null; + validateIntermediateStreamConfigure(intermediateInputDescriptor.getStreamId(), physicalName, config); + } + + private void validateJoinStoreConfigures(MapConfig mapConfig, Map deserializedSerdes) { + String leftJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-L.key.serde"); + String leftJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-L.msg.serde"); + assertTrue("Serialized serdes should contain left join store key serde", + deserializedSerdes.containsKey(leftJoinStoreKeySerde)); + assertTrue("Serialized left join store key serde should be a StringSerde", + leftJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName())); + assertTrue("Serialized serdes should contain left join store msg serde", + deserializedSerdes.containsKey(leftJoinStoreMsgSerde)); + assertTrue("Serialized left join store msg serde should be a TimestampedValueSerde", + leftJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName())); + + String rightJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-R.key.serde"); + String rightJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-R.msg.serde"); + assertTrue("Serialized serdes should contain right join store key serde", + deserializedSerdes.containsKey(rightJoinStoreKeySerde)); + assertTrue("Serialized right join store key serde should be a StringSerde", + rightJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName())); + assertTrue("Serialized serdes should contain right join store msg serde", + deserializedSerdes.containsKey(rightJoinStoreMsgSerde)); + assertTrue("Serialized right join store msg serde should be a TimestampedValueSerde", + rightJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName())); + + Config leftJoinStoreConfig = mapConfig.subset("stores.jobName-jobId-join-j1-L.", true); + validateJoinStoreConfigure(leftJoinStoreConfig, "jobName-jobId-join-j1-L"); + Config rightJoinStoreConfig = mapConfig.subset("stores.jobName-jobId-join-j1-R.", true); + validateJoinStoreConfigure(rightJoinStoreConfig, "jobName-jobId-join-j1-R"); + } + + private void validateJoinStoreConfigure(Config joinStoreConfig, String changelogName) { + assertEquals("org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory", joinStoreConfig.get("factory")); + assertEquals(changelogName, joinStoreConfig.get("changelog")); + assertEquals("delete", joinStoreConfig.get("changelog.kafka.cleanup.policy")); + assertEquals("3600000", joinStoreConfig.get("changelog.kafka.retention.ms")); + assertEquals("3600000", joinStoreConfig.get("rocksdb.ttl.ms")); + } + + private static class MockTableProvider implements TableProvider { + private final Map configMap; + + MockTableProvider(Map configMap) { + this.configMap = configMap; + } + + @Override + public void init(SamzaContainerContext containerContext, TaskContext taskContext) { + + } + + @Override + public Table getTable() { + return null; + } + + @Override + public Map generateConfig(Config jobConfig, Map generatedConfig) { + return configMap; + } + + @Override + public void close() { + + } + } + + public static class MockTableProviderFactory implements TableProviderFactory { + + @Override + public TableProvider getTableProvider(TableSpec tableSpec) { + Map configMap = new HashMap<>(); + configMap.put("mock.table.provider.config", "mock.config.value"); + return new MockTableProvider(configMap); + } + } + + public static class MockConfigRewriter implements ConfigRewriter { + + @Override + public Config rewrite(String name, Config config) { + Map configMap = new HashMap<>(config); + configMap.putAll(config.subset(String.format("job.config.rewriter.%s.", name))); + return new MapConfig(configMap); + } + } +} diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestRemoteJobPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestRemoteJobPlanner.java index 988fb341f8..85921f4def 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestRemoteJobPlanner.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestRemoteJobPlanner.java @@ -69,7 +69,7 @@ public void testStreamCreation() ApplicationConfig mockAppConfig = mock(ApplicationConfig.class); when(mockAppConfig.getAppMode()).thenReturn(ApplicationConfig.ApplicationMode.STREAM); when(plan.getApplicationConfig()).thenReturn(mockAppConfig); - doReturn(plan).when(remotePlanner).getExecutionPlan(any(), any()); + doReturn(plan).when(remotePlanner).getExecutionPlan(any()); remotePlanner.prepareJobs(); diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java b/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java index a5b15b8b1e..57ae6d87c6 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java +++ b/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java @@ -117,7 +117,6 @@ public void testConstructor() { OperatorSpecGraph specGraph = new OperatorSpecGraph(mockAppDesc); assertEquals(specGraph.getInputOperators(), inputOpSpecMap); assertEquals(specGraph.getOutputStreams(), outputStrmMap); - assertTrue(specGraph.getTables().isEmpty()); assertTrue(!specGraph.hasWindowOrJoins()); assertEquals(specGraph.getAllOperatorSpecs(), this.allOpSpecs); } diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java b/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java index 7704a5b034..a34fdc386e 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java @@ -53,7 +53,6 @@ enum TestEnum { public static void assertClonedGraph(OperatorSpecGraph originalGraph, OperatorSpecGraph clonedGraph) { assertClonedInputs(originalGraph.getInputOperators(), clonedGraph.getInputOperators()); assertClonedOutputs(originalGraph.getOutputStreams(), clonedGraph.getOutputStreams()); - assertClonedTables(originalGraph.getTables(), clonedGraph.getTables()); assertAllOperators(originalGraph.getAllOperatorSpecs(), clonedGraph.getAllOperatorSpecs()); } diff --git a/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java b/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java index 19ee74f9a7..fd0ddf859d 100644 --- a/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java +++ b/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java @@ -25,10 +25,10 @@ import java.util.Map; import org.apache.samza.application.ApplicationDescriptor; import org.apache.samza.application.ApplicationDescriptorImpl; +import org.apache.samza.application.LegacyTaskApplication; import org.apache.samza.application.SamzaApplication; import org.apache.samza.application.ApplicationDescriptorUtil; import org.apache.samza.application.StreamApplication; -import org.apache.samza.application.TaskApplication; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; @@ -37,7 +37,6 @@ import org.apache.samza.processor.StreamProcessor; import org.apache.samza.execution.LocalJobPlanner; import org.apache.samza.task.IdentityStreamTask; -import org.apache.samza.task.StreamTaskFactory; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -73,8 +72,9 @@ public void testRunStreamTask() final Map cfgs = new HashMap<>(); cfgs.put(ApplicationConfig.APP_PROCESSOR_ID_GENERATOR_CLASS, UUIDGenerator.class.getName()); cfgs.put(JobConfig.JOB_NAME(), "test-task-job"); + cfgs.put(JobConfig.JOB_ID(), "jobId"); config = new MapConfig(cfgs); - mockApp = (TaskApplication) appDesc -> appDesc.setTaskFactory((StreamTaskFactory) () -> new IdentityStreamTask()); + mockApp = new LegacyTaskApplication(IdentityStreamTask.class.getName()); prepareTest(); StreamProcessor sp = mock(StreamProcessor.class); @@ -186,7 +186,8 @@ public void testWaitForFinishTimesout() { } private void prepareTest() { - ApplicationDescriptorImpl appDesc = ApplicationDescriptorUtil.getAppDescriptor(mockApp, config); + ApplicationDescriptorImpl appDesc = + ApplicationDescriptorUtil.getAppDescriptor(mockApp, config); localPlanner = spy(new LocalJobPlanner(appDesc)); runner = spy(new LocalApplicationRunner(appDesc, localPlanner)); } diff --git a/samza-core/src/test/java/org/apache/samza/runtime/TestRemoteApplicationRunner.java b/samza-core/src/test/java/org/apache/samza/runtime/TestRemoteApplicationRunner.java index ae525fb84f..702cbfbb1b 100644 --- a/samza-core/src/test/java/org/apache/samza/runtime/TestRemoteApplicationRunner.java +++ b/samza-core/src/test/java/org/apache/samza/runtime/TestRemoteApplicationRunner.java @@ -124,7 +124,7 @@ public ApplicationStatus waitForStatus(ApplicationStatus status, long timeoutMs) @Override public ApplicationStatus getStatus() { - String jobId = c.getJobId().get(); + String jobId = c.getJobId(); switch (jobId) { case "newJob": return ApplicationStatus.New; diff --git a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemFactory.scala b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemFactory.scala index 05d717a78d..3f5f11c647 100644 --- a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemFactory.scala +++ b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemFactory.scala @@ -35,7 +35,7 @@ class HdfsSystemFactory extends SystemFactory with Logging { def getProducer(systemName: String, config: Config, registry: MetricsRegistry) = { val jobConfig = new JobConfig(config) val jobName = jobConfig.getName.getOrElse(throw new ConfigException("Missing job name.")) - val jobId = jobConfig.getJobId.getOrElse("1") + val jobId = jobConfig.getJobId val clientId = getClientId("samza-producer", jobName, jobId) val metrics = new HdfsSystemProducerMetrics(systemName, registry) diff --git a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala index 8d4098f50f..2999800c4b 100644 --- a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala +++ b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala @@ -32,7 +32,7 @@ class KafkaCheckpointManagerFactory extends CheckpointManagerFactory with Loggin def getCheckpointManager(config: Config, registry: MetricsRegistry): CheckpointManager = { val jobName = config.getName.getOrElse(throw new SamzaException("Missing job name in configs")) - val jobId = config.getJobId.getOrElse("1") + val jobId = config.getJobId val kafkaConfig = new KafkaConfig(config) val checkpointSystemName = kafkaConfig.getCheckpointSystem.getOrElse( diff --git a/samza-kafka/src/main/scala/org/apache/samza/config/KafkaConsumerConfig.java b/samza-kafka/src/main/scala/org/apache/samza/config/KafkaConsumerConfig.java index 3fa66e5c44..6cebc2827b 100644 --- a/samza-kafka/src/main/scala/org/apache/samza/config/KafkaConsumerConfig.java +++ b/samza-kafka/src/main/scala/org/apache/samza/config/KafkaConsumerConfig.java @@ -31,7 +31,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; -import scala.runtime.AbstractFunction0; /** @@ -126,11 +125,7 @@ static String getConsumerGroupId(Config config) { } String jobName = (String) jobNameOption.get(); - Option jobIdOption = jobConfig.getJobId(); - String jobId = "1"; - if (! jobIdOption.isEmpty()) { - jobId = (String) jobIdOption.get(); - } + String jobId = jobConfig.getJobId(); return String.format("%s-%s", jobName, jobId); } @@ -156,11 +151,7 @@ static String getConsumerClientId(String id, Config config) { } String jobName = (String) jobNameOption.get(); - Option jobIdOption = jobConfig.getJobId(); - String jobId = "1"; - if (! jobIdOption.isEmpty()) { - jobId = (String) jobIdOption.get(); - } + String jobId = jobConfig.getJobId(); return String.format("%s-%s-%s", id.replaceAll("\\W", "_"), jobName.replaceAll("\\W", "_"), jobId.replaceAll("\\W", "_")); diff --git a/samza-kafka/src/main/scala/org/apache/samza/util/KafkaUtil.scala b/samza-kafka/src/main/scala/org/apache/samza/util/KafkaUtil.scala index 601ffa25b9..2d09301926 100644 --- a/samza-kafka/src/main/scala/org/apache/samza/util/KafkaUtil.scala +++ b/samza-kafka/src/main/scala/org/apache/samza/util/KafkaUtil.scala @@ -40,7 +40,7 @@ object KafkaUtil extends Logging { def getClientId(id: String, config: Config): String = getClientId( id, config.getName.getOrElse(throw new ConfigException("Missing job name.")), - config.getJobId.getOrElse("1")) + config.getJobId) def getClientId(id: String, jobName: String, jobId: String): String = "%s-%s-%s" format diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java index 07f4f55882..823190581f 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java @@ -89,10 +89,15 @@ protected Map generateCommonStoreConfig(Config jobConfig, Map storeConfig = new HashMap<>(); - // We assume the configuration for serde are already generated for this table, - // so we simply carry them over to store configuration. - // - JavaTableConfig tableConfig = new JavaTableConfig(new MapConfig(generatedConfig)); + // serde configurations for tables are generated at top level by JobNodeConfigurationGenerator and are included + // in the global jobConfig. generatedConfig has all table specific configuration generated from TableSpec, such + // as TableProviderFactory, sideInputs, etc. + // Merge the global jobConfig and generatedConfig to get full access to configuration needed to create local + // store configuration + Map mergedConfigMap = new HashMap<>(jobConfig); + mergedConfigMap.putAll(generatedConfig); + JobConfig mergedJobConfig = new JobConfig(new MapConfig(mergedConfigMap)); + JavaTableConfig tableConfig = new JavaTableConfig(mergedJobConfig); String keySerde = tableConfig.getKeySerde(tableSpec.getId()); storeConfig.put(String.format(StorageConfig.KEY_SERDE(), tableSpec.getId()), keySerde); @@ -116,9 +121,7 @@ protected Map generateCommonStoreConfig(Config jobConfig, Map configs; - private Class taskClass; - private StreamApplication app; + private SamzaApplication app; /* * inMemoryScope is a unique global key per TestRunner, this key when configured with {@link InMemorySystemDescriptor} * provides an isolated state to run with in memory system @@ -112,7 +108,7 @@ private TestRunner(Class taskClass) { this(); Preconditions.checkNotNull(taskClass); configs.put(TaskConfig.TASK_CLASS(), taskClass.getName()); - this.taskClass = taskClass; + this.app = new LegacyTaskApplication(taskClass.getName()); } /** @@ -158,6 +154,17 @@ public TestRunner addConfigs(Map config) { return this; } + /** + * Only adds a config from {@code config} to samza job {@code configs} if they dont exist in it. + * @param config configs for the application + * @return this {@link TestRunner} + */ + public TestRunner addConfigs(Map config, String configPrefix) { + Preconditions.checkNotNull(config); + config.forEach((key, value) -> this.configs.putIfAbsent(String.format("%s%s", configPrefix, key), value)); + return this; + } + /** * Adds a config to {@code configs} if its not already present. Overrides a config value for which key is already * exisiting in {@code configs} @@ -168,7 +175,7 @@ public TestRunner addConfigs(Map config) { public TestRunner addOverrideConfig(String key, String value) { Preconditions.checkNotNull(key); Preconditions.checkNotNull(value); - String configKeyPrefix = String.format(JobConfig.CONFIG_JOB_PREFIX(), JOB_NAME); + String configKeyPrefix = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()); configs.put(String.format("%s%s", configKeyPrefix, key), value); return this; } @@ -192,6 +199,10 @@ public TestRunner addInputStream(InMemoryInputDescriptor des return this; } + private String getJobNameAndId() { + return String.format("%s-%s", JOB_NAME, configs.getOrDefault(JobConfig.JOB_ID(), "1")); + } + /** * Adds the provided input stream with mock data to the test application. * @param descriptor describes the stream that is supposed to be input to Samza application @@ -243,11 +254,10 @@ public TestRunner addOutputStream(InMemoryOutputDescriptor streamDescriptor, int * @throws SamzaException if Samza job fails with exception and returns UnsuccessfulFinish as the statuscode */ public void run(Duration timeout) { - Preconditions.checkState((app == null && taskClass != null) || (app != null && taskClass == null), + Preconditions.checkState(app != null, "TestRunner should run for Low Level Task api or High Level Application Api"); Preconditions.checkState(!timeout.isZero() || !timeout.isNegative(), "Timeouts should be positive"); - SamzaApplication testApp = app == null ? (TaskApplication) appDesc -> appDesc.setTaskFactory(createTaskFactory()) : app; - final LocalApplicationRunner runner = new LocalApplicationRunner(testApp, new MapConfig(configs)); + final LocalApplicationRunner runner = new LocalApplicationRunner(app, new MapConfig(configs)); runner.run(); boolean timedOut = !runner.waitForFinish(timeout); Assert.assertFalse("Timed out waiting for application to finish", timedOut); @@ -326,28 +336,6 @@ public static Map> consumeS entry -> entry.getValue().stream().map(e -> (StreamMessageType) e.getMessage()).collect(Collectors.toList()))); } - private TaskFactory createTaskFactory() { - if (StreamTask.class.isAssignableFrom(taskClass)) { - return (StreamTaskFactory) () -> { - try { - return (StreamTask) taskClass.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new SamzaException(String.format("Failed to instantiate StreamTask class %s", taskClass.getName()), e); - } - }; - } else if (AsyncStreamTask.class.isAssignableFrom(taskClass)) { - return (AsyncStreamTaskFactory) () -> { - try { - return (AsyncStreamTask) taskClass.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new SamzaException(String.format("Failed to instantiate AsyncStreamTask class %s", taskClass.getName()), e); - } - }; - } - throw new SamzaException(String.format("Not supported task.class %s. task.class has to implement either StreamTask " - + "or AsyncStreamTask", taskClass.getName())); - } - /** * Creates an in memory stream with {@link InMemorySystemFactory} and feeds its partition with stream of messages * @param partitonData key of the map represents partitionId and value represents @@ -367,7 +355,7 @@ private void initializeInMemoryInputStream(InMemoryInputDesc InMemorySystemDescriptor imsd = (InMemorySystemDescriptor) descriptor.getSystemDescriptor(); imsd.withInMemoryScope(this.inMemoryScope); addConfigs(descriptor.toConfig()); - addConfigs(descriptor.getSystemDescriptor().toConfig()); + addConfigs(descriptor.getSystemDescriptor().toConfig(), String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId())); StreamSpec spec = new StreamSpec(descriptor.getStreamId(), streamName, systemName, partitonData.size()); SystemFactory factory = new InMemorySystemFactory(); Config config = new MapConfig(descriptor.toConfig(), descriptor.getSystemDescriptor().toConfig()); @@ -381,7 +369,7 @@ private void initializeInMemoryInputStream(InMemoryInputDesc producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value)); }); producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null, - new EndOfStreamMessage(null))); + new EndOfStreamMessage(null))); }); } } diff --git a/samza-test/src/main/java/org/apache/samza/test/framework/system/InMemorySystemDescriptor.java b/samza-test/src/main/java/org/apache/samza/test/framework/system/InMemorySystemDescriptor.java index 92b23ef96f..e6e423f5af 100644 --- a/samza-test/src/main/java/org/apache/samza/test/framework/system/InMemorySystemDescriptor.java +++ b/samza-test/src/main/java/org/apache/samza/test/framework/system/InMemorySystemDescriptor.java @@ -29,7 +29,6 @@ import org.apache.samza.system.SystemStreamMetadata; import org.apache.samza.system.inmemory.InMemorySystemFactory; import org.apache.samza.config.JavaSystemConfig; -import org.apache.samza.test.framework.TestRunner; /** @@ -60,9 +59,6 @@ public class InMemorySystemDescriptor extends SystemDescriptor * **/ - private static final String CONFIG_OVERRIDE_PREFIX = "jobs.%s."; - private static final String DEFAULT_STREAM_OFFSET_DEFAULT_CONFIG_KEY = "systems.%s.default.stream.samza.offset.default"; - private String inMemoryScope; /** @@ -106,11 +102,7 @@ public InMemorySystemDescriptor withInMemoryScope(String inMemoryScope) { public Map toConfig() { HashMap configs = new HashMap<>(super.toConfig()); configs.put(InMemorySystemConfig.INMEMORY_SCOPE, this.inMemoryScope); - configs.put(String.format(CONFIG_OVERRIDE_PREFIX + JavaSystemConfig.SYSTEM_FACTORY_FORMAT, TestRunner.JOB_NAME, getSystemName()), - FACTORY_CLASS_NAME); - configs.put( - String.format(CONFIG_OVERRIDE_PREFIX + DEFAULT_STREAM_OFFSET_DEFAULT_CONFIG_KEY, TestRunner.JOB_NAME, - getSystemName()), SystemStreamMetadata.OffsetType.OLDEST.toString()); + configs.put(String.format(JavaSystemConfig.SYSTEM_FACTORY_FORMAT, getSystemName()), FACTORY_CLASS_NAME); return configs; } diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java b/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java index d123cee7bf..6186ca7271 100644 --- a/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java @@ -96,8 +96,6 @@ public void testWithTableDescriptorsProviderClass() throws Exception { Assert.assertEquals(storageConfig.getStoreNames().get(0), localTableId); Assert.assertEquals(storageConfig.getStorageFactoryClassName(localTableId), RocksDbKeyValueStorageEngineFactory.class.getName()); - Assert.assertTrue(storageConfig.getStorageKeySerde(localTableId).startsWith("StringSerde")); - Assert.assertTrue(storageConfig.getStorageMsgSerde(localTableId).startsWith("StringSerde")); Config storeConfig = resultConfig.subset("stores." + localTableId + ".", true); Assert.assertEquals(4, storeConfig.size()); Assert.assertEquals(4096, storeConfig.getInt("rocksdb.block.size.bytes")); @@ -107,10 +105,6 @@ public void testWithTableDescriptorsProviderClass() throws Exception { RocksDbTableProviderFactory.class.getName()); Assert.assertEquals(tableConfig.getTableProviderFactory(remoteTableId), RemoteTableProviderFactory.class.getName()); - Assert.assertTrue(tableConfig.getKeySerde(localTableId).startsWith("StringSerde")); - Assert.assertTrue(tableConfig.getValueSerde(localTableId).startsWith("StringSerde")); - Assert.assertTrue(tableConfig.getKeySerde(remoteTableId).startsWith("StringSerde")); - Assert.assertTrue(tableConfig.getValueSerde(remoteTableId).startsWith("LongSerde")); Assert.assertEquals(tableConfig.getTableProviderFactory(localTableId), RocksDbTableProviderFactory.class.getName()); Assert.assertEquals(tableConfig.getTableProviderFactory(remoteTableId), RemoteTableProviderFactory.class.getName()); } diff --git a/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java b/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java index b30b896371..4adb93a45b 100644 --- a/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java +++ b/samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java @@ -76,7 +76,7 @@ public YarnJobValidationTool(JobConfig config, YarnClient client, MetricsValidat this.config = config; this.client = client; String name = this.config.getName().get(); - String jobId = this.config.getJobId().nonEmpty()? this.config.getJobId().get() : "1"; + String jobId = this.config.getJobId(); this.jobName = name + "_" + jobId; this.validator = validator; } diff --git a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala index d3354489aa..1d72a88184 100644 --- a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala +++ b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala @@ -67,7 +67,7 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob { } envMapWithJavaHome }), - Some("%s_%s" format(config.getName.get, config.getJobId.getOrElse(1))) + Some("%s_%s" format(config.getName.get, config.getJobId)) ) } catch { case e: Throwable => @@ -169,7 +169,7 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob { // Get by name config.getName match { case Some(jobName) => - val applicationName = "%s_%s" format(jobName, config.getJobId.getOrElse(1)) + val applicationName = "%s_%s" format(jobName, config.getJobId) logger.info("Fetching status from YARN for application name %s" format applicationName) val applicationIds = client.getActiveApplicationIds(applicationName)