diff --git a/sdk/pom.xml b/sdk/pom.xml index d7e10a53a8..2639bc13db 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -685,6 +685,13 @@ ${guava.version} + + com.google.guava + guava-testlib + ${guava.version} + test + + com.google.protobuf protobuf-java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java new file mode 100644 index 0000000000..b5bcf18eda --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines a {@link com.google.cloud.dataflow.sdk.coders.Coder} + * for Protocol Buffers messages, {@code ProtoCoder}. + * + * @see com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder + */ +package com.google.cloud.dataflow.sdk.coders.protobuf; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java index 7ecccf14a6..7d59b09c8d 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java @@ -72,7 +72,7 @@ *

Reading from Cloud Bigtable

* *

The Bigtable source returns a set of rows from a single table, returning a - * {@code PCollection<Row>}. + * {@code PCollection}. * *

To configure a Cloud Bigtable source, you must supply a table id and a {@link BigtableOptions} * or builder configured with the project and other information necessary to identify the diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java new file mode 100644 index 0000000000..112a954d71 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines transforms for reading and writing from Google Cloud Bigtable. + * + * @see com.google.cloud.dataflow.sdk.io.bigtable.BigtableIO + */ +package com.google.cloud.dataflow.sdk.io.bigtable; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java index 923033d5da..8ff1fa9783 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java @@ -137,7 +137,8 @@ * *

{@link Default @Default} represents a set of annotations that can be used to annotate getter * properties on {@link PipelineOptions} with information representing the default value to be - * returned if no value is specified. + * returned if no value is specified. Any default implementation (using the {@code default} keyword) + * is ignored. * *

{@link Hidden @Hidden} hides an option from being listed when {@code --help} * is invoked via {@link PipelineOptionsFactory#fromArgs(String[])}. diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java index e77b89f9a4..4781d1c829 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -16,6 +16,8 @@ package com.google.cloud.dataflow.sdk.options; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.cloud.dataflow.sdk.options.Validation.Required; import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; @@ -443,13 +445,21 @@ Class getProxyClass() { private static final Map>> SUPPORTED_PIPELINE_RUNNERS; /** Classes that are used as the boundary in the stack trace to find the callers class name. */ - private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = ImmutableSet.of( - PipelineOptionsFactory.class.getName(), - Builder.class.getName()); + private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = + ImmutableSet.of(PipelineOptionsFactory.class.getName(), Builder.class.getName()); /** Methods that are ignored when validating the proxy class. */ private static final Set IGNORED_METHODS; + /** A predicate that checks if a method is synthetic via {@link Method#isSynthetic()}. */ + private static final Predicate NOT_SYNTHETIC_PREDICATE = + new Predicate() { + @Override + public boolean apply(Method input) { + return !input.isSynthetic(); + } + }; + /** The set of options that have been registered and visible to the user. */ private static final Set> REGISTERED_OPTIONS = Sets.newConcurrentHashSet(); @@ -662,7 +672,9 @@ public static void printHelp(PrintStream out, Class i Preconditions.checkNotNull(iface); validateWellFormed(iface, REGISTERED_OPTIONS); - Iterable methods = ReflectHelpers.getClosureOfMethodsOnInterface(iface); + Iterable methods = + Iterables.filter( + ReflectHelpers.getClosureOfMethodsOnInterface(iface), NOT_SYNTHETIC_PREDICATE); ListMultimap, Method> ifaceToMethods = ArrayListMultimap.create(); for (Method method : methods) { // Process only methods that are not marked as hidden. @@ -876,7 +888,8 @@ private static List getPropertyDescriptors(Class beanClas throws IntrospectionException { // The sorting is important to make this method stable. SortedSet methods = Sets.newTreeSet(MethodComparator.INSTANCE); - methods.addAll(Arrays.asList(beanClass.getMethods())); + methods.addAll( + Collections2.filter(Arrays.asList(beanClass.getMethods()), NOT_SYNTHETIC_PREDICATE)); SortedMap propertyNamesToGetters = getPropertyNamesToGetters(methods); List descriptors = Lists.newArrayList(); @@ -1017,8 +1030,9 @@ private static List validateClass(Class klass) throws IntrospectionException { Set methods = Sets.newHashSet(IGNORED_METHODS); // Ignore static methods, "equals", "hashCode", "toString" and "as" on the generated class. + // Ignore synthetic methods for (Method method : klass.getMethods()) { - if (Modifier.isStatic(method.getModifiers())) { + if (Modifier.isStatic(method.getModifiers()) || method.isSynthetic()) { methods.add(method); } } @@ -1035,6 +1049,7 @@ private static List validateClass(Class interfaceMethods = FluentIterable .from(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .filter(NOT_SYNTHETIC_PREDICATE) .toSortedSet(MethodComparator.INSTANCE); SortedSetMultimap methodNameToMethodMap = TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); @@ -1059,10 +1074,13 @@ private static List validateClass(Class allInterfaceMethods = FluentIterable - .from(ReflectHelpers.getClosureOfMethodsOnInterfaces(validatedPipelineOptionsInterfaces)) - .append(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) - .toSortedSet(MethodComparator.INSTANCE); + Iterable allInterfaceMethods = + FluentIterable.from( + ReflectHelpers.getClosureOfMethodsOnInterfaces( + validatedPipelineOptionsInterfaces)) + .append(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .filter(NOT_SYNTHETIC_PREDICATE) + .toSortedSet(MethodComparator.INSTANCE); SortedSetMultimap methodNameToAllMethodMap = TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); for (Method method : allInterfaceMethods) { @@ -1146,7 +1164,10 @@ private static List validateClass(Class unknownMethods = new TreeSet<>(MethodComparator.INSTANCE); - unknownMethods.addAll(Sets.difference(Sets.newHashSet(klass.getMethods()), methods)); + unknownMethods.addAll( + Sets.filter( + Sets.difference(Sets.newHashSet(klass.getMethods()), methods), + NOT_SYNTHETIC_PREDICATE)); Preconditions.checkArgument(unknownMethods.isEmpty(), "Methods %s on [%s] do not conform to being bean properties.", FluentIterable.from(unknownMethods).transform(ReflectHelpers.METHOD_FORMATTER), @@ -1391,7 +1412,10 @@ private static ListMultimap parseCommandLine( * split up each string on ','. * *

We special case the "runner" option. It is mapped to the class of the {@link PipelineRunner} - * based off of the {@link PipelineRunner}s simple class name or fully qualified class name. + * based off of the {@link PipelineRunner PipelineRunners} simple class name. If the provided + * runner name is not registered via a {@link PipelineRunnerRegistrar}, we attempt to obtain the + * class that the name represents using {@link Class#forName(String)} and use the result class if + * it subclasses {@link PipelineRunner}. * *

If strict parsing is enabled, unknown options or options that cannot be converted to * the expected java type using an {@link ObjectMapper} will be ignored. @@ -1442,10 +1466,26 @@ public boolean apply(@Nullable String input) { JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); if ("runner".equals(entry.getKey())) { String runner = Iterables.getOnlyElement(entry.getValue()); - Preconditions.checkArgument(SUPPORTED_PIPELINE_RUNNERS.containsKey(runner), - "Unknown 'runner' specified '%s', supported pipeline runners %s", - runner, Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); - convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + if (SUPPORTED_PIPELINE_RUNNERS.containsKey(runner)) { + convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + } else { + try { + Class runnerClass = Class.forName(runner); + checkArgument( + PipelineRunner.class.isAssignableFrom(runnerClass), + "Class '%s' does not implement PipelineRunner. Supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + convertedOptions.put("runner", runnerClass); + } catch (ClassNotFoundException e) { + String msg = + String.format( + "Unknown 'runner' specified '%s', supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + throw new IllegalArgumentException(msg, e); + } + } } else if ((returnType.isArray() && (SIMPLE_TYPES.contains(returnType.getComponentType()) || returnType.getComponentType().isEnum())) || Collection.class.isAssignableFrom(returnType)) { diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java index d0cc4e53d5..0feae957f8 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -952,6 +952,9 @@ private void groupByKeyHelper( context.addInput( PropertyNames.SERIALIZED_FN, byteArrayToJsonString(serializeToByteArray(windowingStrategy))); + context.addInput( + PropertyNames.IS_MERGING_WINDOW_FN, + !windowingStrategy.getWindowFn().isNonMerging()); } }); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java index 1c0279897a..eaea3ed293 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java @@ -15,10 +15,11 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.Read.Bounded; import com.google.cloud.dataflow.sdk.io.Source.Reader; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.PTransform; @@ -79,8 +80,7 @@ private TransformEvaluator getTransformEvaluator( @SuppressWarnings("unchecked") private Queue> getTransformEvaluatorQueue( final AppliedPTransform, Bounded> transform, - final InProcessEvaluationContext evaluationContext) - throws IOException { + final InProcessEvaluationContext evaluationContext) { // Key by the application and the context the evaluation is occurring in (which call to // Pipeline#run). EvaluatorKey key = new EvaluatorKey(transform, evaluationContext); @@ -102,21 +102,25 @@ private Queue> getTransformEvaluatorQueu return evaluatorQueue; } + /** + * A {@link BoundedReadEvaluator} produces elements from an underlying {@link BoundedSource}, + * discarding all input elements. Within the call to {@link #finishBundle()}, the evaluator + * creates the {@link BoundedReader} and consumes all available input. + * + *

A {@link BoundedReadEvaluator} should only be created once per {@link BoundedSource}, and + * each evaluator should only be called once per evaluation of the pipeline. Otherwise, the source + * may produce duplicate elements. + */ private static class BoundedReadEvaluator implements TransformEvaluator { private final AppliedPTransform, Bounded> transform; private final InProcessEvaluationContext evaluationContext; - private final Reader reader; private boolean contentsRemaining; public BoundedReadEvaluator( AppliedPTransform, Bounded> transform, - InProcessEvaluationContext evaluationContext) - throws IOException { + InProcessEvaluationContext evaluationContext) { this.transform = transform; this.evaluationContext = evaluationContext; - reader = - transform.getTransform().getSource().createReader(evaluationContext.getPipelineOptions()); - contentsRemaining = reader.start(); } @Override @@ -124,17 +128,25 @@ public void processElement(WindowedValue element) {} @Override public InProcessTransformResult finishBundle() throws IOException { - UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); - while (contentsRemaining) { - output.add( - WindowedValue.timestampedValueInGlobalWindow( - reader.getCurrent(), reader.getCurrentTimestamp())); - contentsRemaining = reader.advance(); + try (final Reader reader = + transform + .getTransform() + .getSource() + .createReader(evaluationContext.getPipelineOptions());) { + contentsRemaining = reader.start(); + UncommittedBundle output = + evaluationContext.createRootBundle(transform.getOutput()); + while (contentsRemaining) { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + contentsRemaining = reader.advance(); + } + reader.close(); + return StepTransformResult.withHold(transform, BoundedWindow.TIMESTAMP_MAX_VALUE) + .addOutput(output) + .build(); } - return StepTransformResult - .withHold(transform, BoundedWindow.TIMESTAMP_MAX_VALUE) - .addOutput(output) - .build(); } } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java new file mode 100644 index 0000000000..3350d2b4d5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * A {@link ExecutorServiceFactory} that produces cached thread pools via + * {@link Executors#newCachedThreadPool()}. + */ +class CachedThreadPoolExecutorServiceFactory + implements DefaultValueFactory, ExecutorServiceFactory { + private static final CachedThreadPoolExecutorServiceFactory INSTANCE = + new CachedThreadPoolExecutorServiceFactory(); + + @Override + public ExecutorServiceFactory create(PipelineOptions options) { + return INSTANCE; + } + + @Override + public ExecutorService create() { + return Executors.newCachedThreadPool(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java new file mode 100644 index 0000000000..2792631560 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; + +/** + * A callback for completing a bundle of input. + */ +interface CompletionCallback { + /** + * Handle a successful result. + */ + void handleResult(CommittedBundle inputBundle, InProcessTransformResult result); + + /** + * Handle a result that terminated abnormally due to the provided {@link Throwable}. + */ + void handleThrowable(CommittedBundle inputBundle, Throwable t); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java new file mode 100644 index 0000000000..c602b23c41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the + * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume + * input after the upstream transform has produced and committed output. + */ +public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { + private Map>> valueToConsumers = new HashMap<>(); + private Collection> rootTransforms = new ArrayList<>(); + private Collection> views = new ArrayList<>(); + private Map, String> stepNames = new HashMap<>(); + private Set toFinalize = new HashSet<>(); + private int numTransforms = 0; + private boolean finalized = false; + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s " + + "which has already visited a Pipeline and is finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s which is already finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + if (node.isRootNode()) { + finalized = true; + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + toFinalize.removeAll(node.getInput().expand()); + AppliedPTransform appliedTransform = getAppliedTransform(node); + if (node.getInput().expand().isEmpty()) { + rootTransforms.add(appliedTransform); + } else { + for (PValue value : node.getInput().expand()) { + valueToConsumers.get(value).add(appliedTransform); + stepNames.put(appliedTransform, genStepName()); + } + } + } + + private AppliedPTransform getAppliedTransform(TransformTreeNode node) { + @SuppressWarnings({"rawtypes", "unchecked"}) + AppliedPTransform application = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); + return application; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + toFinalize.add(value); + for (PValue expandedValue : value.expand()) { + valueToConsumers.put(expandedValue, new ArrayList>()); + if (expandedValue instanceof PCollectionView) { + views.add((PCollectionView) expandedValue); + } + expandedValue.recordAsOutput(getAppliedTransform(producer)); + } + value.recordAsOutput(getAppliedTransform(producer)); + } + + private String genStepName() { + return String.format("s%s", numTransforms++); + } + + + /** + * Returns a mapping of each fully-expanded {@link PValue} to each + * {@link AppliedPTransform} that consumes it. For each AppliedPTransform in the collection + * returned from {@code getValueToCustomers().get(PValue)}, + * {@code AppliedPTransform#getInput().expand()} will contain the argument {@link PValue}. + */ + public Map>> getValueToConsumers() { + checkState( + finalized, + "Can't call getValueToConsumers before the Pipeline has been completely traversed"); + + return valueToConsumers; + } + + /** + * Returns the mapping for each {@link AppliedPTransform} in the {@link Pipeline} to a unique step + * name. + */ + public Map, String> getStepNames() { + checkState( + finalized, "Can't call getStepNames before the Pipeline has been completely traversed"); + + return stepNames; + } + + /** + * Returns the root transforms of the {@link Pipeline}. A root {@link AppliedPTransform} consumes + * a {@link PInput} where the {@link PInput#expand()} returns an empty collection. + */ + public Collection> getRootTransforms() { + checkState( + finalized, + "Can't call getRootTransforms before the Pipeline has been completely traversed"); + + return rootTransforms; + } + + /** + * Returns all of the {@link PCollectionView PCollectionViews} contained in the visited + * {@link Pipeline}. + */ + public Collection> getViews() { + checkState(finalized, "Can't call getViews before the Pipeline has been completely traversed"); + + return views; + } + + /** + * Returns all of the {@link PValue PValues} that have been produced but not consumed. These + * {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the + * {@link Pipeline} is executed. + */ + public Set getUnfinalizedPValues() { + checkState( + finalized, + "Can't call getUnfinalizedPValues before the Pipeline has been completely traversed"); + + return toFinalize; + } +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java index 745f8f2718..307bc5cdb5 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java @@ -15,7 +15,6 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import java.util.Objects; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java new file mode 100644 index 0000000000..480bcdefed --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import java.util.concurrent.ExecutorService; + +/** + * A factory that creates {@link ExecutorService ExecutorServices}. + * {@link ExecutorService ExecutorServices} created by this factory should be independent of one + * another (e.g., if any executor is shut down the remaining executors should continue to process + * work). + */ +public interface ExecutorServiceFactory { + /** + * Create a new {@link ExecutorService}. + */ + ExecutorService create(); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java new file mode 100644 index 0000000000..c72a1155f5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java @@ -0,0 +1,394 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItem; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItems; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.MoreObjects; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; + +import javax.annotation.Nullable; + +/** + * An {@link InProcessExecutor} that uses an underlying {@link ExecutorService} and + * {@link InProcessEvaluationContext} to execute a {@link Pipeline}. + */ +final class ExecutorServiceParallelExecutor implements InProcessExecutor { + private static final Logger LOG = LoggerFactory.getLogger(ExecutorServiceParallelExecutor.class); + + private final ExecutorService executorService; + + private final Map>> valueToConsumers; + private final Set keyedPValues; + private final TransformEvaluatorRegistry registry; + private final InProcessEvaluationContext evaluationContext; + + private final ConcurrentMap currentEvaluations; + private final ConcurrentMap, Boolean> scheduledExecutors; + + private final Queue allUpdates; + private final BlockingQueue visibleUpdates; + + private final TransformExecutorService parallelExecutorService; + private final CompletionCallback defaultCompletionCallback; + + private Collection> rootNodes; + + public static ExecutorServiceParallelExecutor create( + ExecutorService executorService, + Map>> valueToConsumers, + Set keyedPValues, + TransformEvaluatorRegistry registry, + InProcessEvaluationContext context) { + return new ExecutorServiceParallelExecutor( + executorService, valueToConsumers, keyedPValues, registry, context); + } + + private ExecutorServiceParallelExecutor( + ExecutorService executorService, + Map>> valueToConsumers, + Set keyedPValues, + TransformEvaluatorRegistry registry, + InProcessEvaluationContext context) { + this.executorService = executorService; + this.valueToConsumers = valueToConsumers; + this.keyedPValues = keyedPValues; + this.registry = registry; + this.evaluationContext = context; + + currentEvaluations = new ConcurrentHashMap<>(); + scheduledExecutors = new ConcurrentHashMap<>(); + + this.allUpdates = new ConcurrentLinkedQueue<>(); + this.visibleUpdates = new ArrayBlockingQueue<>(20); + + parallelExecutorService = + TransformExecutorServices.parallel(executorService, scheduledExecutors); + defaultCompletionCallback = new DefaultCompletionCallback(); + } + + @Override + public void start(Collection> roots) { + rootNodes = ImmutableList.copyOf(roots); + Runnable monitorRunnable = new MonitorRunnable(); + executorService.submit(monitorRunnable); + } + + @SuppressWarnings("unchecked") + public void scheduleConsumption( + AppliedPTransform consumer, + @Nullable CommittedBundle bundle, + CompletionCallback onComplete) { + evaluateBundle(consumer, bundle, onComplete); + } + + private void evaluateBundle( + final AppliedPTransform transform, + @Nullable final CommittedBundle bundle, + final CompletionCallback onComplete) { + TransformExecutorService transformExecutor; + if (bundle != null && isKeyed(bundle.getPCollection())) { + final StepAndKey stepAndKey = + StepAndKey.of(transform, bundle == null ? null : bundle.getKey()); + transformExecutor = getSerialExecutorService(stepAndKey); + } else { + transformExecutor = parallelExecutorService; + } + TransformExecutor callable = + TransformExecutor.create( + registry, evaluationContext, bundle, transform, onComplete, transformExecutor); + transformExecutor.schedule(callable); + } + + private boolean isKeyed(PValue pvalue) { + return keyedPValues.contains(pvalue); + } + + private void scheduleConsumers(CommittedBundle bundle) { + for (AppliedPTransform consumer : valueToConsumers.get(bundle.getPCollection())) { + scheduleConsumption(consumer, bundle, defaultCompletionCallback); + } + } + + private TransformExecutorService getSerialExecutorService(StepAndKey stepAndKey) { + if (!currentEvaluations.containsKey(stepAndKey)) { + currentEvaluations.putIfAbsent( + stepAndKey, TransformExecutorServices.serial(executorService, scheduledExecutors)); + } + return currentEvaluations.get(stepAndKey); + } + + @Override + public void awaitCompletion() throws Throwable { + VisibleExecutorUpdate update; + do { + update = visibleUpdates.take(); + if (update.throwable.isPresent()) { + throw update.throwable.get(); + } + } while (!update.isDone()); + executorService.shutdown(); + } + + /** + * The default {@link CompletionCallback}. The default completion callback is used to complete + * transform evaluations that are triggered due to the arrival of elements from an upstream + * transform, or for a source transform. + */ + private class DefaultCompletionCallback implements CompletionCallback { + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + Iterable> resultBundles = + evaluationContext.handleResult(inputBundle, Collections.emptyList(), result); + for (CommittedBundle outputBundle : resultBundles) { + allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle)); + } + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + allUpdates.offer(ExecutorUpdate.fromThrowable(t)); + } + } + + /** + * A {@link CompletionCallback} where the completed bundle was produced to deliver some collection + * of {@link TimerData timers}. When the evaluator completes successfully, reports all of the + * timers used to create the input to the {@link InProcessEvaluationContext evaluation context} + * as part of the result. + */ + private class TimerCompletionCallback implements CompletionCallback { + private final Iterable timers; + + private TimerCompletionCallback(Iterable timers) { + this.timers = timers; + } + + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + Iterable> resultBundles = + evaluationContext.handleResult(inputBundle, timers, result); + for (CommittedBundle outputBundle : resultBundles) { + allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle)); + } + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + allUpdates.offer(ExecutorUpdate.fromThrowable(t)); + } + } + + /** + * An internal status update on the state of the executor. + * + * Used to signal when the executor should be shut down (due to an exception). + */ + private static class ExecutorUpdate { + private final Optional> bundle; + private final Optional throwable; + + public static ExecutorUpdate fromBundle(CommittedBundle bundle) { + return new ExecutorUpdate(bundle, null); + } + + public static ExecutorUpdate fromThrowable(Throwable t) { + return new ExecutorUpdate(null, t); + } + + private ExecutorUpdate(CommittedBundle producedBundle, Throwable throwable) { + this.bundle = Optional.fromNullable(producedBundle); + this.throwable = Optional.fromNullable(throwable); + } + + public Optional> getBundle() { + return bundle; + } + + public Optional getException() { + return throwable; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(ExecutorUpdate.class) + .add("bundle", bundle) + .add("exception", throwable) + .toString(); + } + } + + /** + * An update of interest to the user. Used in {@link #awaitCompletion} to decide whether to + * return normally or throw an exception. + */ + private static class VisibleExecutorUpdate { + private final Optional throwable; + private final boolean done; + + public static VisibleExecutorUpdate fromThrowable(Throwable e) { + return new VisibleExecutorUpdate(false, e); + } + + public static VisibleExecutorUpdate finished() { + return new VisibleExecutorUpdate(true, null); + } + + private VisibleExecutorUpdate(boolean done, @Nullable Throwable exception) { + this.throwable = Optional.fromNullable(exception); + this.done = done; + } + + public boolean isDone() { + return done; + } + } + + private class MonitorRunnable implements Runnable { + private final String runnableName = + String.format( + "%s$%s-monitor", + evaluationContext.getPipelineOptions().getAppName(), + ExecutorServiceParallelExecutor.class.getSimpleName()); + + @Override + public void run() { + String oldName = Thread.currentThread().getName(); + Thread.currentThread().setName(runnableName); + try { + ExecutorUpdate update = allUpdates.poll(); + if (update != null) { + LOG.debug("Executor Update: {}", update); + if (update.getBundle().isPresent()) { + scheduleConsumers(update.getBundle().get()); + } else if (update.getException().isPresent()) { + visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(update.getException().get())); + } + } + fireTimers(); + mightNeedMoreWork(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.error("Monitor died due to being interrupted"); + while (!visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(e))) { + visibleUpdates.poll(); + } + } catch (Throwable t) { + LOG.error("Monitor thread died due to throwable", t); + while (!visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(t))) { + visibleUpdates.poll(); + } + } finally { + if (!shouldShutdown()) { + // The monitor thread should always be scheduled; but we only need to be scheduled once + executorService.submit(this); + } + Thread.currentThread().setName(oldName); + } + } + + private void fireTimers() throws Exception { + try { + for (Map.Entry, Map> transformTimers : + evaluationContext.extractFiredTimers().entrySet()) { + AppliedPTransform transform = transformTimers.getKey(); + for (Map.Entry keyTimers : transformTimers.getValue().entrySet()) { + for (TimeDomain domain : TimeDomain.values()) { + Collection delivery = keyTimers.getValue().getTimers(domain); + if (delivery.isEmpty()) { + continue; + } + KeyedWorkItem work = + KeyedWorkItems.timersWorkItem(keyTimers.getKey(), delivery); + @SuppressWarnings({"unchecked", "rawtypes"}) + CommittedBundle bundle = + InProcessBundle.>keyed( + (PCollection) transform.getInput(), keyTimers.getKey()) + .add(WindowedValue.valueInEmptyWindows(work)) + .commit(Instant.now()); + scheduleConsumption(transform, bundle, new TimerCompletionCallback(delivery)); + } + } + } + } catch (Exception e) { + LOG.error("Internal Error while delivering timers", e); + throw e; + } + } + + private boolean shouldShutdown() { + if (evaluationContext.isDone()) { + LOG.debug("Pipeline is finished. Shutting down. {}"); + while (!visibleUpdates.offer(VisibleExecutorUpdate.finished())) { + visibleUpdates.poll(); + } + executorService.shutdown(); + return true; + } + return false; + } + + private void mightNeedMoreWork() { + synchronized (scheduledExecutors) { + for (TransformExecutor executor : scheduledExecutors.keySet()) { + Thread thread = executor.getThread(); + if (thread != null) { + switch (thread.getState()) { + case BLOCKED: + case WAITING: + case TERMINATED: + case TIMED_WAITING: + break; + default: + return; + } + } + } + } + // All current TransformExecutors are blocked; add more work from the roots. + for (AppliedPTransform root : rootNodes) { + scheduleConsumption(root, null, defaultCompletionCallback); + } + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java index 14428888e2..bde1df45e9 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java @@ -16,7 +16,6 @@ package com.google.cloud.dataflow.sdk.runners.inprocess; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.Flatten; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java index 0347281749..b1d5d35869 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java @@ -22,7 +22,6 @@ import com.google.cloud.dataflow.sdk.coders.IterableCoder; import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.StepTransformResult.Builder; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; @@ -60,8 +59,10 @@ public TransformEvaluator forApplication( AppliedPTransform application, CommittedBundle inputBundle, InProcessEvaluationContext evaluationContext) { - return createEvaluator( - (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + @SuppressWarnings({"cast", "unchecked", "rawtypes"}) + TransformEvaluator evaluator = createEvaluator( + (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + return evaluator; } private TransformEvaluator>> createEvaluator( @@ -184,7 +185,7 @@ public static final class InProcessGroupByKey extends ForwardingPTransform>, PCollection>>> { private final GroupByKey original; - public InProcessGroupByKey(GroupByKey from) { + private InProcessGroupByKey(GroupByKey from) { this.original = from; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java index e280e22d2b..094526d962 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java @@ -866,7 +866,7 @@ private void updatePending( * {@link #getWatermarks(AppliedPTransform)}, the output watermark will be equal to * {@link BoundedWindow#TIMESTAMP_MAX_VALUE}. */ - public boolean isDone() { + public boolean allWatermarksAtPositiveInfinity() { for (Map.Entry, TransformWatermarks> watermarksEntry : transformToWatermarks.entrySet()) { Instant endOfTime = THE_END_OF_TIME.get(); @@ -1209,8 +1209,11 @@ public TimerUpdateBuilder deletedTimer(TimerData deletedTimer) { * and deletedTimers. */ public TimerUpdate build() { - return new TimerUpdate(key, ImmutableSet.copyOf(completedTimers), - ImmutableSet.copyOf(setTimers), ImmutableSet.copyOf(deletedTimers)); + return new TimerUpdate( + key, + ImmutableSet.copyOf(completedTimers), + ImmutableSet.copyOf(setTimers), + ImmutableSet.copyOf(deletedTimers)); } } @@ -1245,6 +1248,13 @@ Iterable getDeletedTimers() { return deletedTimers; } + /** + * Returns a {@link TimerUpdate} that is like this one, but with the specified completed timers. + */ + public TimerUpdate withCompletedTimers(Iterable completedTimers) { + return new TimerUpdate(this.key, completedTimers, setTimers, deletedTimers); + } + @Override public int hashCode() { return Objects.hash(key, completedTimers, setTimers, deletedTimers); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java index cc20161097..112ba17d14 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -22,7 +22,6 @@ import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.common.base.MoreObjects; -import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.collect.ImmutableList; import org.joda.time.Instant; @@ -64,6 +63,11 @@ private InProcessBundle(PCollection pcollection, boolean keyed, Object key) { this.elements = ImmutableList.builder(); } + @Override + public PCollection getPCollection() { + return pcollection; + } + @Override public InProcessBundle add(WindowedValue element) { checkState(!committed, "Can't add element %s to committed bundle %s", element, this); @@ -105,12 +109,12 @@ public Instant getSynchronizedProcessingOutputWatermark() { @Override public String toString() { - ToStringHelper toStringHelper = - MoreObjects.toStringHelper(this).add("pcollection", pcollection); - if (keyed) { - toStringHelper = toStringHelper.add("key", key); - } - return toStringHelper.add("elements", elements).toString(); + return MoreObjects.toStringHelper(this) + .omitNullValues() + .add("pcollection", pcollection) + .add("key", key) + .add("elements", committedElements) + .toString(); } }; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java new file mode 100644 index 0000000000..2908fba818 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java @@ -0,0 +1,383 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nullable; + +/** + * The evaluation context for a specific pipeline being executed by the + * {@link InProcessPipelineRunner}. Contains state shared within the execution across all + * transforms. + * + *

{@link InProcessEvaluationContext} contains shared state for an execution of the + * {@link InProcessPipelineRunner} that can be used while evaluating a {@link PTransform}. This + * consists of views into underlying state and watermark implementations, access to read and write + * {@link PCollectionView PCollectionViews}, and constructing {@link CounterSet CounterSets} and + * {@link ExecutionContext ExecutionContexts}. This includes executing callbacks asynchronously when + * state changes to the appropriate point (e.g. when a {@link PCollectionView} is requested and + * known to be empty). + * + *

{@link InProcessEvaluationContext} also handles results by committing finalizing bundles based + * on the current global state and updating the global state appropriately. This includes updating + * the per-{@link StepAndKey} state, updating global watermarks, and executing any callbacks that + * can be executed. + */ +class InProcessEvaluationContext { + /** The step name for each {@link AppliedPTransform} in the {@link Pipeline}. */ + private final Map, String> stepNames; + + /** The options that were used to create this {@link Pipeline}. */ + private final InProcessPipelineOptions options; + + /** The current processing time and event time watermarks and timers. */ + private final InMemoryWatermarkManager watermarkManager; + + /** Executes callbacks based on the progression of the watermark. */ + private final WatermarkCallbackExecutor callbackExecutor; + + /** The stateInternals of the world, by applied PTransform and key. */ + private final ConcurrentMap> + applicationStateInternals; + + private final InProcessSideInputContainer sideInputContainer; + + private final CounterSet mergedCounters; + + public static InProcessEvaluationContext create( + InProcessPipelineOptions options, + Collection> rootTransforms, + Map>> valueToConsumers, + Map, String> stepNames, + Collection> views) { + return new InProcessEvaluationContext( + options, rootTransforms, valueToConsumers, stepNames, views); + } + + private InProcessEvaluationContext( + InProcessPipelineOptions options, + Collection> rootTransforms, + Map>> valueToConsumers, + Map, String> stepNames, + Collection> views) { + this.options = checkNotNull(options); + checkNotNull(rootTransforms); + checkNotNull(valueToConsumers); + checkNotNull(stepNames); + checkNotNull(views); + this.stepNames = stepNames; + + this.watermarkManager = + InMemoryWatermarkManager.create( + NanosOffsetClock.create(), rootTransforms, valueToConsumers); + this.sideInputContainer = InProcessSideInputContainer.create(this, views); + + this.applicationStateInternals = new ConcurrentHashMap<>(); + this.mergedCounters = new CounterSet(); + + this.callbackExecutor = WatermarkCallbackExecutor.create(); + } + + /** + * Handle the provided {@link InProcessTransformResult}, produced after evaluating the provided + * {@link CommittedBundle} (potentially null, if the result of a root {@link PTransform}). + * + *

The result is the output of running the transform contained in the + * {@link InProcessTransformResult} on the contents of the provided bundle. + * + * @param completedBundle the bundle that was processed to produce the result. Potentially + * {@code null} if the transform that produced the result is a root + * transform + * @param completedTimers the timers that were delivered to produce the {@code completedBundle}, + * or an empty iterable if no timers were delivered + * @param result the result of evaluating the input bundle + * @return the committed bundles contained within the handled {@code result} + */ + public synchronized Iterable> handleResult( + @Nullable CommittedBundle completedBundle, + Iterable completedTimers, + InProcessTransformResult result) { + Iterable> committedBundles = + commitBundles(result.getOutputBundles()); + // Update watermarks and timers + watermarkManager.updateWatermarks( + completedBundle, + result.getTransform(), + result.getTimerUpdate().withCompletedTimers(completedTimers), + committedBundles, + result.getWatermarkHold()); + fireAllAvailableCallbacks(); + // Update counters + if (result.getCounters() != null) { + mergedCounters.merge(result.getCounters()); + } + // Update state internals + CopyOnAccessInMemoryStateInternals theirState = result.getState(); + if (theirState != null) { + CopyOnAccessInMemoryStateInternals committedState = theirState.commit(); + StepAndKey stepAndKey = + StepAndKey.of( + result.getTransform(), completedBundle == null ? null : completedBundle.getKey()); + if (!committedState.isEmpty()) { + applicationStateInternals.put(stepAndKey, committedState); + } else { + applicationStateInternals.remove(stepAndKey); + } + } + return committedBundles; + } + + private Iterable> commitBundles( + Iterable> bundles) { + ImmutableList.Builder> completed = ImmutableList.builder(); + for (UncommittedBundle inProgress : bundles) { + AppliedPTransform producing = + inProgress.getPCollection().getProducingTransformInternal(); + TransformWatermarks watermarks = watermarkManager.getWatermarks(producing); + CommittedBundle committed = + inProgress.commit(watermarks.getSynchronizedProcessingOutputTime()); + // Empty bundles don't impact watermarks and shouldn't trigger downstream execution, so + // filter them out + if (!Iterables.isEmpty(committed.getElements())) { + completed.add(committed); + } + } + return completed.build(); + } + + private void fireAllAvailableCallbacks() { + for (AppliedPTransform transform : stepNames.keySet()) { + fireAvailableCallbacks(transform); + } + } + + private void fireAvailableCallbacks(AppliedPTransform producingTransform) { + TransformWatermarks watermarks = watermarkManager.getWatermarks(producingTransform); + callbackExecutor.fireForWatermark(producingTransform, watermarks.getOutputWatermark()); + } + + /** + * Create a {@link UncommittedBundle} for use by a source. + */ + public UncommittedBundle createRootBundle(PCollection output) { + return InProcessBundle.unkeyed(output); + } + + /** + * Create a {@link UncommittedBundle} whose elements belong to the specified {@link + * PCollection}. + */ + public UncommittedBundle createBundle(CommittedBundle input, PCollection output) { + return input.isKeyed() + ? InProcessBundle.keyed(output, input.getKey()) + : InProcessBundle.unkeyed(output); + } + + /** + * Create a {@link UncommittedBundle} with the specified keys at the specified step. For use by + * {@link InProcessGroupByKeyOnly} {@link PTransform PTransforms}. + */ + public UncommittedBundle createKeyedBundle( + CommittedBundle input, Object key, PCollection output) { + return InProcessBundle.keyed(output, key); + } + + /** + * Create a {@link PCollectionViewWriter}, whose elements will be used in the provided + * {@link PCollectionView}. + */ + public PCollectionViewWriter createPCollectionViewWriter( + PCollection> input, final PCollectionView output) { + return new PCollectionViewWriter() { + @Override + public void add(Iterable> values) { + sideInputContainer.write(output, values); + } + }; + } + + /** + * Schedule a callback to be executed after output would be produced for the given window + * if there had been input. + * + *

Output would be produced when the watermark for a {@link PValue} passes the point at + * which the trigger for the specified window (with the specified windowing strategy) must have + * fired from the perspective of that {@link PValue}, as specified by the value of + * {@link Trigger#getWatermarkThatGuaranteesFiring(BoundedWindow)} for the trigger of the + * {@link WindowingStrategy}. When the callback has fired, either values will have been produced + * for a key in that window, the window is empty, or all elements in the window are late. The + * callback will be executed regardless of whether values have been produced. + */ + public void scheduleAfterOutputWouldBeProduced( + PValue value, + BoundedWindow window, + WindowingStrategy windowingStrategy, + Runnable runnable) { + AppliedPTransform producing = getProducing(value); + callbackExecutor.callOnGuaranteedFiring(producing, window, windowingStrategy, runnable); + + fireAvailableCallbacks(lookupProducing(value)); + } + + private AppliedPTransform getProducing(PValue value) { + if (value.getProducingTransformInternal() != null) { + return value.getProducingTransformInternal(); + } + return lookupProducing(value); + } + + private AppliedPTransform lookupProducing(PValue value) { + for (AppliedPTransform transform : stepNames.keySet()) { + if (transform.getOutput().equals(value) || transform.getOutput().expand().contains(value)) { + return transform; + } + } + return null; + } + + /** + * Get the options used by this {@link Pipeline}. + */ + public InProcessPipelineOptions getPipelineOptions() { + return options; + } + + /** + * Get an {@link ExecutionContext} for the provided {@link AppliedPTransform} and key. + */ + public InProcessExecutionContext getExecutionContext( + AppliedPTransform application, Object key) { + StepAndKey stepAndKey = StepAndKey.of(application, key); + return new InProcessExecutionContext( + options.getClock(), + key, + (CopyOnAccessInMemoryStateInternals) applicationStateInternals.get(stepAndKey), + watermarkManager.getWatermarks(application)); + } + + /** + * Get all of the steps used in this {@link Pipeline}. + */ + public Collection> getSteps() { + return stepNames.keySet(); + } + + /** + * Get the Step Name for the provided application. + */ + public String getStepName(AppliedPTransform application) { + return stepNames.get(application); + } + + /** + * Returns a {@link SideInputReader} capable of reading the provided + * {@link PCollectionView PCollectionViews}. + * @param sideInputs the {@link PCollectionView PCollectionViews} the result should be able to + * read + * @return a {@link SideInputReader} that can read all of the provided + * {@link PCollectionView PCollectionViews} + */ + public SideInputReader createSideInputReader(final List> sideInputs) { + return sideInputContainer.createReaderForViews(sideInputs); + } + + /** + * Create a {@link CounterSet} for this {@link Pipeline}. The {@link CounterSet} is independent + * of all other {@link CounterSet CounterSets} created by this call. + * + * The {@link InProcessEvaluationContext} is responsible for unifying the counters present in + * all created {@link CounterSet CounterSets} when the transforms that call this method + * complete. + */ + public CounterSet createCounterSet() { + return new CounterSet(); + } + + /** + * Returns all of the counters that have been merged into this context via calls to + * {@link CounterSet#merge(CounterSet)}. + */ + public CounterSet getCounters() { + return mergedCounters; + } + + /** + * Extracts all timers that have been fired and have not already been extracted. + * + *

This is a destructive operation. Timers will only appear in the result of this method once + * for each time they are set. + */ + public Map, Map> extractFiredTimers() { + return watermarkManager.extractFiredTimers(); + } + + /** + * Returns true if all steps are done. + */ + public boolean isDone() { + if (!options.isShutdownUnboundedProducersWithMaxWatermark() && containsUnboundedPCollection()) { + return false; + } + if (!watermarkManager.allWatermarksAtPositiveInfinity()) { + return false; + } + return true; + } + + private boolean containsUnboundedPCollection() { + for (AppliedPTransform transform : stepNames.keySet()) { + for (PValue value : transform.getInput().expand()) { + if (value instanceof PCollection + && ((PCollection) value).isBounded().equals(IsBounded.UNBOUNDED)) { + return true; + } + } + } + return false; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java new file mode 100644 index 0000000000..7b60bca17d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; + +/** + * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both + * source and intermediate {@link PTransform PTransforms}. + */ +interface InProcessExecutor { + /** + * Starts this executor. The provided collection is the collection of root transforms to + * initially schedule. + * + * @param rootTransforms + */ + void start(Collection> rootTransforms); + + /** + * Blocks until the job being executed enters a terminal state. A job is completed after all + * root {@link AppliedPTransform AppliedPTransforms} have completed, and all + * {@link CommittedBundle Bundles} have been consumed. Jobs may also terminate abnormally. + * + * @throws Throwable whenever an executor thread throws anything, transfers the throwable to the + * waiting thread and rethrows it + */ + void awaitCompletion() throws Throwable; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index d659d962f0..5ee0e88b5b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -15,10 +15,76 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.Hidden; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.Validation.Required; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ -public interface InProcessPipelineOptions extends PipelineOptions {} +public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions { + /** + * Gets the {@link ExecutorServiceFactory} to use to create instances of {@link ExecutorService} + * to execute {@link PTransform PTransforms}. + * + *

Note that {@link ExecutorService ExecutorServices} returned by the factory must ensure that + * it cannot enter a state in which it will not schedule additional pending work unless currently + * scheduled work completes, as this may cause the {@link Pipeline} to cease processing. + * + *

Defaults to a {@link CachedThreadPoolExecutorServiceFactory}, which produces instances of + * {@link Executors#newCachedThreadPool()}. + */ + @JsonIgnore + @Required + @Hidden + @Default.InstanceFactory(CachedThreadPoolExecutorServiceFactory.class) + ExecutorServiceFactory getExecutorServiceFactory(); + + void setExecutorServiceFactory(ExecutorServiceFactory executorService); + + /** + * Gets the {@link Clock} used by this pipeline. The clock is used in place of accessing the + * system time when time values are required by the evaluator. + */ + @Default.InstanceFactory(NanosOffsetClock.Factory.class) + @JsonIgnore + @Required + @Hidden + @Description( + "The processing time source used by the pipeline. When the current time is " + + "needed by the evaluator, the result of clock#now() is used.") + Clock getClock(); + + void setClock(Clock clock); + + @Default.Boolean(false) + @Description( + "If the pipeline should shut down producers which have reached the maximum " + + "representable watermark. If this is set to true, a pipeline in which all PTransforms " + + "have reached the maximum watermark will be shut down, even if there are unbounded " + + "sources that could produce additional (late) data. By default, if the pipeline " + + "contains any unbounded PCollections, it will run until explicitly shut down.") + boolean isShutdownUnboundedProducersWithMaxWatermark(); + + void setShutdownUnboundedProducersWithMaxWatermark(boolean shutdown); + + @Default.Boolean(true) + @Description( + "If the pipeline should block awaiting completion of the pipeline. If set to true, " + + "a call to Pipeline#run() will block until all PTransforms are complete. Otherwise, " + + "the Pipeline will execute asynchronously. If set to false, the completion of the " + + "pipeline can be awaited on by use of InProcessPipelineResult#awaitCompletion().") + boolean isBlockOnRun(); + void setBlockOnRun(boolean b); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 124de46b94..a1c8756c5e 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -15,35 +15,46 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import static com.google.common.base.Preconditions.checkArgument; - import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; -import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; -import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; -import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; -import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; -import com.google.cloud.dataflow.sdk.util.ExecutionContext; -import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.UserCodeException; import com.google.cloud.dataflow.sdk.util.WindowedValue; -import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.Counter; import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.joda.time.Instant; -import java.util.List; +import java.util.Collection; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; import javax.annotation.Nullable; @@ -52,28 +63,25 @@ * {@link PCollection PCollections}. */ @Experimental -public class InProcessPipelineRunner { - @SuppressWarnings({"rawtypes", "unused"}) +public class InProcessPipelineRunner + extends PipelineRunner { + /** + * The default set of transform overrides to use in the {@link InProcessPipelineRunner}. + * + *

A transform override must have a single-argument constructor that takes an instance of the + * type of transform it is overriding. + */ + @SuppressWarnings("rawtypes") private static Map, Class> defaultTransformOverrides = ImmutableMap., Class>builder() + .put(Create.Values.class, InProcessCreate.class) .put(GroupByKey.class, InProcessGroupByKey.class) - .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class) + .put( + CreatePCollectionView.class, + ViewEvaluatorFactory.InProcessCreatePCollectionView.class) .build(); - private static Map, TransformEvaluatorFactory> defaultEvaluatorFactories = - new ConcurrentHashMap<>(); - - /** - * Register a default transform evaluator. - */ - public static > void registerTransformEvaluatorFactory( - Class clazz, TransformEvaluatorFactory evaluator) { - checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null, - "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator, - clazz); - } - /** * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be * executed by {@link PTransform PTransforms} that consume the {@link PCollection} this bundle is @@ -82,6 +90,11 @@ public class InProcessPipelineRunner { * @param the type of elements that can be added to this bundle */ public static interface UncommittedBundle { + /** + * Returns the PCollection that the elements of this {@link UncommittedBundle} belong to. + */ + PCollection getPCollection(); + /** * Outputs an element to this bundle. * @@ -108,14 +121,13 @@ public static interface UncommittedBundle { * @param the type of elements contained within this bundle */ public static interface CommittedBundle { - /** - * @return the PCollection that the elements of this bundle belong to + * Returns the PCollection that the elements of this bundle belong to. */ PCollection getPCollection(); /** - * Returns weather this bundle is keyed. A bundle that is part of a {@link PCollection} that + * Returns whether this bundle is keyed. A bundle that is part of a {@link PCollection} that * occurs after a {@link GroupByKey} is keyed by the result of the last {@link GroupByKey}. */ boolean isKeyed(); @@ -124,11 +136,12 @@ public static interface CommittedBundle { * Returns the (possibly null) key that was output in the most recent {@link GroupByKey} in the * execution of this bundle. */ - @Nullable Object getKey(); + @Nullable + Object getKey(); /** - * @return an {@link Iterable} containing all of the elements that have been added to this - * {@link CommittedBundle} + * Returns an {@link Iterable} containing all of the elements that have been added to this + * {@link CommittedBundle}. */ Iterable> getElements(); @@ -154,107 +167,177 @@ public static interface PCollectionViewWriter { void add(Iterable> values); } - /** - * The evaluation context for the {@link InProcessPipelineRunner}. Contains state shared within - * the current evaluation. - */ - public static interface InProcessEvaluationContext { - /** - * Create a {@link UncommittedBundle} for use by a source. - */ - UncommittedBundle createRootBundle(PCollection output); + //////////////////////////////////////////////////////////////////////////////////////////////// + private final InProcessPipelineOptions options; - /** - * Create a {@link UncommittedBundle} whose elements belong to the specified {@link - * PCollection}. - */ - UncommittedBundle createBundle(CommittedBundle input, PCollection output); + public static InProcessPipelineRunner fromOptions(PipelineOptions options) { + return new InProcessPipelineRunner(options.as(InProcessPipelineOptions.class)); + } - /** - * Create a {@link UncommittedBundle} with the specified keys at the specified step. For use by - * {@link GroupByKeyOnly} {@link PTransform PTransforms}. - */ - UncommittedBundle createKeyedBundle( - CommittedBundle input, Object key, PCollection output); + private InProcessPipelineRunner(InProcessPipelineOptions options) { + this.options = options; + } - /** - * Create a bundle whose elements will be used in a PCollectionView. - */ - PCollectionViewWriter createPCollectionViewWriter( - PCollection> input, PCollectionView output); + /** + * Returns the {@link PipelineOptions} used to create this {@link InProcessPipelineRunner}. + */ + public InProcessPipelineOptions getPipelineOptions() { + return options; + } - /** - * Get the options used by this {@link Pipeline}. - */ - InProcessPipelineOptions getPipelineOptions(); + @Override + public OutputT apply( + PTransform transform, InputT input) { + Class overrideClass = defaultTransformOverrides.get(transform.getClass()); + if (overrideClass != null) { + // It is the responsibility of whoever constructs overrides to ensure this is type safe. + @SuppressWarnings("unchecked") + Class> transformClass = + (Class>) transform.getClass(); - /** - * Get an {@link ExecutionContext} for the provided application. - */ - InProcessExecutionContext getExecutionContext( - AppliedPTransform application, @Nullable Object key); + @SuppressWarnings("unchecked") + Class> customTransformClass = + (Class>) overrideClass; - /** - * Get the Step Name for the provided application. - */ - String getStepName(AppliedPTransform application); + PTransform customTransform = + InstanceBuilder.ofType(customTransformClass) + .withArg(transformClass, transform) + .build(); - /** - * @param sideInputs the {@link PCollectionView PCollectionViews} the result should be able to - * read - * @return a {@link SideInputReader} that can read all of the provided - * {@link PCollectionView PCollectionViews} - */ - SideInputReader createSideInputReader(List> sideInputs); + // This overrides the contents of the apply method without changing the TransformTreeNode that + // is generated by the PCollection application. + return super.apply(customTransform, input); + } else { + return super.apply(transform, input); + } + } - /** - * Schedules a callback after the watermark for a {@link PValue} after the trigger for the - * specified window (with the specified windowing strategy) must have fired from the perspective - * of that {@link PValue}, as specified by the value of - * {@link Trigger#getWatermarkThatGuaranteesFiring(BoundedWindow)} for the trigger of the - * {@link WindowingStrategy}. - */ - void callAfterOutputMustHaveBeenProduced(PValue value, BoundedWindow window, - WindowingStrategy windowingStrategy, Runnable runnable); + @Override + public InProcessPipelineResult run(Pipeline pipeline) { + ConsumerTrackingPipelineVisitor consumerTrackingVisitor = new ConsumerTrackingPipelineVisitor(); + pipeline.traverseTopologically(consumerTrackingVisitor); + for (PValue unfinalized : consumerTrackingVisitor.getUnfinalizedPValues()) { + unfinalized.finishSpecifying(); + } + @SuppressWarnings("rawtypes") + KeyedPValueTrackingVisitor keyedPValueVisitor = + KeyedPValueTrackingVisitor.create( + ImmutableSet.>of( + GroupByKey.class, InProcessGroupByKeyOnly.class)); + pipeline.traverseTopologically(keyedPValueVisitor); - /** - * Create a {@link CounterSet} for this {@link Pipeline}. The {@link CounterSet} is independent - * of all other {@link CounterSet CounterSets} created by this call. - * - * The {@link InProcessEvaluationContext} is responsible for unifying the counters present in - * all created {@link CounterSet CounterSets} when the transforms that call this method - * complete. - */ - CounterSet createCounterSet(); + InProcessEvaluationContext context = + InProcessEvaluationContext.create( + getPipelineOptions(), + consumerTrackingVisitor.getRootTransforms(), + consumerTrackingVisitor.getValueToConsumers(), + consumerTrackingVisitor.getStepNames(), + consumerTrackingVisitor.getViews()); - /** - * Returns all of the counters that have been merged into this context via calls to - * {@link CounterSet#merge(CounterSet)}. - */ - CounterSet getCounters(); + // independent executor service for each run + ExecutorService executorService = + context.getPipelineOptions().getExecutorServiceFactory().create(); + InProcessExecutor executor = + ExecutorServiceParallelExecutor.create( + executorService, + consumerTrackingVisitor.getValueToConsumers(), + keyedPValueVisitor.getKeyedPValues(), + TransformEvaluatorRegistry.defaultRegistry(), + context); + executor.start(consumerTrackingVisitor.getRootTransforms()); + + Map, Collection>> aggregatorSteps = + new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + InProcessPipelineResult result = + new InProcessPipelineResult(executor, context, aggregatorSteps); + if (options.isBlockOnRun()) { + try { + result.awaitCompletion(); + } catch (UserCodeException userException) { + throw new PipelineExecutionException(userException.getCause()); + } catch (Throwable t) { + Throwables.propagate(t); + } + } + return result; } /** - * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both - * source and intermediate {@link PTransform PTransforms}. + * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}. + * + * Throws {@link UnsupportedOperationException} for all methods. */ - public static interface InProcessExecutor { - /** - * @param root the root {@link AppliedPTransform} to schedule - */ - void scheduleRoot(AppliedPTransform root); + public static class InProcessPipelineResult implements PipelineResult { + private final InProcessExecutor executor; + private final InProcessEvaluationContext evaluationContext; + private final Map, Collection>> aggregatorSteps; + private State state; - /** - * @param consumer the {@link AppliedPTransform} to schedule - * @param bundle the input bundle to the consumer - */ - void scheduleConsumption(AppliedPTransform consumer, CommittedBundle bundle); + private InProcessPipelineResult( + InProcessExecutor executor, + InProcessEvaluationContext evaluationContext, + Map, Collection>> aggregatorSteps) { + this.executor = executor; + this.evaluationContext = evaluationContext; + this.aggregatorSteps = aggregatorSteps; + // Only ever constructed after the executor has started. + this.state = State.RUNNING; + } + + @Override + public State getState() { + return state; + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException { + CounterSet counters = evaluationContext.getCounters(); + Collection> steps = aggregatorSteps.get(aggregator); + Map stepValues = new HashMap<>(); + for (AppliedPTransform transform : evaluationContext.getSteps()) { + if (steps.contains(transform.getTransform())) { + String stepName = + String.format( + "user-%s-%s", evaluationContext.getStepName(transform), aggregator.getName()); + Counter counter = (Counter) counters.getExistingCounter(stepName); + if (counter != null) { + stepValues.put(transform.getFullName(), counter.getAggregate()); + } + } + } + return new MapAggregatorValues<>(stepValues); + } /** - * Blocks until the job being executed enters a terminal state. A job is completed after all - * root {@link AppliedPTransform AppliedPTransforms} have completed, and all - * {@link CommittedBundle Bundles} have been consumed. Jobs may also terminate abnormally. + * Blocks until the {@link Pipeline} execution represented by this + * {@link InProcessPipelineResult} is complete, returning the terminal state. + * + *

If the pipeline terminates abnormally by throwing an exception, this will rethrow the + * exception. Future calls to {@link #getState()} will return + * {@link com.google.cloud.dataflow.sdk.PipelineResult.State#FAILED}. + * + *

NOTE: if the {@link Pipeline} contains an {@link IsBounded#UNBOUNDED unbounded} + * {@link PCollection}, and the {@link PipelineRunner} was created with + * {@link InProcessPipelineOptions#isShutdownUnboundedProducersWithMaxWatermark()} set to false, + * this method will never return. + * + * See also {@link InProcessExecutor#awaitCompletion()}. */ - void awaitCompletion(); + public State awaitCompletion() throws Throwable { + if (!state.isTerminal()) { + try { + executor.awaitCompletion(); + state = State.DONE; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw e; + } catch (Throwable t) { + state = State.FAILED; + throw t; + } + } + return state; + } } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java index bf9a2e1c53..37c9fcfa65 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java @@ -17,7 +17,6 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; import com.google.cloud.dataflow.sdk.util.PCollectionViewWindow; @@ -26,6 +25,7 @@ import com.google.cloud.dataflow.sdk.util.WindowingStrategy; import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.common.base.MoreObjects; +import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; @@ -89,7 +89,7 @@ private InProcessSideInputContainer(InProcessEvaluationContext context, * the provided argument. The returned {@link InProcessSideInputContainer} is unmodifiable without * casting, but will change as this {@link InProcessSideInputContainer} is modified. */ - public SideInputReader withViews(Collection> newContainedViews) { + public SideInputReader createReaderForViews(Collection> newContainedViews) { if (!containedViews.containsAll(newContainedViews)) { Set> currentlyContained = ImmutableSet.copyOf(containedViews); Set> newRequested = ImmutableSet.copyOf(newContainedViews); @@ -108,8 +108,20 @@ public SideInputReader withViews(Collection> newContainedView * *

The provided iterable is expected to contain only a single window and pane. */ - public void write(PCollectionView view, Iterable> values) - throws ExecutionException { + public void write(PCollectionView view, Iterable> values) { + Map>> valuesPerWindow = + indexValuesByWindow(values); + for (Map.Entry>> windowValues : + valuesPerWindow.entrySet()) { + updatePCollectionViewWindowValues(view, windowValues.getKey(), windowValues.getValue()); + } + } + + /** + * Index the provided values by all {@link BoundedWindow windows} in which they appear. + */ + private Map>> indexValuesByWindow( + Iterable> values) { Map>> valuesPerWindow = new HashMap<>(); for (WindowedValue value : values) { for (BoundedWindow window : value.getWindows()) { @@ -121,29 +133,40 @@ public void write(PCollectionView view, Iterable> windowValues.add(value); } } - for (Map.Entry>> windowValues : - valuesPerWindow.entrySet()) { - PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, windowValues.getKey()); - SettableFuture>> future = viewByWindows.get(windowedView); + return valuesPerWindow; + } + + /** + * Set the value of the {@link PCollectionView} in the {@link BoundedWindow} to be based on the + * specified values, if the values are part of a later pane than currently exist within the + * {@link PCollectionViewWindow}. + */ + private void updatePCollectionViewWindowValues( + PCollectionView view, BoundedWindow window, Collection> windowValues) { + PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, window); + SettableFuture>> future = null; + try { + future = viewByWindows.get(windowedView); if (future.isDone()) { - try { - Iterator> existingValues = future.get().iterator(); - PaneInfo newPane = windowValues.getValue().iterator().next().getPane(); - // The current value may have no elements, if no elements were produced for the window, - // but we are recieving late data. - if (!existingValues.hasNext() - || newPane.getIndex() > existingValues.next().getPane().getIndex()) { - viewByWindows.invalidate(windowedView); - viewByWindows.get(windowedView).set(windowValues.getValue()); - } - } catch (InterruptedException e) { - // TODO: Handle meaningfully. This should never really happen when the result remains - // useful, but the result could be available and the thread can still be interrupted. - Thread.currentThread().interrupt(); + Iterator> existingValues = future.get().iterator(); + PaneInfo newPane = windowValues.iterator().next().getPane(); + // The current value may have no elements, if no elements were produced for the window, + // but we are recieving late data. + if (!existingValues.hasNext() + || newPane.getIndex() > existingValues.next().getPane().getIndex()) { + viewByWindows.invalidate(windowedView); + viewByWindows.get(windowedView).set(windowValues); } } else { - future.set(windowValues.getValue()); + future.set(windowValues); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (future != null && !future.isDone()) { + future.set(Collections.>emptyList()); } + } catch (ExecutionException e) { + Throwables.propagate(e.getCause()); } } @@ -165,7 +188,7 @@ public T get(final PCollectionView view, final BoundedWindow window) { viewByWindows.get(windowedView); WindowingStrategy windowingStrategy = view.getWindowingStrategyInternal(); - evaluationContext.callAfterOutputMustHaveBeenProduced( + evaluationContext.scheduleAfterOutputWouldBeProduced( view, window, windowingStrategy, new Runnable() { @Override public void run() { diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java new file mode 100644 index 0000000000..23a8c0f506 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.HashSet; +import java.util.Set; + +/** + * A pipeline visitor that tracks all keyed {@link PValue PValues}. A {@link PValue} is keyed if it + * is the result of a {@link PTransform} that produces keyed outputs. A {@link PTransform} that + * produces keyed outputs is assumed to colocate output elements that share a key. + * + *

All {@link GroupByKey} transforms, or their runner-specific implementation primitive, produce + * keyed output. + */ +// TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms +// unkeyed +class KeyedPValueTrackingVisitor implements PipelineVisitor { + @SuppressWarnings("rawtypes") + private final Set> producesKeyedOutputs; + private final Set keyedValues; + private boolean finalized; + + public static KeyedPValueTrackingVisitor create( + @SuppressWarnings("rawtypes") Set> producesKeyedOutputs) { + return new KeyedPValueTrackingVisitor(producesKeyedOutputs); + } + + private KeyedPValueTrackingVisitor( + @SuppressWarnings("rawtypes") Set> producesKeyedOutputs) { + this.producesKeyedOutputs = producesKeyedOutputs; + this.keyedValues = new HashSet<>(); + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + if (node.isRootNode()) { + finalized = true; + } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) { + keyedValues.addAll(node.getExpandedOutputs()); + } + } + + @Override + public void visitTransform(TransformTreeNode node) {} + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + if (producesKeyedOutputs.contains(producer.getTransform().getClass())) { + keyedValues.addAll(value.expand()); + } + } + + public Set getKeyedPValues() { + checkState( + finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed"); + return keyedValues; + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java index e3ae1a028c..24142c2151 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java index cd79c219bd..af5914bab0 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java new file mode 100644 index 0000000000..15955724eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.common.base.MoreObjects; + +import java.util.Objects; + +/** + * A (Step, Key) pair. This is useful as a map key or cache key for things that are available + * per-step in a keyed manner (e.g. State). + */ +final class StepAndKey { + private final AppliedPTransform step; + private final Object key; + + /** + * Create a new {@link StepAndKey} with the provided step and key. + */ + public static StepAndKey of(AppliedPTransform step, Object key) { + return new StepAndKey(step, key); + } + + private StepAndKey(AppliedPTransform step, Object key) { + this.step = step; + this.key = key; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(StepAndKey.class) + .add("step", step.getFullName()) + .add("key", key) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(step, key); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (!(other instanceof StepAndKey)) { + return false; + } else { + StepAndKey that = (StepAndKey) other; + return Objects.equals(this.step, that.step) + && Objects.equals(this.key, that.key); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java index 3b672e0def..860ddfe48f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java @@ -16,7 +16,6 @@ package com.google.cloud.dataflow.sdk.runners.inprocess; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java new file mode 100644 index 0000000000..0c8cb7e80a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Flatten.FlattenPCollectionList; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A {@link TransformEvaluatorFactory} that delegates to primitive {@link TransformEvaluatorFactory} + * implementations based on the type of {@link PTransform} of the application. + */ +class TransformEvaluatorRegistry implements TransformEvaluatorFactory { + public static TransformEvaluatorRegistry defaultRegistry() { + @SuppressWarnings("rawtypes") + ImmutableMap, TransformEvaluatorFactory> primitives = + ImmutableMap., TransformEvaluatorFactory>builder() + .put(Read.Bounded.class, new BoundedReadEvaluatorFactory()) + .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory()) + .put(ParDo.Bound.class, new ParDoSingleEvaluatorFactory()) + .put(ParDo.BoundMulti.class, new ParDoMultiEvaluatorFactory()) + .put( + GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly.class, + new GroupByKeyEvaluatorFactory()) + .put(FlattenPCollectionList.class, new FlattenEvaluatorFactory()) + .put(ViewEvaluatorFactory.WriteView.class, new ViewEvaluatorFactory()) + .build(); + return new TransformEvaluatorRegistry(primitives); + } + + // the TransformEvaluatorFactories can construct instances of all generic types of transform, + // so all instances of a primitive can be handled with the same evaluator factory. + @SuppressWarnings("rawtypes") + private final Map, TransformEvaluatorFactory> factories; + + private TransformEvaluatorRegistry( + @SuppressWarnings("rawtypes") + Map, TransformEvaluatorFactory> factories) { + this.factories = factories; + } + + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + @Nullable CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) + throws Exception { + TransformEvaluatorFactory factory = factories.get(application.getTransform().getClass()); + return factory.forApplication(application, inputBundle, evaluationContext); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java new file mode 100644 index 0000000000..d630749387 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.common.base.Throwables; + +import java.util.concurrent.Callable; + +import javax.annotation.Nullable; + +/** + * A {@link Callable} responsible for constructing a {@link TransformEvaluator} from a + * {@link TransformEvaluatorFactory} and evaluating it on some bundle of input, and registering + * the result using a registered {@link CompletionCallback}. + * + *

A {@link TransformExecutor} that is currently executing also provides access to the thread + * that it is being executed on. + */ +class TransformExecutor implements Callable { + public static TransformExecutor create( + TransformEvaluatorFactory factory, + InProcessEvaluationContext evaluationContext, + CommittedBundle inputBundle, + AppliedPTransform transform, + CompletionCallback completionCallback, + TransformExecutorService transformEvaluationState) { + return new TransformExecutor<>( + factory, + evaluationContext, + inputBundle, + transform, + completionCallback, + transformEvaluationState); + } + + private final TransformEvaluatorFactory evaluatorFactory; + private final InProcessEvaluationContext evaluationContext; + + /** The transform that will be evaluated. */ + private final AppliedPTransform transform; + /** The inputs this {@link TransformExecutor} will deliver to the transform. */ + private final CommittedBundle inputBundle; + + private final CompletionCallback onComplete; + private final TransformExecutorService transformEvaluationState; + + private Thread thread; + + private TransformExecutor( + TransformEvaluatorFactory factory, + InProcessEvaluationContext evaluationContext, + CommittedBundle inputBundle, + AppliedPTransform transform, + CompletionCallback completionCallback, + TransformExecutorService transformEvaluationState) { + this.evaluatorFactory = factory; + this.evaluationContext = evaluationContext; + + this.inputBundle = inputBundle; + this.transform = transform; + + this.onComplete = completionCallback; + + this.transformEvaluationState = transformEvaluationState; + } + + @Override + public InProcessTransformResult call() { + this.thread = Thread.currentThread(); + try { + TransformEvaluator evaluator = + evaluatorFactory.forApplication(transform, inputBundle, evaluationContext); + if (inputBundle != null) { + for (WindowedValue value : inputBundle.getElements()) { + evaluator.processElement(value); + } + } + InProcessTransformResult result = evaluator.finishBundle(); + onComplete.handleResult(inputBundle, result); + return result; + } catch (Throwable t) { + onComplete.handleThrowable(inputBundle, t); + throw Throwables.propagate(t); + } finally { + this.thread = null; + transformEvaluationState.complete(this); + } + } + + /** + * If this {@link TransformExecutor} is currently executing, return the thread it is executing in. + * Otherwise, return null. + */ + @Nullable + public Thread getThread() { + return this.thread; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java new file mode 100644 index 0000000000..3f00da6ebe --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +/** + * Schedules and completes {@link TransformExecutor TransformExecutors}, controlling concurrency as + * appropriate for the {@link StepAndKey} the executor exists for. + */ +interface TransformExecutorService { + /** + * Schedule the provided work to be eventually executed. + */ + void schedule(TransformExecutor work); + + /** + * Finish executing the provided work. This may cause additional + * {@link TransformExecutor TransformExecutors} to be evaluated. + */ + void complete(TransformExecutor completed); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java new file mode 100644 index 0000000000..34efdf694e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.common.base.MoreObjects; + +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Static factory methods for constructing instances of {@link TransformExecutorService}. + */ +final class TransformExecutorServices { + private TransformExecutorServices() { + // Do not instantiate + } + + /** + * Returns an EvaluationState that evaluates {@link TransformExecutor TransformExecutors} in + * parallel. + */ + public static TransformExecutorService parallel( + ExecutorService executor, Map, Boolean> scheduled) { + return new ParallelEvaluationState(executor, scheduled); + } + + /** + * Returns an EvaluationState that evaluates {@link TransformExecutor TransformExecutors} in + * serial. + */ + public static TransformExecutorService serial( + ExecutorService executor, Map, Boolean> scheduled) { + return new SerialEvaluationState(executor, scheduled); + } + + /** + * A {@link TransformExecutorService} with unlimited parallelism. Any {@link TransformExecutor} + * scheduled will be immediately submitted to the {@link ExecutorService}. + * + *

A principal use of this is for the evaluation of an unkeyed Step. Unkeyed computations are + * processed in parallel. + */ + private static class ParallelEvaluationState implements TransformExecutorService { + private final ExecutorService executor; + private final Map, Boolean> scheduled; + + private ParallelEvaluationState( + ExecutorService executor, Map, Boolean> scheduled) { + this.executor = executor; + this.scheduled = scheduled; + } + + @Override + public void schedule(TransformExecutor work) { + executor.submit(work); + scheduled.put(work, true); + } + + @Override + public void complete(TransformExecutor completed) { + scheduled.remove(completed); + } + } + + /** + * A {@link TransformExecutorService} with a single work queue. Any {@link TransformExecutor} + * scheduled will be placed on the work queue. Only one item of work will be submitted to the + * {@link ExecutorService} at any time. + * + *

A principal use of this is for the serial evaluation of a (Step, Key) pair. + * Keyed computations are processed serially per step. + */ + private static class SerialEvaluationState implements TransformExecutorService { + private final ExecutorService executor; + private final Map, Boolean> scheduled; + + private AtomicReference> currentlyEvaluating; + private final Queue> workQueue; + + private SerialEvaluationState( + ExecutorService executor, Map, Boolean> scheduled) { + this.scheduled = scheduled; + this.executor = executor; + this.currentlyEvaluating = new AtomicReference<>(); + this.workQueue = new ConcurrentLinkedQueue<>(); + } + + /** + * Schedules the work, adding it to the work queue if there is a bundle currently being + * evaluated and scheduling it immediately otherwise. + */ + @Override + public void schedule(TransformExecutor work) { + workQueue.offer(work); + updateCurrentlyEvaluating(); + } + + @Override + public void complete(TransformExecutor completed) { + if (!currentlyEvaluating.compareAndSet(completed, null)) { + throw new IllegalStateException( + "Finished work " + + completed + + " but could not complete due to unexpected currently executing " + + currentlyEvaluating.get()); + } + scheduled.remove(completed); + updateCurrentlyEvaluating(); + } + + private void updateCurrentlyEvaluating() { + if (currentlyEvaluating.get() == null) { + // Only synchronize if we need to update what's currently evaluating + synchronized (this) { + TransformExecutor newWork = workQueue.poll(); + if (newWork != null) { + if (currentlyEvaluating.compareAndSet(null, newWork)) { + scheduled.put(newWork, true); + executor.submit(newWork); + } else { + workQueue.offer(newWork); + } + } + } + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(SerialEvaluationState.class) + .add("currentlyEvaluating", currentlyEvaluating) + .add("workQueue", workQueue) + .toString(); + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java index 4beac337d6..549afabcc0 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java @@ -21,7 +21,6 @@ import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.PTransform; @@ -100,6 +99,16 @@ private Queue> getTransformEvaluatorQu return evaluatorQueue; } + /** + * A {@link UnboundedReadEvaluator} produces elements from an underlying {@link UnboundedSource}, + * discarding all input elements. Within the call to {@link #finishBundle()}, the evaluator + * creates the {@link UnboundedReader} and consumes some currently available input. + * + *

Calls to {@link UnboundedReadEvaluator} are not internally thread-safe, and should only be + * used by a single thread at a time. Each {@link UnboundedReadEvaluator} maintains its own + * checkpoint, and constructs its reader from the current checkpoint in each call to + * {@link #finishBundle()}. + */ private static class UnboundedReadEvaluator implements TransformEvaluator { private static final int ARBITRARY_MAX_ELEMENTS = 10; private final AppliedPTransform, Unbounded> transform; @@ -123,28 +132,29 @@ public void processElement(WindowedValue element) {} @Override public InProcessTransformResult finishBundle() throws IOException { UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); - UnboundedReader reader = - createReader( - transform.getTransform().getSource(), evaluationContext.getPipelineOptions()); - int numElements = 0; - if (reader.start()) { - do { - output.add( - WindowedValue.timestampedValueInGlobalWindow( - reader.getCurrent(), reader.getCurrentTimestamp())); - numElements++; - } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); + try (UnboundedReader reader = + createReader( + transform.getTransform().getSource(), evaluationContext.getPipelineOptions());) { + int numElements = 0; + if (reader.start()) { + do { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + numElements++; + } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); + } + checkpointMark = reader.getCheckpointMark(); + checkpointMark.finalizeCheckpoint(); + // TODO: When exercising create initial splits, make this the minimum watermark across all + // existing readers + StepTransformResult result = + StepTransformResult.withHold(transform, reader.getWatermark()) + .addOutput(output) + .build(); + evaluatorQueue.offer(this); + return result; } - checkpointMark = reader.getCheckpointMark(); - checkpointMark.finalizeCheckpoint(); - // TODO: When exercising create initial splits, make this the minimum watermark across all - // existing readers - StepTransformResult result = - StepTransformResult.withHold(transform, reader.getWatermark()) - .addOutput(output) - .build(); - evaluatorQueue.offer(this); - return result; } private UnboundedReader createReader( diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java index f47cd1de98..314d81f6aa 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.coders.VoidCoder; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java new file mode 100644 index 0000000000..27d59b9a64 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.common.collect.ComparisonChain; +import com.google.common.collect.Ordering; + +import org.joda.time.Instant; + +import java.util.PriorityQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * Executes callbacks that occur based on the progression of the watermark per-step. + * + *

Callbacks are registered by calls to + * {@link #callOnGuaranteedFiring(AppliedPTransform, BoundedWindow, WindowingStrategy, Runnable)}, + * and are executed after a call to {@link #fireForWatermark(AppliedPTransform, Instant)} with the + * same {@link AppliedPTransform} and a watermark sufficient to ensure that the trigger for the + * windowing strategy would have been produced. + * + *

NOTE: {@link WatermarkCallbackExecutor} does not track the latest observed watermark for any + * {@link AppliedPTransform} - any call to + * {@link #callOnGuaranteedFiring(AppliedPTransform, BoundedWindow, WindowingStrategy, Runnable)} + * that could have potentially already fired should be followed by a call to + * {@link #fireForWatermark(AppliedPTransform, Instant)} for the same transform with the current + * value of the watermark. + */ +class WatermarkCallbackExecutor { + /** + * Create a new {@link WatermarkCallbackExecutor}. + */ + public static WatermarkCallbackExecutor create() { + return new WatermarkCallbackExecutor(); + } + + private final ConcurrentMap, PriorityQueue> + callbacks; + private final ExecutorService executor; + + private WatermarkCallbackExecutor() { + this.callbacks = new ConcurrentHashMap<>(); + this.executor = Executors.newSingleThreadExecutor(); + } + + /** + * Execute the provided {@link Runnable} after the next call to + * {@link #fireForWatermark(AppliedPTransform, Instant)} where the window is guaranteed to have + * produced output. + */ + public void callOnGuaranteedFiring( + AppliedPTransform step, + BoundedWindow window, + WindowingStrategy windowingStrategy, + Runnable runnable) { + WatermarkCallback callback = + WatermarkCallback.onGuaranteedFiring(window, windowingStrategy, runnable); + + PriorityQueue callbackQueue = callbacks.get(step); + if (callbackQueue == null) { + callbackQueue = new PriorityQueue<>(11, new CallbackOrdering()); + if (callbacks.putIfAbsent(step, callbackQueue) != null) { + callbackQueue = callbacks.get(step); + } + } + + synchronized (callbackQueue) { + callbackQueue.offer(callback); + } + } + + /** + * Schedule all pending callbacks that must have produced output by the time of the provided + * watermark. + */ + public void fireForWatermark(AppliedPTransform step, Instant watermark) { + PriorityQueue callbackQueue = callbacks.get(step); + if (callbackQueue == null) { + return; + } + synchronized (callbackQueue) { + while (!callbackQueue.isEmpty() && callbackQueue.peek().shouldFire(watermark)) { + executor.submit(callbackQueue.poll().getCallback()); + } + } + } + + private static class WatermarkCallback { + public static WatermarkCallback onGuaranteedFiring( + BoundedWindow window, WindowingStrategy strategy, Runnable callback) { + @SuppressWarnings("unchecked") + Instant firingAfter = + strategy.getTrigger().getSpec().getWatermarkThatGuaranteesFiring((W) window); + return new WatermarkCallback(firingAfter, callback); + } + + private final Instant fireAfter; + private final Runnable callback; + + private WatermarkCallback(Instant fireAfter, Runnable callback) { + this.fireAfter = fireAfter; + this.callback = callback; + } + + public boolean shouldFire(Instant currentWatermark) { + return currentWatermark.isAfter(fireAfter) + || currentWatermark.equals(BoundedWindow.TIMESTAMP_MAX_VALUE); + } + + public Runnable getCallback() { + return callback; + } + } + + private static class CallbackOrdering extends Ordering { + @Override + public int compare(WatermarkCallback left, WatermarkCallback right) { + return ComparisonChain.start() + .compare(left.fireAfter, right.fireAfter) + .compare(left.callback, right.callback, Ordering.arbitrary()) + .result(); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java index cc0347a124..b8d20e303f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java @@ -1690,21 +1690,9 @@ public List> getSideInputs() { @Override public PCollection> apply(PCollection> input) { - if (fn instanceof RequiresContextInternal) { - return input - .apply(GroupByKey.create(fewKeys)) - .apply(ParDo.of(new DoFn>, KV>>() { - @Override - public void processElement(ProcessContext c) throws Exception { - c.output(c.element()); - } - })) - .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); - } else { - return input - .apply(GroupByKey.create(fewKeys)) - .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); - } + return input + .apply(GroupByKey.create(fewKeys)) + .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java new file mode 100644 index 0000000000..656c010d91 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java @@ -0,0 +1,1100 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Static utility methods that create combine function instances. + */ +public class CombineFns { + + /** + * Returns a {@link ComposeKeyedCombineFnBuilder} to construct a composed + * {@link PerKeyCombineFn}. + * + *

The same {@link TupleTag} cannot be used in a composition multiple times. + * + *

Example: + *

{ @code
+   * PCollection> latencies = ...;
+   *
+   * TupleTag maxLatencyTag = new TupleTag();
+   * TupleTag meanLatencyTag = new TupleTag();
+   *
+   * SimpleFunction identityFn =
+   *     new SimpleFunction() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection> maxAndMean = latencies.apply(
+   *     Combine.perKey(
+   *         CombineFns.composeKeyed()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn(), meanLatencyTag)));
+   *
+   * PCollection finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn, T>() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             KV e = c.element();
+   *             Integer maxLatency = e.getValue().get(maxLatencyTag);
+   *             Double meanLatency = e.getValue().get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } 
+ */ + public static ComposeKeyedCombineFnBuilder composeKeyed() { + return new ComposeKeyedCombineFnBuilder(); + } + + /** + * Returns a {@link ComposeCombineFnBuilder} to construct a composed + * {@link GlobalCombineFn}. + * + *

The same {@link TupleTag} cannot be used in a composition multiple times. + * + *

Example: + *

{ @code
+   * PCollection globalLatencies = ...;
+   *
+   * TupleTag maxLatencyTag = new TupleTag();
+   * TupleTag meanLatencyTag = new TupleTag();
+   *
+   * SimpleFunction identityFn =
+   *     new SimpleFunction() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection maxAndMean = globalLatencies.apply(
+   *     Combine.globally(
+   *         CombineFns.compose()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn(), meanLatencyTag)));
+   *
+   * PCollection finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             CoCombineResult e = c.element();
+   *             Integer maxLatency = e.get(maxLatencyTag);
+   *             Double meanLatency = e.get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } 
+ */ + public static ComposeCombineFnBuilder compose() { + return new ComposeCombineFnBuilder(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A builder class to construct a composed {@link PerKeyCombineFn}. + */ + public static class ComposeKeyedCombineFnBuilder { + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedKeyedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + KeyedCombineFn keyedCombineFn, + TupleTag outputTag) { + return new ComposedKeyedCombineFn() + .with(extractInputFn, keyedCombineFn, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedKeyedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + KeyedCombineFnWithContext keyedCombineFnWithContext, + TupleTag outputTag) { + return new ComposedKeyedCombineFnWithContext() + .with(extractInputFn, keyedCombineFnWithContext, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + return with(extractInputFn, combineFn.asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFnWithContext, + TupleTag outputTag) { + return with(extractInputFn, combineFnWithContext.asKeyedFn(), outputTag); + } + } + + /** + * A builder class to construct a composed {@link GlobalCombineFn}. + */ + public static class ComposeCombineFnBuilder { + /** + * Returns a {@link ComposedCombineFn} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + return new ComposedCombineFn() + .with(extractInputFn, combineFn, outputTag); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFnWithContext, + TupleTag outputTag) { + return new ComposedCombineFnWithContext() + .with(extractInputFn, combineFnWithContext, outputTag); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A tuple of outputs produced by a composed combine functions. + * + *

See {@link #compose()} or {@link #composeKeyed()}) for details. + */ + public static class CoCombineResult implements Serializable { + + private enum NullValue { + INSTANCE; + } + + private final Map, Object> valuesMap; + + /** + * The constructor of {@link CoCombineResult}. + * + *

Null values should have been filtered out from the {@code valuesMap}. + * {@link TupleTag TupleTags} that associate with null values doesn't exist in the key set of + * {@code valuesMap}. + * + * @throws NullPointerException if any key or value in {@code valuesMap} is null + */ + CoCombineResult(Map, Object> valuesMap) { + ImmutableMap.Builder, Object> builder = ImmutableMap.builder(); + for (Entry, Object> entry : valuesMap.entrySet()) { + if (entry.getValue() != null) { + builder.put(entry); + } else { + builder.put(entry.getKey(), NullValue.INSTANCE); + } + } + this.valuesMap = builder.build(); + } + + /** + * Returns the value represented by the given {@link TupleTag}. + * + *

It is an error to request a non-exist tuple tag from the {@link CoCombineResult}. + */ + @SuppressWarnings("unchecked") + public V get(TupleTag tag) { + checkArgument( + valuesMap.keySet().contains(tag), "TupleTag " + tag + " is not in the CoCombineResult"); + Object value = valuesMap.get(tag); + if (value == NullValue.INSTANCE) { + return null; + } else { + return (V) value; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A composed {@link CombineFn} that applies multiple {@link CombineFn CombineFns}. + * + *

For each {@link CombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFn extends CombineFn { + + private final List> combineFns; + private final List> extractInputFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.combineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFn( + ImmutableList> extractInputFns, + ImmutableList> combineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedCombineFns = (List) combineFns; + this.combineFns = castedCombineFns; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFns.size(); + } + + /** + * Returns a {@link ComposedCombineFn} with an additional {@link CombineFn}. + */ + public ComposedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFn<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(combineFns) + .add(combineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + List> fnsWithContext = Lists.newArrayList(); + for (CombineFn fn : combineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(fnsWithContext) + .add(combineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator() { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFns.get(i).createAccumulator(); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFns.get(i).addInput(accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFns.get(i).extractOutput(accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFns.get(i).compact(accumulator[i]); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFns.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link CombineFnWithContext} that applies multiple + * {@link CombineFnWithContext CombineFnWithContexts}. + * + *

For each {@link CombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFnWithContext + extends CombineFnWithContext { + + private final List> extractInputFns; + private final List> combineFnWithContexts; + private final List> outputTags; + private final int combineFnCount; + + private ComposedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.combineFnWithContexts = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFnWithContext( + ImmutableList> extractInputFns, + ImmutableList> combineFnWithContexts, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"rawtypes", "unchecked"}) + List> castedCombineFnWithContexts + = (List) combineFnWithContexts; + this.combineFnWithContexts = castedCombineFnWithContexts; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFnWithContexts.size(); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional {@link GlobalCombineFn}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + GlobalCombineFn globalCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(combineFnWithContexts) + .add(toFnWithContext(globalCombineFn)) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator(Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFnWithContexts.get(i).createAccumulator(c); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFnWithContexts.get(i).addInput(accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable accumulators, Context c) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFnWithContexts.get(i).mergeAccumulators( + new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator, Context c) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFnWithContexts.get(i).extractOutput(accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFnWithContexts.get(i).compact(accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFnWithContexts.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFn} that applies multiple {@link KeyedCombineFn KeyedCombineFns}. + * + *

For each {@link KeyedCombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFn + extends KeyedCombineFn { + + private final List> extractInputFns; + private final List> keyedCombineFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFn( + ImmutableList> extractInputFns, + ImmutableList> keyedCombineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link KeyedCombineFn}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + KeyedCombineFn keyedCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFn<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(keyedCombineFns) + .add(keyedCombineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link KeyedCombineFnWithContext}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + KeyedCombineFnWithContext keyedCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + List> fnsWithContext = + Lists.newArrayList(); + for (KeyedCombineFn fn : keyedCombineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(fnsWithContext) + .add(keyedCombineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link CombineFn}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + CombineFn keyedCombineFn, + TupleTag outputTag) { + return with(extractInputFn, keyedCombineFn.asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext keyedCombineFn, + TupleTag outputTag) { + return with(extractInputFn, keyedCombineFn.asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, final Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i]); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder(registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFnWithContext} that applies multiple + * {@link KeyedCombineFnWithContext KeyedCombineFnWithContexts}. + * + *

For each {@link KeyedCombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFnWithContext + extends KeyedCombineFnWithContext { + + private final List> extractInputFns; + private final List> keyedCombineFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFnWithContext( + ImmutableList> extractInputFns, + ImmutableList> keyedCombineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link PerKeyCombineFn}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + PerKeyCombineFn perKeyCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(keyedCombineFns) + .add(toFnWithContext(perKeyCombineFn)) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link GlobalCombineFn}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + GlobalCombineFn perKeyCombineFn, + TupleTag outputTag) { + return with(extractInputFn, perKeyCombineFn.asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key, Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key, c); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, Iterable accumulators, Context c) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key, c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator, Context c) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder( + registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private static class ProjectionIterable implements Iterable { + private final Iterable iterable; + private final int column; + + private ProjectionIterable(Iterable iterable, int column) { + this.iterable = iterable; + this.column = column; + } + + @Override + public Iterator iterator() { + final Iterator iter = iterable.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Object next() { + return iter.next()[column]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + } + + private static class ComposedAccumulatorCoder extends StandardCoder { + private List> coders; + private int codersCount; + + public ComposedAccumulatorCoder(List> coders) { + this.coders = ImmutableList.copyOf(coders); + this.codersCount = coders.size(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @JsonCreator + public static ComposedAccumulatorCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + return new ComposedAccumulatorCoder((List) components); + } + + @Override + public void encode(Object[] value, OutputStream outStream, Context context) + throws CoderException, IOException { + checkArgument(value.length == codersCount); + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + coders.get(i).encode(value[i], outStream, nestedContext); + } + } + + @Override + public Object[] decode(InputStream inStream, Context context) + throws CoderException, IOException { + Object[] ret = new Object[codersCount]; + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + ret[i] = coders.get(i).decode(inStream, nestedContext); + } + return ret; + } + + @Override + public List> getCoderArguments() { + return coders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + for (int i = 0; i < codersCount; ++i) { + coders.get(i).verifyDeterministic(); + } + } + } + + @SuppressWarnings("unchecked") + private static CombineFnWithContext + toFnWithContext(GlobalCombineFn globalCombineFn) { + if (globalCombineFn instanceof CombineFnWithContext) { + return (CombineFnWithContext) globalCombineFn; + } else { + final CombineFn combineFn = + (CombineFn) globalCombineFn; + return new CombineFnWithContext() { + @Override + public AccumT createAccumulator(Context c) { + return combineFn.createAccumulator(); + } + @Override + public AccumT addInput(AccumT accumulator, InputT input, Context c) { + return combineFn.addInput(accumulator, input); + } + @Override + public AccumT mergeAccumulators(Iterable accumulators, Context c) { + return combineFn.mergeAccumulators(accumulators); + } + @Override + public OutputT extractOutput(AccumT accumulator, Context c) { + return combineFn.extractOutput(accumulator); + } + @Override + public AccumT compact(AccumT accumulator, Context c) { + return combineFn.compact(accumulator); + } + @Override + public OutputT defaultValue() { + return combineFn.defaultValue(); + } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return combineFn.getAccumulatorCoder(registry, inputCoder); + } + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { + return combineFn.getDefaultOutputCoder(registry, inputCoder); + } + }; + } + } + + private static KeyedCombineFnWithContext + toFnWithContext(PerKeyCombineFn perKeyCombineFn) { + if (perKeyCombineFn instanceof KeyedCombineFnWithContext) { + @SuppressWarnings("unchecked") + KeyedCombineFnWithContext keyedCombineFnWithContext = + (KeyedCombineFnWithContext) perKeyCombineFn; + return keyedCombineFnWithContext; + } else { + @SuppressWarnings("unchecked") + final KeyedCombineFn keyedCombineFn = + (KeyedCombineFn) perKeyCombineFn; + return new KeyedCombineFnWithContext() { + @Override + public AccumT createAccumulator(K key, Context c) { + return keyedCombineFn.createAccumulator(key); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) { + return keyedCombineFn.addInput(key, accumulator, value); + } + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, Context c) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return keyedCombineFn.extractOutput(key, accumulator); + } + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return keyedCombineFn.compact(key, accumulator); + } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + }; + } + } + + private static void checkUniqueness( + List> registeredTags, TupleTag outputTag) { + checkArgument( + !registeredTags.contains(outputTag), + "Cannot compose with tuple tag %s because it is already present in the composition.", + outputTag); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java index af06cc8796..5ba9992143 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java @@ -24,6 +24,8 @@ import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.transforms.display.HasDisplayData; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; import com.google.cloud.dataflow.sdk.util.WindowingInternals; @@ -69,7 +71,7 @@ * @param the type of the (main) input elements * @param the type of the (main) output elements */ -public abstract class DoFn implements Serializable { +public abstract class DoFn implements Serializable, HasDisplayData { /** * Information accessible to all methods in this {@code DoFn}. @@ -366,6 +368,15 @@ public void startBundle(Context c) throws Exception { public void finishBundle(Context c) throws Exception { } + /** + * {@inheritDoc} + * + *

By default, does not register any display data. Implementors may override this method + * to provide their own display metadata. + */ + @Override + public void populateDisplayData(DisplayData.Builder builder) { + } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java index 8a7450997a..d4496b8c74 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java @@ -19,6 +19,8 @@ import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.transforms.display.HasDisplayData; import com.google.cloud.dataflow.sdk.util.StringUtils; import com.google.cloud.dataflow.sdk.values.PInput; import com.google.cloud.dataflow.sdk.values.POutput; @@ -168,7 +170,7 @@ * @param the type of the output of this PTransform */ public abstract class PTransform - implements Serializable /* See the note above */ { + implements Serializable /* See the note above */, HasDisplayData { /** * Applies this {@code PTransform} on the given {@code InputT}, and returns its * {@code Output}. @@ -309,4 +311,14 @@ public Coder getDefaultOutputCoder( Coder defaultOutputCoder = (Coder) getDefaultOutputCoder(input); return defaultOutputCoder; } + + /** + * {@inheritDoc} + * + *

By default, does not register any display data. Implementors may override this method + * to provide their own display metadata. + */ + @Override + public void populateDisplayData(Builder builder) { + } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java index 0922767adc..c77ac4447c 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java @@ -22,6 +22,7 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderException; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; import com.google.cloud.dataflow.sdk.util.DirectSideInputReader; @@ -787,6 +788,18 @@ protected String getKindString() { } } + /** + * {@inheritDoc} + * + *

{@link ParDo} registers its internal {@link DoFn} as a subcomponent for display metadata. + * {@link DoFn} implementations can register display data by overriding + * {@link DoFn#populateDisplayData}. + */ + @Override + public void populateDisplayData(Builder builder) { + builder.include(fn); + } + public DoFn getFn() { return fn; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java new file mode 100644 index 0000000000..05fa7c7881 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java @@ -0,0 +1,517 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + +import org.apache.avro.reflect.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; + +import java.util.Collection; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Static display metadata associated with a pipeline component. Display data is useful for + * pipeline runner UIs and diagnostic dashboards to display details about + * {@link PTransform PTransforms} that make up a pipeline. + * + *

Components specify their display data by implementing the {@link HasDisplayData} + * interface. + */ +public class DisplayData { + private static final DisplayData EMPTY = new DisplayData(Maps.newHashMap()); + private static final DateTimeFormatter TIMESTAMP_FORMATTER = ISODateTimeFormat.dateTime(); + + private final ImmutableMap entries; + + private DisplayData(Map entries) { + this.entries = ImmutableMap.copyOf(entries); + } + + /** + * Default empty {@link DisplayData} instance. + */ + public static DisplayData none() { + return EMPTY; + } + + /** + * Collect the {@link DisplayData} from a component. This will traverse all subcomponents + * specified via {@link Builder#include} in the given component. Data in this component will be in + * a namespace derived from the component. + */ + public static DisplayData from(HasDisplayData component) { + checkNotNull(component); + return InternalBuilder.forRoot(component).build(); + } + + public Collection items() { + return entries.values(); + } + + public Map asMap() { + return entries; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + boolean isFirstLine = true; + for (Map.Entry entry : entries.entrySet()) { + if (isFirstLine) { + isFirstLine = false; + } else { + builder.append("\n"); + } + + builder.append(entry); + } + + return builder.toString(); + } + + /** + * Utility to build up display metadata from a component and its included + * subcomponents. + */ + public interface Builder { + /** + * Include display metadata from the specified subcomponent. For example, a {@link ParDo} + * transform includes display metadata from the encapsulated {@link DoFn}. + * + * @return A builder instance to continue to build in a fluent-style. + */ + Builder include(HasDisplayData subComponent); + + /** + * Register the given string display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#STRING}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, String value); + + /** + * Register the given numeric display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#INTEGER}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, long value); + + /** + * Register the given floating point display metadata. The metadata item will be registered with + * type {@link DisplayData.Type#FLOAT}, and is identified by the specified key and namespace + * from the current transform or component. + */ + ItemBuilder add(String key, double value); + + /** + * Register the given timestamp display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#TIMESTAMP}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, Instant value); + + /** + * Register the given duration display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#DURATION}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, Duration value); + + /** + * Register the given class display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#JAVA_CLASS}, and is identified by the specified key and namespace + * from the current transform or component. + */ + ItemBuilder add(String key, Class value); + } + + /** + * Utility to append optional fields to display metadata, or register additional display metadata + * items. + */ + public interface ItemBuilder extends Builder { + /** + * Add a human-readable label to describe the most-recently added metadata field. + * A label is optional; if unspecified, UIs should display the metadata key to identify the + * display item. + * + *

Specifying a null value will clear the label if it was previously defined. + */ + ItemBuilder withLabel(@Nullable String label); + + /** + * Add a link URL to the most-recently added display metadata. A link URL is optional and + * can be provided to point the reader to additional details about the metadata. + * + *

Specifying a null value will clear the URL if it was previously defined. + */ + ItemBuilder withLinkUrl(@Nullable String url); + } + + /** + * A display metadata item. DisplayData items are registered via {@link Builder#add} within + * {@link HasDisplayData#populateDisplayData} implementations. Each metadata item is uniquely + * identified by the specified key and namespace generated from the registering component's + * class name. + */ + public static class Item { + private final String key; + private final String ns; + private final Type type; + private final String value; + private final String shortValue; + private final String label; + private final String url; + + private static Item create(String namespace, String key, Type type, T value) { + FormattedItemValue formatted = type.format(value); + return new Item( + namespace, key, type, formatted.getLongValue(), formatted.getShortValue(), null, null); + } + + private Item( + String namespace, + String key, + Type type, + String value, + String shortValue, + String url, + String label) { + this.ns = namespace; + this.key = key; + this.type = type; + this.value = value; + this.shortValue = shortValue; + this.url = url; + this.label = label; + } + + public String getNamespace() { + return ns; + } + + public String getKey() { + return key; + } + + /** + * Retrieve the {@link DisplayData.Type} of display metadata. All metadata conforms to a + * predefined set of allowed types. + */ + public Type getType() { + return type; + } + + /** + * Retrieve the value of the metadata item. + */ + public String getValue() { + return value; + } + + /** + * Return the optional short value for an item. Types may provide a short-value to displayed + * instead of or in addition to the full {@link Item#value}. + * + *

Some display data types will not provide a short value, in which case the return value + * will be null. + */ + @Nullable + public String getShortValue() { + return shortValue; + } + + /** + * Retrieve the optional label for an item. The label is a human-readable description of what + * the metadata represents. UIs may choose to display the label instead of the item key. + * + *

If no label was specified, this will return {@code null}. + */ + @Nullable + public String getLabel() { + return label; + } + + /** + * Retrieve the optional link URL for an item. The URL points to an address where the reader + * can find additional context for the display metadata. + * + *

If no URL was specified, this will return {@code null}. + */ + @Nullable + public String getUrl() { + return url; + } + + @Override + public String toString() { + return getValue(); + } + + private Item withLabel(String label) { + return new Item(this.ns, this.key, this.type, this.value, this.shortValue, this.url, label); + } + + private Item withUrl(String url) { + return new Item(this.ns, this.key, this.type, this.value, this.shortValue, url, this.label); + } + } + + /** + * Unique identifier for a display metadata item within a component. + * Identifiers are composed of the key they are registered with and a namespace generated from + * the class of the component which registered the item. + * + *

Display metadata registered with the same key from different components will have different + * namespaces and thus will both be represented in the composed {@link DisplayData}. If a + * single component registers multiple metadata items with the same key, only the most recent + * item will be retained; previous versions are discarded. + */ + public static class Identifier { + private final String ns; + private final String key; + + static Identifier of(Class namespace, String key) { + return new Identifier(namespace.getName(), key); + } + + private Identifier(String ns, String key) { + this.ns = ns; + this.key = key; + } + + public String getNamespace() { + return ns; + } + + public String getKey() { + return key; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof Identifier) { + Identifier that = (Identifier) obj; + return Objects.equals(this.ns, that.ns) + && Objects.equals(this.key, that.key); + } + + return false; + } + + @Override + public int hashCode() { + return Objects.hash(ns, key); + } + + @Override + public String toString() { + return String.format("%s:%s", ns, key); + } + } + + /** + * Display metadata type. + */ + enum Type { + STRING { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue((String) value); + } + }, + INTEGER { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Long.toString((long) value)); + } + }, + FLOAT { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Double.toString((Double) value)); + } + }, + TIMESTAMP() { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue((TIMESTAMP_FORMATTER.print((Instant) value))); + } + }, + DURATION { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Long.toString(((Duration) value).getMillis())); + } + }, + JAVA_CLASS { + @Override + FormattedItemValue format(Object value) { + Class clazz = (Class) value; + return new FormattedItemValue(clazz.getName(), clazz.getSimpleName()); + } + }; + + /** + * Format the display metadata value into a long string representation, and optionally + * a shorter representation for display. + * + *

Internal-only. Value objects can be safely cast to the expected Java type. + */ + abstract FormattedItemValue format(Object value); + } + + private static class FormattedItemValue { + private final String shortValue; + private final String longValue; + + private FormattedItemValue(String longValue) { + this(longValue, null); + } + + private FormattedItemValue(String longValue, String shortValue) { + this.longValue = longValue; + this.shortValue = shortValue; + } + + private String getLongValue () { + return this.longValue; + } + + private String getShortValue() { + return this.shortValue; + } + } + + private static class InternalBuilder implements ItemBuilder { + private final Map entries; + private final Set visited; + + private Class latestNs; + private Item latestItem; + private Identifier latestIdentifier; + + private InternalBuilder() { + this.entries = Maps.newHashMap(); + this.visited = Sets.newIdentityHashSet(); + } + + private static InternalBuilder forRoot(HasDisplayData instance) { + InternalBuilder builder = new InternalBuilder(); + builder.include(instance); + return builder; + } + + @Override + public Builder include(HasDisplayData subComponent) { + checkNotNull(subComponent); + boolean newComponent = visited.add(subComponent); + if (newComponent) { + Class prevNs = this.latestNs; + this.latestNs = subComponent.getClass(); + subComponent.populateDisplayData(this); + this.latestNs = prevNs; + } + + return this; + } + + @Override + public ItemBuilder add(String key, String value) { + checkNotNull(value); + return addItem(key, Type.STRING, value); + } + + @Override + public ItemBuilder add(String key, long value) { + return addItem(key, Type.INTEGER, value); + } + + @Override + public ItemBuilder add(String key, double value) { + return addItem(key, Type.FLOAT, value); + } + + @Override + public ItemBuilder add(String key, Instant value) { + checkNotNull(value); + return addItem(key, Type.TIMESTAMP, value); + } + + @Override + public ItemBuilder add(String key, Duration value) { + checkNotNull(value); + return addItem(key, Type.DURATION, value); + } + + @Override + public ItemBuilder add(String key, Class value) { + checkNotNull(value); + return addItem(key, Type.JAVA_CLASS, value); + } + + private ItemBuilder addItem(String key, Type type, T value) { + checkNotNull(key); + checkArgument(!key.isEmpty()); + + Identifier id = Identifier.of(latestNs, key); + if (entries.containsKey(id)) { + throw new IllegalArgumentException("DisplayData key already exists. All display data " + + "for a component must be registered with a unique key.\nKey: " + id); + } + Item item = Item.create(id.getNamespace(), key, type, value); + entries.put(id, item); + + latestItem = item; + latestIdentifier = id; + + return this; + } + + @Override + public ItemBuilder withLabel(String label) { + latestItem = latestItem.withLabel(label); + entries.put(latestIdentifier, latestItem); + return this; + } + + @Override + public ItemBuilder withLinkUrl(String url) { + latestItem = latestItem.withUrl(url); + entries.put(latestIdentifier, latestItem); + return this; + } + + private DisplayData build() { + return new DisplayData(this.entries); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java new file mode 100644 index 0000000000..b2eca3d881 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * Marker interface for {@link PTransform PTransforms} and components used within + * {@link PTransform PTransforms} to specify display metadata to be used within UIs and diagnostic + * tools. + * + *

Display metadata is optional and may be collected during pipeline construction. It should + * only be used to informational purposes. Tools and components should not assume that display data + * will always be collected, or that collected display data will always be displayed. + */ +public interface HasDisplayData { + /** + * Register display metadata for the given transform or component. Metadata can be registered + * directly on the provided builder, as well as via included sub-components. + * + *

+   * {@code
+   * @Override
+   * public void populateDisplayData(DisplayData.Builder builder) {
+   *  builder
+   *     .include(subComponent)
+   *     .add("minFilter", 42)
+   *     .add("topic", "projects/myproject/topics/mytopic")
+   *       .withLabel("Pub/Sub Topic")
+   *     .add("serviceInstance", "myservice.com/fizzbang")
+   *       .withLinkUrl("http://www.myservice.com/fizzbang");
+   * }
+   * }
+   * 
+ * + * @param builder The builder to populate with display metadata. + */ + void populateDisplayData(DisplayData.Builder builder); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java index da16db99c6..fac2c2841b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java @@ -80,7 +80,7 @@ public static FromEndOfWindow pastEndOfWindow() { public interface AfterWatermarkEarly extends TriggerBuilder { /** * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever - * the given {@code Trigger} fires before the watermark has passed the end of the window. + * the given {@code Trigger} fires after the watermark has passed the end of the window. */ TriggerBuilder withLateFirings(OnceTrigger lateTrigger); } @@ -91,7 +91,7 @@ public interface AfterWatermarkEarly extends TriggerBui public interface AfterWatermarkLate extends TriggerBuilder { /** * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever - * the given {@code Trigger} fires after the watermark has passed the end of the window. + * the given {@code Trigger} fires before the watermark has passed the end of the window. */ TriggerBuilder withEarlyFirings(OnceTrigger earlyTrigger); } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java index 5611fabe28..ec6518976b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java @@ -65,6 +65,7 @@ public class PropertyNames { public static final String INPUTS = "inputs"; public static final String INPUT_CODER = "input_coder"; public static final String IS_GENERATED = "is_generated"; + public static final String IS_MERGING_WINDOW_FN = "is_merging_window_fn"; public static final String IS_PAIR_LIKE = "is_pair_like"; public static final String IS_STREAM_LIKE = "is_stream_like"; public static final String IS_WRAPPER = "is_wrapper"; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java index e687f27989..045a8ad0f2 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java @@ -25,8 +25,12 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; import com.google.common.collect.ArrayListMultimap; @@ -824,6 +828,14 @@ public void testSettingRunner() { assertEquals(BlockingDataflowPipelineRunner.class, options.getRunner()); } + @Test + public void testSettingRunnerFullName() { + String[] args = + new String[] {String.format("--runner=%s", DataflowPipelineRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), DataflowPipelineRunner.class); + } + @Test public void testSettingUnknownRunner() { String[] args = new String[] {"--runner=UnknownRunner"}; @@ -834,6 +846,30 @@ public void testSettingUnknownRunner() { PipelineOptionsFactory.fromArgs(args).create(); } + private static class ExampleTestRunner extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + return null; + } + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedExists() { + String[] args = new String[] {String.format("--runner=%s", ExampleTestRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), ExampleTestRunner.class); + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedNotPipelineRunner() { + String[] args = new String[] {"--runner=java.lang.String"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("does not implement PipelineRunner"); + expectedException.expectMessage("java.lang.String"); + + PipelineOptionsFactory.fromArgs(args).create(); + } + @Test public void testUsingArgumentWithUnknownPropertyIsNotAllowed() { String[] args = new String[] {"--unknownProperty=value"}; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java index 9f22fbbe9e..e641dd6181 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java @@ -18,24 +18,39 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.io.Read.Bounded; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; +import org.joda.time.Instant; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; /** * Tests for {@link BoundedReadEvaluatorFactory}. @@ -45,7 +60,7 @@ public class BoundedReadEvaluatorFactoryTest { private BoundedSource source; private PCollection longs; private TransformEvaluatorFactory factory; - private InProcessEvaluationContext context; + @Mock private InProcessEvaluationContext context; @Before public void setup() { @@ -146,6 +161,125 @@ public void boundedSourceEvaluatorSimultaneousEvaluations() throws Exception { gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); } + @Test + public void boundedSourceEvaluatorClosesReader() throws Exception { + TestSource source = new TestSource<>(BigEndianLongCoder.of(), 1L, 2L, 3L); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(gw(2L), gw(3L), gw(1L))); + assertThat(TestSource.readerClosed, is(true)); + } + + @Test + public void boundedSourceEvaluatorNoElementsClosesReader() throws Exception { + TestSource source = new TestSource<>(BigEndianLongCoder.of()); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), emptyIterable()); + assertThat(TestSource.readerClosed, is(true)); + } + + private static class TestSource extends BoundedSource { + private static boolean readerClosed; + private final Coder coder; + private final T[] elems; + + public TestSource(Coder coder, T... elems) { + this.elems = elems; + this.coder = coder; + readerClosed = false; + } + + @Override + public List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return ImmutableList.of(this); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader createReader(PipelineOptions options) throws IOException { + return new TestReader<>(this, elems); + } + + @Override + public void validate() { + } + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + } + + private static class TestReader extends BoundedReader { + private final BoundedSource source; + private final List elems; + private int index; + + public TestReader(BoundedSource source, T... elems) { + this.source = source; + this.elems = Arrays.asList(elems); + this.index = -1; + } + + @Override + public BoundedSource getCurrentSource() { + return source; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (elems.size() > index + 1) { + index++; + return true; + } + return false; + } + + @Override + public T getCurrent() throws NoSuchElementException { + return elems.get(index); + } + + @Override + public void close() throws IOException { + TestSource.readerClosed = true; + } + } + private static WindowedValue gw(Long elem) { return WindowedValue.valueInGlobalWindow(elem); } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java new file mode 100644 index 0000000000..d921f6cdb5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.io.CountingInput; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Tests for {@link ConsumerTrackingPipelineVisitor}. + */ +@RunWith(JUnit4.class) +public class ConsumerTrackingPipelineVisitorTest implements Serializable { + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private transient TestPipeline p = TestPipeline.create(); + private transient ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + + @Test + public void getViewsReturnsViews() { + PCollectionView> listView = + p.apply("listCreate", Create.of("foo", "bar")) + .apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply(View.asList()); + PCollectionView singletonView = + p.apply("singletonCreate", Create.of(1, 2, 3)).apply(View.asSingleton()); + p.traverseTopologically(visitor); + assertThat( + visitor.getViews(), + Matchers.>containsInAnyOrder(listView, singletonView)); + } + + @Test + public void getRootTransformsContainsPBegins() { + PCollection created = p.apply(Create.of("foo", "bar")); + PCollection counted = p.apply(CountingInput.upTo(1234L)); + PCollection unCounted = p.apply(CountingInput.unbounded()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.>containsInAnyOrder( + created.getProducingTransformInternal(), + counted.getProducingTransformInternal(), + unCounted.getProducingTransformInternal())); + } + + @Test + public void getRootTransformsContainsEmptyFlatten() { + PCollection empty = + PCollectionList.empty(p).apply(Flatten.pCollections()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.>containsInAnyOrder( + empty.getProducingTransformInternal())); + } + + @Test + public void getValueToConsumersSucceeds() { + PCollection created = p.apply(Create.of("1", "2", "3")); + PCollection transformed = + created.apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + PCollection flattened = + PCollectionList.of(created).and(transformed).apply(Flatten.pCollections()); + + p.traverseTopologically(visitor); + + assertThat( + visitor.getValueToConsumers().get(created), + Matchers.>containsInAnyOrder( + transformed.getProducingTransformInternal(), + flattened.getProducingTransformInternal())); + assertThat( + visitor.getValueToConsumers().get(transformed), + Matchers.>containsInAnyOrder( + flattened.getProducingTransformInternal())); + assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + } + + @Test + public void getUnfinalizedPValuesContainsDanglingOutputs() { + PCollection created = p.apply(Create.of("1", "2", "3")); + PCollection transformed = + created.apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), Matchers.contains(transformed)); + } + + @Test + public void getUnfinalizedPValuesEmpty() { + p.apply(Create.of("1", "2", "3")) + .apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply( + new PTransform() { + @Override + public PDone apply(PInput input) { + return PDone.in(input.getPipeline()); + } + }); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), emptyIterable()); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply(Create.of(1, 2, 3)); + + p.traverseTopologically(visitor); + thrown.expect(IllegalStateException.class); + thrown.expectMessage(ConsumerTrackingPipelineVisitor.class.getSimpleName()); + thrown.expectMessage("is finalized"); + p.traverseTopologically(visitor); + } + + @Test + public void traverseIndependentPathsSucceeds() { + p.apply("left", Create.of(1, 2, 3)); + p.apply("right", Create.of("foo", "bar", "baz")); + + p.traverseTopologically(visitor); + } + + @Test + public void getRootTransformsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getRootTransforms"); + visitor.getRootTransforms(); + } + @Test + public void getStepNamesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getStepNames"); + visitor.getStepNames(); + } + @Test + public void getUnfinalizedPValuesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getUnfinalizedPValues"); + visitor.getUnfinalizedPValues(); + } + + @Test + public void getValueToConsumersWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getValueToConsumers"); + visitor.getValueToConsumers(); + } + + @Test + public void getViewsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getViews"); + visitor.getViews(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java index bf25970aff..0120b9880d 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java @@ -22,7 +22,6 @@ import static org.mockito.Mockito.when; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java index 5c9e824afe..4ced82f8c7 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java @@ -23,7 +23,6 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java index 24251522d8..52398cf73a 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java @@ -1047,6 +1047,18 @@ public void timerUpdateBuilderWithCompletedAfterBuildNotAddedToBuilt() { assertThat(built.getCompletedTimers(), emptyIterable()); } + @Test + public void timerUpdateWithCompletedTimersNotAddedToExisting() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.build(); + assertThat(built.getCompletedTimers(), emptyIterable()); + assertThat( + built.withCompletedTimers(ImmutableList.of(timer)).getCompletedTimers(), contains(timer)); + assertThat(built.getCompletedTimers(), emptyIterable()); + } + private static Matcher earlierThan(final Instant laterInstant) { return new BaseMatcher() { @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java new file mode 100644 index 0000000000..149096040a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java @@ -0,0 +1,436 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.BagState; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link InProcessEvaluationContext}. + */ +@RunWith(JUnit4.class) +public class InProcessEvaluationContextTest { + private TestPipeline p; + private InProcessEvaluationContext context; + private PCollection created; + private PCollection> downstream; + private PCollectionView> view; + + @Before + public void setup() { + InProcessPipelineRunner runner = + InProcessPipelineRunner.fromOptions(PipelineOptionsFactory.create()); + p = TestPipeline.create(); + created = p.apply(Create.of(1, 2, 3)); + downstream = created.apply(WithKeys.of("foo")); + view = created.apply(View.asIterable()); + Collection> rootTransforms = + ImmutableList.>of(created.getProducingTransformInternal()); + Map>> valueToConsumers = new HashMap<>(); + valueToConsumers.put( + created, + ImmutableList.>of( + downstream.getProducingTransformInternal(), view.getProducingTransformInternal())); + valueToConsumers.put(downstream, ImmutableList.>of()); + valueToConsumers.put(view, ImmutableList.>of()); + + Map, String> stepNames = new HashMap<>(); + stepNames.put(created.getProducingTransformInternal(), "s1"); + stepNames.put(downstream.getProducingTransformInternal(), "s2"); + stepNames.put(view.getProducingTransformInternal(), "s3"); + + Collection> views = ImmutableList.>of(view); + context = InProcessEvaluationContext.create( + runner.getPipelineOptions(), + rootTransforms, + valueToConsumers, + stepNames, + views); + } + + @Test + public void writeToViewWriterThenReadReads() { + PCollectionViewWriter> viewWriter = + context.createPCollectionViewWriter( + PCollection.>createPrimitiveOutputInternal( + p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED), + view); + BoundedWindow window = new TestBoundedWindow(new Instant(1024L)); + BoundedWindow second = new TestBoundedWindow(new Instant(899999L)); + WindowedValue firstValue = + WindowedValue.of(1, new Instant(1222), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); + WindowedValue secondValue = + WindowedValue.of( + 2, new Instant(8766L), second, PaneInfo.createPane(true, false, Timing.ON_TIME, 0, 0)); + Iterable> values = ImmutableList.of(firstValue, secondValue); + viewWriter.add(values); + + SideInputReader reader = + context.createSideInputReader(ImmutableList.>of(view)); + assertThat(reader.get(view, window), containsInAnyOrder(1)); + assertThat(reader.get(view, second), containsInAnyOrder(2)); + + WindowedValue overrittenSecondValue = + WindowedValue.of( + 4444, new Instant(8677L), second, PaneInfo.createPane(false, true, Timing.LATE, 1, 1)); + viewWriter.add(Collections.singleton(overrittenSecondValue)); + assertThat(reader.get(view, second), containsInAnyOrder(4444)); + } + + @Test + public void getExecutionContextSameStepSameKeyState() { + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + InProcessStepContext stepContext = fooContext.getOrCreateStepContext("s1", "s1", null); + stepContext.stateInternals().state(StateNamespaces.global(), intBag).add(1); + + context.handleResult( + InProcessBundle.keyed(created, "foo").commit(Instant.now()), + ImmutableList.of(), + StepTransformResult.withoutHold(created.getProducingTransformInternal()) + .withState(stepContext.commitState()) + .build()); + + InProcessExecutionContext secondFooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + assertThat( + secondFooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + contains(1)); + } + + + @Test + public void getExecutionContextDifferentKeysIndependentState() { + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + fooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .add(1); + + InProcessExecutionContext barContext = + context.getExecutionContext(created.getProducingTransformInternal(), "bar"); + assertThat(barContext, not(equalTo(fooContext))); + assertThat( + barContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + emptyIterable()); + } + + @Test + public void getExecutionContextDifferentStepsIndependentState() { + String myKey = "foo"; + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), myKey); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + fooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .add(1); + + InProcessExecutionContext barContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + assertThat( + barContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + emptyIterable()); + } + + @Test + public void handleResultMergesCounters() { + CounterSet counters = context.createCounterSet(); + Counter myCounter = Counter.longs("foo", AggregationKind.SUM); + counters.addCounter(myCounter); + + myCounter.addValue(4L); + InProcessTransformResult result = + StepTransformResult.withoutHold(created.getProducingTransformInternal()) + .withCounters(counters) + .build(); + context.handleResult(null, ImmutableList.of(), result); + assertThat((Long) context.getCounters().getExistingCounter("foo").getAggregate(), equalTo(4L)); + + CounterSet againCounters = context.createCounterSet(); + Counter myLongCounterAgain = Counter.longs("foo", AggregationKind.SUM); + againCounters.add(myLongCounterAgain); + myLongCounterAgain.addValue(8L); + + InProcessTransformResult secondResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withCounters(againCounters) + .build(); + context.handleResult( + InProcessBundle.unkeyed(created).commit(Instant.now()), + ImmutableList.of(), + secondResult); + assertThat((Long) context.getCounters().getExistingCounter("foo").getAggregate(), equalTo(12L)); + } + + @Test + public void handleResultStoresState() { + String myKey = "foo"; + InProcessExecutionContext fooContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + CopyOnAccessInMemoryStateInternals state = + fooContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + BagState bag = state.state(StateNamespaces.global(), intBag); + bag.add(1); + bag.add(2); + bag.add(4); + + InProcessTransformResult stateResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withState(state) + .build(); + + context.handleResult( + InProcessBundle.keyed(created, myKey).commit(Instant.now()), + ImmutableList.of(), + stateResult); + + InProcessExecutionContext afterResultContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + + CopyOnAccessInMemoryStateInternals afterResultState = + afterResultContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + assertThat(afterResultState.state(StateNamespaces.global(), intBag).read(), contains(1, 2, 4)); + } + + @Test + public void callAfterOutputMustHaveBeenProducedAfterEndOfWatermarkCallsback() throws Exception { + final CountDownLatch callLatch = new CountDownLatch(1); + Runnable callback = + new Runnable() { + @Override + public void run() { + callLatch.countDown(); + } + }; + + // Should call back after the end of the global window + context.scheduleAfterOutputWouldBeProduced( + downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback); + + InProcessTransformResult result = + StepTransformResult.withHold(created.getProducingTransformInternal(), new Instant(0)) + .build(); + + context.handleResult(null, ImmutableList.of(), result); + + // Difficult to demonstrate that we took no action in a multithreaded world; poll for a bit + // will likely be flaky if this logic is broken + assertThat(callLatch.await(500L, TimeUnit.MILLISECONDS), is(false)); + + InProcessTransformResult finishedResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + context.handleResult(null, ImmutableList.of(), finishedResult); + // Obtain the value via blocking call + assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true)); + } + + @Test + public void callAfterOutputMustHaveBeenProducedAlreadyAfterCallsImmediately() throws Exception { + InProcessTransformResult finishedResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + context.handleResult(null, ImmutableList.of(), finishedResult); + + final CountDownLatch callLatch = new CountDownLatch(1); + Runnable callback = + new Runnable() { + @Override + public void run() { + callLatch.countDown(); + } + }; + context.scheduleAfterOutputWouldBeProduced( + downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback); + assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true)); + } + + @Test + public void extractFiredTimersExtractsTimers() { + InProcessTransformResult holdResult = + StepTransformResult.withHold(created.getProducingTransformInternal(), new Instant(0)) + .build(); + context.handleResult(null, ImmutableList.of(), holdResult); + + String key = "foo"; + TimerData toFire = + TimerData.of(StateNamespaces.global(), new Instant(100L), TimeDomain.EVENT_TIME); + InProcessTransformResult timerResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withState(CopyOnAccessInMemoryStateInternals.withUnderlying(key, null)) + .withTimerUpdate(TimerUpdate.builder(key).setTimer(toFire).build()) + .build(); + + // haven't added any timers, must be empty + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + context.handleResult( + InProcessBundle.keyed(created, key).commit(Instant.now()), + ImmutableList.of(), + timerResult); + + // timer hasn't fired + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + + InProcessTransformResult advanceResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + // Should cause the downstream timer to fire + context.handleResult(null, ImmutableList.of(), advanceResult); + + Map, Map> fired = context.extractFiredTimers(); + assertThat( + fired, + Matchers.>hasKey(downstream.getProducingTransformInternal())); + Map downstreamFired = + fired.get(downstream.getProducingTransformInternal()); + assertThat(downstreamFired, Matchers.hasKey(key)); + + FiredTimers firedForKey = downstreamFired.get(key); + assertThat(firedForKey.getTimers(TimeDomain.PROCESSING_TIME), emptyIterable()); + assertThat(firedForKey.getTimers(TimeDomain.SYNCHRONIZED_PROCESSING_TIME), emptyIterable()); + assertThat(firedForKey.getTimers(TimeDomain.EVENT_TIME), contains(toFire)); + + // Don't reextract timers + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + } + + @Test + public void createBundleUnkeyedResultUnkeyed() { + CommittedBundle> newBundle = + context + .createBundle(InProcessBundle.unkeyed(created).commit(Instant.now()), downstream) + .commit(Instant.now()); + assertThat(newBundle.isKeyed(), is(false)); + } + + @Test + public void createBundleKeyedResultPropagatesKey() { + CommittedBundle> newBundle = + context + .createBundle(InProcessBundle.keyed(created, "foo").commit(Instant.now()), downstream) + .commit(Instant.now()); + assertThat(newBundle.isKeyed(), is(true)); + assertThat(newBundle.getKey(), Matchers.equalTo("foo")); + } + + @Test + public void createRootBundleUnkeyed() { + assertThat(context.createRootBundle(created).commit(Instant.now()).isKeyed(), is(false)); + } + + @Test + public void createKeyedBundleKeyed() { + CommittedBundle> keyedBundle = + context + .createKeyedBundle( + InProcessBundle.unkeyed(created).commit(Instant.now()), "foo", downstream) + .commit(Instant.now()); + assertThat(keyedBundle.isKeyed(), is(true)); + assertThat(keyedBundle.getKey(), Matchers.equalTo("foo")); + } + + private static class TestBoundedWindow extends BoundedWindow { + private final Instant ts; + + public TestBoundedWindow(Instant ts) { + this.ts = ts; + } + + @Override + public Instant maxTimestamp() { + return ts; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java new file mode 100644 index 0000000000..adb64cd625 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessPipelineResult; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for basic {@link InProcessPipelineRunner} functionality. + */ +@RunWith(JUnit4.class) +public class InProcessPipelineRunnerTest implements Serializable { + @Test + public void wordCountShouldSucceed() throws Throwable { + Pipeline p = getPipeline(); + + PCollection> counts = + p.apply(Create.of("foo", "bar", "foo", "baz", "bar", "foo")) + .apply(MapElements.via(new SimpleFunction() { + @Override + public String apply(String input) { + return input; + } + })) + .apply(Count.perElement()); + PCollection countStrs = + counts.apply(MapElements.via(new SimpleFunction, String>() { + @Override + public String apply(KV input) { + String str = String.format("%s: %s", input.getKey(), input.getValue()); + return str; + } + })); + + DataflowAssert.that(countStrs).containsInAnyOrder("baz: 1", "bar: 2", "foo: 3"); + + InProcessPipelineResult result = ((InProcessPipelineResult) p.run()); + result.awaitCompletion(); + } + + private Pipeline getPipeline() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(InProcessPipelineRunner.class); + + Pipeline p = Pipeline.create(opts); + return p; + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java index 4cfe782936..16b4eb7d07 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java @@ -24,7 +24,6 @@ import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.Mean; @@ -137,7 +136,7 @@ public void getAfterWriteReturnsPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(one, two)); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, firstWindow); assertThat(viewContents, hasEntry("one", 1)); assertThat(viewContents, hasEntry("two", 2)); @@ -153,7 +152,7 @@ public void getReturnsLatestPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(one, two)); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(viewContents, hasEntry("one", 1)); assertThat(viewContents, hasEntry("two", 2)); @@ -164,7 +163,7 @@ public void getReturnsLatestPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(three)); Map overwrittenViewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(overwrittenViewContents, hasEntry("three", 3)); assertThat(overwrittenViewContents.size(), is(1)); @@ -176,15 +175,18 @@ public void getReturnsLatestPaneInWindow() throws Exception { */ @Test public void getBlocksUntilPaneAvailable() throws Exception { - BoundedWindow window = new BoundedWindow() { - @Override - public Instant maxTimestamp() { - return new Instant(1024L); - } - }; + BoundedWindow window = + new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(1024L); + } + }; Future singletonFuture = - getFutureOfView(container.withViews(ImmutableList.>of(singletonView)), - singletonView, window); + getFutureOfView( + container.createReaderForViews(ImmutableList.>of(singletonView)), + singletonView, + window); WindowedValue singletonValue = WindowedValue.of(4.75, new Instant(475L), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); @@ -203,7 +205,7 @@ public Instant maxTimestamp() { } }; SideInputReader newReader = - container.withViews(ImmutableList.>of(singletonView)); + container.createReaderForViews(ImmutableList.>of(singletonView)); Future singletonFuture = getFutureOfView(newReader, singletonView, window); WindowedValue singletonValue = @@ -216,25 +218,31 @@ public Instant maxTimestamp() { @Test public void withPCollectionViewsErrorsForContainsNotInViews() { - PCollectionView>> newView = PCollectionViews.multimapView(pipeline, - WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + PCollectionView>> newView = + PCollectionViews.multimapView( + pipeline, + WindowingStrategy.globalDefault(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("with unknown views " + ImmutableList.of(newView).toString()); - container.withViews(ImmutableList.>of(newView)); + container.createReaderForViews(ImmutableList.>of(newView)); } @Test public void withViewsForViewNotInContainerFails() { - PCollectionView>> newView = PCollectionViews.multimapView(pipeline, - WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + PCollectionView>> newView = + PCollectionViews.multimapView( + pipeline, + WindowingStrategy.globalDefault(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("unknown views"); thrown.expectMessage(newView.toString()); - container.withViews(ImmutableList.>of(newView)); + container.createReaderForViews(ImmutableList.>of(newView)); } @Test @@ -242,7 +250,7 @@ public void getOnReaderForViewNotInReaderFails() { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("unknown view: " + iterableView.toString()); - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(iterableView, GlobalWindow.INSTANCE); } @@ -255,11 +263,11 @@ public void writeForMultipleElementsInDifferentWindowsSucceeds() throws Exceptio PaneInfo.ON_TIME_AND_ONLY_FIRING); container.write(singletonView, ImmutableList.of(firstWindowedValue, secondWindowedValue)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, firstWindow), equalTo(2.875)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, secondWindow), equalTo(4.125)); } @@ -274,7 +282,7 @@ public void writeForMultipleIdenticalElementsInSameWindowSucceeds() throws Excep container.write(iterableView, ImmutableList.of(firstValue, secondValue)); assertThat( - container.withViews(ImmutableList.>of(iterableView)) + container.createReaderForViews(ImmutableList.>of(iterableView)) .get(iterableView, firstWindow), contains(44, 44)); } @@ -286,11 +294,11 @@ public void writeForElementInMultipleWindowsSucceeds() throws Exception { ImmutableList.of(firstWindow, secondWindow), PaneInfo.ON_TIME_AND_ONLY_FIRING); container.write(singletonView, ImmutableList.of(multiWindowedValue)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, firstWindow), equalTo(2.875)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, secondWindow), equalTo(2.875)); } @@ -306,7 +314,7 @@ public void finishDoesNotOverwriteWrittenElements() throws Exception { immediatelyInvokeCallback(mapView, secondWindow); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(viewContents, hasEntry("one", 1)); @@ -317,8 +325,11 @@ public void finishDoesNotOverwriteWrittenElements() throws Exception { @Test public void finishOnPendingViewsSetsEmptyElements() throws Exception { immediatelyInvokeCallback(mapView, secondWindow); - Future> mapFuture = getFutureOfView( - container.withViews(ImmutableList.>of(mapView)), mapView, secondWindow); + Future> mapFuture = + getFutureOfView( + container.createReaderForViews(ImmutableList.>of(mapView)), + mapView, + secondWindow); assertThat(mapFuture.get().isEmpty(), is(true)); } @@ -329,18 +340,21 @@ public void finishOnPendingViewsSetsEmptyElements() throws Exception { */ private void immediatelyInvokeCallback(PCollectionView view, BoundedWindow window) { doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - Object callback = invocation.getArguments()[3]; - Runnable callbackRunnable = (Runnable) callback; - callbackRunnable.run(); - return null; - } - }) + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + Object callback = invocation.getArguments()[3]; + Runnable callbackRunnable = (Runnable) callback; + callbackRunnable.run(); + return null; + } + }) .when(context) - .callAfterOutputMustHaveBeenProduced(Mockito.eq(view), Mockito.eq(window), - Mockito.eq(view.getWindowingStrategyInternal()), Mockito.any(Runnable.class)); + .scheduleAfterOutputWouldBeProduced( + Mockito.eq(view), + Mockito.eq(window), + Mockito.eq(view.getWindowingStrategyInternal()), + Mockito.any(Runnable.class)); } private Future getFutureOfView(final SideInputReader myReader, diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java new file mode 100644 index 0000000000..0aaccc2848 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableSet; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; +import java.util.Set; + +/** + * Tests for {@link KeyedPValueTrackingVisitor}. + */ +@RunWith(JUnit4.class) +public class KeyedPValueTrackingVisitorTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private KeyedPValueTrackingVisitor visitor; + private Pipeline p; + + @Before + public void setup() { + PipelineOptions options = PipelineOptionsFactory.create(); + + p = Pipeline.create(options); + @SuppressWarnings("rawtypes") + Set> producesKeyed = + ImmutableSet.>of(PrimitiveKeyer.class, CompositeKeyer.class); + visitor = KeyedPValueTrackingVisitor.create(producesKeyed); + } + + @Test + public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)).apply(new PrimitiveKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)) + .apply("firstKey", new PrimitiveKeyer()) + .apply("secondKey", new PrimitiveKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)).apply(new CompositeKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)) + .apply("firstKey", new CompositeKeyer()) + .apply("secondKey", new CompositeKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + + @Test + public void noInputUnkeyedOutput() { + PCollection>> unkeyed = + p.apply( + Create.of(KV.>of(-1, Collections.emptyList())) + .withCoder(KvCoder.of(VarIntCoder.of(), IterableCoder.of(VoidCoder.of())))); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void keyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection onceKeyed = + p.apply(Create.of(1, 2, 3)) + .apply(new PrimitiveKeyer()) + .apply(ParDo.of(new IdentityFn())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed))); + } + + @Test + public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection unkeyed = + p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new IdentityFn())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply( + Create.>of( + KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null)) + .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of()))) + .apply(GroupByKey.create()) + .apply(Keys.create()); + + p.traverseTopologically(visitor); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("already been finalized"); + thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName()); + p.traverseTopologically(visitor); + } + + @Test + public void getKeyedPValuesBeforeTraverseThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getKeyedPValues"); + visitor.getKeyedPValues(); + } + + private static class PrimitiveKeyer extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); + } + } + + private static class CompositeKeyer extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return input.apply(new PrimitiveKeyer()).apply(ParDo.of(new IdentityFn())); + } + } + + private static class IdentityFn extends DoFn { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + c.output(c.element()); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java index 033f9de204..66430b6a7f 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java @@ -26,7 +26,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java index ae599bab62..3b928b9077 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java @@ -26,7 +26,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java new file mode 100644 index 0000000000..2c66dc2c32 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.google.common.util.concurrent.MoreExecutors; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +/** + * Tests for {@link TransformExecutorServices}. + */ +@RunWith(JUnit4.class) +public class TransformExecutorServicesTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private ExecutorService executorService; + private Map, Boolean> scheduled; + + @Before + public void setup() { + executorService = MoreExecutors.newDirectExecutorService(); + scheduled = new ConcurrentHashMap<>(); + } + + @Test + public void parallelScheduleMultipleSchedulesBothImmediately() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService parallel = + TransformExecutorServices.parallel(executorService, scheduled); + parallel.schedule(first); + parallel.schedule(second); + + verify(first).call(); + verify(second).call(); + assertThat( + scheduled, + Matchers.allOf( + Matchers., Boolean>hasEntry(first, true), + Matchers., Boolean>hasEntry(second, true))); + + parallel.complete(first); + assertThat(scheduled, Matchers., Boolean>hasEntry(second, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(first), any(Boolean.class)))); + parallel.complete(second); + assertThat(scheduled.isEmpty(), is(true)); + } + + @Test + public void serialScheduleTwoWaitsForFirstToComplete() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService serial = TransformExecutorServices.serial(executorService, scheduled); + serial.schedule(first); + verify(first).call(); + + serial.schedule(second); + verify(second, never()).call(); + + assertThat(scheduled, Matchers., Boolean>hasEntry(first, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(second), any(Boolean.class)))); + + serial.complete(first); + verify(second).call(); + assertThat(scheduled, Matchers., Boolean>hasEntry(second, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(first), any(Boolean.class)))); + + serial.complete(second); + } + + @Test + public void serialCompleteNotExecutingTaskThrows() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService serial = TransformExecutorServices.serial(executorService, scheduled); + serial.schedule(first); + thrown.expect(IllegalStateException.class); + thrown.expectMessage("unexpected currently executing"); + + serial.complete(second); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java new file mode 100644 index 0000000000..bd6325278a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.util.concurrent.MoreExecutors; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Tests for {@link TransformExecutor}. + */ +@RunWith(JUnit4.class) +public class TransformExecutorTest { + private PCollection created; + private PCollection> downstream; + + private CountDownLatch evaluatorCompleted; + + private RegisteringCompletionCallback completionCallback; + private TransformExecutorService transformEvaluationState; + @Mock private InProcessEvaluationContext evaluationContext; + @Mock private TransformEvaluatorRegistry registry; + private Map, Boolean> scheduled; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + scheduled = new HashMap<>(); + transformEvaluationState = + TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService(), scheduled); + + evaluatorCompleted = new CountDownLatch(1); + completionCallback = new RegisteringCompletionCallback(evaluatorCompleted); + + TestPipeline p = TestPipeline.create(); + created = p.apply(Create.of("foo", "spam", "third")); + downstream = created.apply(WithKeys.of(3)); + } + + @Test + public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + final AtomicBoolean finishCalled = new AtomicBoolean(false); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw new IllegalArgumentException("Shouldn't be called"); + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + finishCalled.set(true); + return result; + } + }; + + when(registry.forApplication(created.getProducingTransformInternal(), null, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + null, + created.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + executor.call(); + + assertThat(finishCalled.get(), is(true)); + assertThat(completionCallback.handledResult, equalTo(result)); + assertThat(completionCallback.handledThrowable, is(nullValue())); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final Collection> elementsProcessed = new ArrayList<>(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + elementsProcessed.add(element); + return; + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return result; + } + }; + + WindowedValue foo = WindowedValue.valueInGlobalWindow("foo"); + WindowedValue spam = WindowedValue.valueInGlobalWindow("spam"); + WindowedValue third = WindowedValue.valueInGlobalWindow("third"); + CommittedBundle inputBundle = + InProcessBundle.unkeyed(created).add(foo).add(spam).add(third).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo)); + assertThat(completionCallback.handledResult, equalTo(result)); + assertThat(completionCallback.handledThrowable, is(nullValue())); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void processElementThrowsExceptionCallsback() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final Exception exception = new Exception(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw exception; + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return result; + } + }; + + WindowedValue foo = WindowedValue.valueInGlobalWindow("foo"); + CommittedBundle inputBundle = + InProcessBundle.unkeyed(created).add(foo).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(completionCallback.handledResult, is(nullValue())); + assertThat(completionCallback.handledThrowable, Matchers.equalTo(exception)); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void finishBundleThrowsExceptionCallsback() throws Exception { + final Exception exception = new Exception(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception {} + + @Override + public InProcessTransformResult finishBundle() throws Exception { + throw exception; + } + }; + + CommittedBundle inputBundle = InProcessBundle.unkeyed(created).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(completionCallback.handledResult, is(nullValue())); + assertThat(completionCallback.handledThrowable, Matchers.equalTo(exception)); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void duringCallGetThreadIsNonNull() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final CountDownLatch testLatch = new CountDownLatch(1); + final CountDownLatch evaluatorLatch = new CountDownLatch(1); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw new IllegalArgumentException("Shouldn't be called"); + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + testLatch.countDown(); + evaluatorLatch.await(); + return result; + } + }; + + when(registry.forApplication(created.getProducingTransformInternal(), null, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + null, + created.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + + Executors.newSingleThreadExecutor().submit(executor); + testLatch.await(); + assertThat(executor.getThread(), not(nullValue())); + + // Finish the execution so everything can get closed down cleanly. + evaluatorLatch.countDown(); + } + + private static class RegisteringCompletionCallback implements CompletionCallback { + private InProcessTransformResult handledResult = null; + private Throwable handledThrowable = null; + private final CountDownLatch onMethod; + + private RegisteringCompletionCallback(CountDownLatch onMethod) { + this.onMethod = onMethod; + } + + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + handledResult = result; + onMethod.countDown(); + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + handledThrowable = t; + onMethod.countDown(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java index f139c5648e..20a7d60bce 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java @@ -18,21 +18,30 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.UnboundedSource; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.CheckpointMark; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; import org.hamcrest.Matchers; import org.joda.time.DateTime; @@ -42,6 +51,15 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; /** * Tests for {@link UnboundedReadEvaluatorFactory}. */ @@ -112,6 +130,41 @@ public void unboundedSourceInMemoryTransformEvaluatorMultipleSequentialCalls() t tgw(15L), tgw(13L), tgw(10L))); } + @Test + public void boundedSourceEvaluatorClosesReader() throws Exception { + TestUnboundedSource source = + new TestUnboundedSource<>(BigEndianLongCoder.of(), 1L, 2L, 3L); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(ImmutableList.copyOf(committed.getElements()), hasSize(3)); + assertThat(TestUnboundedSource.readerClosedCount, equalTo(1)); + } + + @Test + public void boundedSourceEvaluatorNoElementsClosesReader() throws Exception { + TestUnboundedSource source = new TestUnboundedSource<>(BigEndianLongCoder.of()); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), emptyIterable()); + assertThat(TestUnboundedSource.readerClosedCount, equalTo(1)); + } + // TODO: Once the source is split into multiple sources before evaluating, this test will have to // be updated. /** @@ -157,4 +210,118 @@ public Instant apply(Long input) { return new Instant(input); } } + + private static class TestUnboundedSource extends UnboundedSource { + static int readerClosedCount; + private final Coder coder; + private final List elems; + + public TestUnboundedSource(Coder coder, T... elems) { + readerClosedCount = 0; + this.coder = coder; + this.elems = Arrays.asList(elems); + } + + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { + return ImmutableList.of(this); + } + + @Override + public UnboundedSource.UnboundedReader createReader( + PipelineOptions options, TestCheckpointMark checkpointMark) { + return new TestUnboundedReader(elems); + } + + @Override + @Nullable + public Coder getCheckpointMarkCoder() { + return new TestCheckpointMark.Coder(); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + private class TestUnboundedReader extends UnboundedReader { + private final List elems; + private int index; + + public TestUnboundedReader(List elems) { + this.elems = elems; + this.index = -1; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (index + 1 < elems.size()) { + index++; + return true; + } + return false; + } + + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public CheckpointMark getCheckpointMark() { + return new TestCheckpointMark(); + } + + @Override + public UnboundedSource getCurrentSource() { + TestUnboundedSource source = TestUnboundedSource.this; + return source; + } + + @Override + public T getCurrent() throws NoSuchElementException { + return elems.get(index); + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return Instant.now(); + } + + @Override + public void close() throws IOException { + readerClosedCount++; + } + } + } + + private static class TestCheckpointMark implements CheckpointMark { + @Override + public void finalizeCheckpoint() throws IOException {} + + public static class Coder extends AtomicCoder { + @Override + public void encode( + TestCheckpointMark value, + OutputStream outStream, + com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public TestCheckpointMark decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new TestCheckpointMark(); + } + } + } } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java index 2f5cd0fb88..2f5bdde6ad 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java @@ -25,7 +25,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.coders.VoidCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java new file mode 100644 index 0000000000..be3e062fbd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link WatermarkCallbackExecutor}. + */ +@RunWith(JUnit4.class) +public class WatermarkCallbackExecutorTest { + private WatermarkCallbackExecutor executor = WatermarkCallbackExecutor.create(); + private AppliedPTransform create; + private AppliedPTransform sum; + + @Before + public void setup() { + TestPipeline p = TestPipeline.create(); + PCollection created = p.apply(Create.of(1, 2, 3)); + create = created.getProducingTransformInternal(); + sum = created.apply(Sum.integersGlobally()).getProducingTransformInternal(); + } + + @Test + public void onGuaranteedFiringFiresAfterTrigger() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + executor.callOnGuaranteedFiring( + create, + GlobalWindow.INSTANCE, + WindowingStrategy.globalDefault(), + new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(true)); + } + + @Test + public void multipleCallbacksShouldFireFires() throws Exception { + CountDownLatch latch = new CountDownLatch(2); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(10))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(true)); + } + + @Test + public void noCallbacksShouldFire() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(5))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(false)); + } + + @Test + public void unrelatedStepShouldNotFire() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + sum, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(20))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(false)); + } + + private static class CountDownLatchCallback implements Runnable { + private final CountDownLatch latch; + + public CountDownLatchCallback(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public void run() { + latch.countDown(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java new file mode 100644 index 0000000000..ad37708677 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.NullableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFns.CoCombineResult; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for {@link CombineFns}. + */ +@RunWith(JUnit4.class) +public class CombineFnsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testDuplicatedTags() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.compose() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.composeKeyed() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsWithContext() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.compose() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag); + } + + @Test + public void testDuplicatedTagsWithContextKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.composeKeyed() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombine() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + PCollection>> combineGlobally = perKeyInput + .apply(Values.>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString(), + concatStringTag))) + .apply(WithKeys.of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString().asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114")), + KV.of("b", KV.of(13, "113"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineWithContext() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollectionView view = p + .apply(Create.of("I")) + .apply(View.asSingleton()); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + PCollection>> combineGlobally = perKeyInput + .apply(Values.>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view).forKey("G", StringUtf8Coder.of()), + concatStringTag)) + .withoutDefaults() + .withSideInputs(ImmutableList.of(view))) + .apply(WithKeys.of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view), + concatStringTag)) + .withSideInputs(ImmutableList.of(view))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134GI"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114Ia")), + KV.of("b", KV.of(13, "113Ib"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineNullValues() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, NullableCoder.of(UserStringCoder.of())); + p.getCoderRegistry().registerCoder(String.class, NullableCoder.of(StringUtf8Coder.of())); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of( + BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of()))))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new OutputNullString().asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, (String) null)), + KV.of("b", KV.of(13, (String) null))); + p.run(); + } + + private static class UserString implements Serializable { + private String strValue; + + static UserString of(String strValue) { + UserString ret = new UserString(); + ret.strValue = strValue; + return ret; + } + } + + private static class UserStringCoder extends StandardCoder { + public static UserStringCoder of() { + return INSTANCE; + } + + private static final UserStringCoder INSTANCE = new UserStringCoder(); + + @Override + public void encode(UserString value, OutputStream outStream, Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(value.strValue, outStream, context); + } + + @Override + public UserString decode(InputStream inStream, Context context) + throws CoderException, IOException { + return UserString.of(StringUtf8Coder.of().decode(inStream, context)); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException {} + } + + private static class GetIntegerFunction + extends SimpleFunction, Integer> { + @Override + public Integer apply(KV input) { + return input.getKey(); + } + } + + private static class GetUserStringFunction + extends SimpleFunction, UserString> { + @Override + public UserString apply(KV input) { + return input.getValue(); + } + } + + private static class ConcatString extends BinaryCombineFn { + @Override + public UserString apply(UserString left, UserString right) { + String retStr = left.strValue + right.strValue; + char[] chars = retStr.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class OutputNullString extends BinaryCombineFn { + @Override + public UserString apply(UserString left, UserString right) { + return null; + } + } + + private static class ConcatStringWithContext + extends KeyedCombineFnWithContext { + private final PCollectionView view; + + private ConcatStringWithContext(PCollectionView view) { + this.view = view; + } + + @Override + public UserString createAccumulator(String key, CombineWithContext.Context c) { + return UserString.of(key + c.sideInput(view)); + } + + @Override + public UserString addInput( + String key, UserString accumulator, UserString input, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + accumulator.strValue += input.strValue; + return accumulator; + } + + @Override + public UserString mergeAccumulators( + String key, Iterable accumulators, CombineWithContext.Context c) { + String keyPrefix = key + c.sideInput(view); + String all = keyPrefix; + for (UserString accumulator : accumulators) { + assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix)); + all += accumulator.strValue.substring(keyPrefix.length()); + accumulator.strValue = "cleared in mergeAccumulators"; + } + return UserString.of(all); + } + + @Override + public UserString extractOutput( + String key, UserString accumulator, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + char[] chars = accumulator.strValue.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class ExtractResultDoFn + extends DoFn, KV>>{ + + private final TupleTag maxIntTag; + private final TupleTag concatStringTag; + + ExtractResultDoFn(TupleTag maxIntTag, TupleTag concatStringTag) { + this.maxIntTag = maxIntTag; + this.concatStringTag = concatStringTag; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + UserString userString = c.element().getValue().get(concatStringTag); + KV value = KV.of( + c.element().getValue().get(maxIntTag), + userString == null ? null : userString.strValue); + c.output(KV.of(c.element().getKey(), value)); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java index a709a233ab..dabad7bae7 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java @@ -17,14 +17,17 @@ package com.google.cloud.dataflow.sdk.transforms; import static org.hamcrest.CoreMatchers.isA; +import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; import com.google.cloud.dataflow.sdk.testing.RunnableOnService; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; import org.junit.Rule; import org.junit.Test; @@ -188,4 +191,16 @@ private TestPipeline createTestPipeline(DoFn return pipeline; } + + @Test + public void testPopulateDisplayDataDefaultBehavior() { + DoFn usesDefault = + new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception {} + }; + + DisplayData data = DisplayData.from(usesDefault); + assertThat(data.items(), empty()); + } } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java new file mode 100644 index 0000000000..cea1b38df7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link PTransform} base class. + */ +@RunWith(JUnit4.class) +public class PTransformTest { + @Test + public void testPopulateDisplayDataDefaultBehavior() { + PTransform, PCollection> transform = + new PTransform, PCollection>() {}; + DisplayData displayData = DisplayData.from(transform); + assertThat(displayData.items(), empty()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java index f3f9bde92d..1ff46e41c2 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java @@ -16,6 +16,8 @@ package com.google.cloud.dataflow.sdk.transforms; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; @@ -39,6 +41,9 @@ import com.google.cloud.dataflow.sdk.testing.RunnableOnService; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.ParDo.Bound; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; import com.google.cloud.dataflow.sdk.transforms.windowing.Window; import com.google.cloud.dataflow.sdk.util.IllegalMutationException; @@ -1515,4 +1520,22 @@ public void testMutatingInputCoderDoFnError() throws Exception { thrown.expectMessage("must not be mutated"); pipeline.run(); } + + @Test + public void testIncludesDoFnDisplayData() { + Bound parDo = + ParDo.of( + new DoFn() { + @Override + public void processElement(ProcessContext c) {} + + @Override + public void populateDisplayData(Builder builder) { + builder.add("foo", "bar"); + } + }); + + DisplayData displayData = DisplayData.from(parDo); + assertThat(displayData, hasDisplayItem(hasKey("foo"))); + } } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java new file mode 100644 index 0000000000..2753aafa7b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Item; + +import org.hamcrest.Description; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeDiagnosingMatcher; + +import java.util.Collection; + +/** + * Hamcrest matcher for making assertions on {@link DisplayData} instances. + */ +public class DisplayDataMatchers { + /** + * Do not instantiate. + */ + private DisplayDataMatchers() {} + + /** + * Creates a matcher that matches if the examined {@link DisplayData} contains any items. + */ + public static Matcher hasDisplayItem() { + return hasDisplayItem(Matchers.any(DisplayData.Item.class)); + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData} contains any item + * matching the specified {@code itemMatcher}. + */ + public static Matcher hasDisplayItem(Matcher itemMatcher) { + return new HasDisplayDataItemMatcher(itemMatcher); + } + + private static class HasDisplayDataItemMatcher extends TypeSafeDiagnosingMatcher { + private final Matcher itemMatcher; + + private HasDisplayDataItemMatcher(Matcher itemMatcher) { + this.itemMatcher = itemMatcher; + } + + @Override + public void describeTo(Description description) { + description.appendText("display data with item: "); + itemMatcher.describeTo(description); + } + + @Override + protected boolean matchesSafely(DisplayData data, Description mismatchDescription) { + Collection items = data.items(); + boolean isMatch = Matchers.hasItem(itemMatcher).matches(items); + if (!isMatch) { + mismatchDescription.appendText("found " + items.size() + " non-matching items"); + } + + return isMatch; + } + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData.Item} contains a key + * with the specified value. + */ + public static Matcher hasKey(String key) { + return hasKey(Matchers.is(key)); + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData.Item} contains a key + * matching the specified key matcher. + */ + public static Matcher hasKey(Matcher keyMatcher) { + return new FeatureMatcher(keyMatcher, "with key", "key") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getKey(); + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java new file mode 100644 index 0000000000..2636cf85c8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.StringDescription; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link DisplayDataMatchers}. + */ +@RunWith(JUnit4.class) +public class DisplayDataMatchersTest { + @Test + public void testHasDisplayItem() { + Matcher matcher = hasDisplayItem(); + + assertFalse(matcher.matches(DisplayData.none())); + assertTrue(matcher.matches(createDisplayDataWithItem("foo", "bar"))); + } + + @Test + public void testHasDisplayItemDescription() { + Matcher matcher = hasDisplayItem(); + Description desc = new StringDescription(); + Description mismatchDesc = new StringDescription(); + + matcher.describeTo(desc); + matcher.describeMismatch(DisplayData.none(), mismatchDesc); + + assertThat(desc.toString(), startsWith("display data with item: ")); + assertThat(mismatchDesc.toString(), containsString("found 0 non-matching items")); + } + + @Test + public void testHasKey() { + Matcher matcher = hasDisplayItem(hasKey("foo")); + + assertTrue(matcher.matches(createDisplayDataWithItem("foo", "bar"))); + assertFalse(matcher.matches(createDisplayDataWithItem("fooz", "bar"))); + } + + private DisplayData createDisplayDataWithItem(final String key, final String value) { + return DisplayData.from( + new PTransform, PCollection>() { + @Override + public void populateDisplayData(Builder builder) { + builder.add(key, value); + } + }); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java new file mode 100644 index 0000000000..13dd618657 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java @@ -0,0 +1,633 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.everyItem; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isEmptyOrNullString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Item; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.testing.EqualsTester; + +import org.hamcrest.CustomTypeSafeMatcher; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Tests for {@link DisplayData} class. + */ +@RunWith(JUnit4.class) +public class DisplayDataTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + private static final DateTimeFormatter ISO_FORMATTER = ISODateTimeFormat.dateTime(); + + @Test + public void testTypicalUsage() { + final HasDisplayData subComponent1 = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("ExpectedAnswer", 42); + } + }; + + final HasDisplayData subComponent2 = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("Location", "Seattle").add("Forecast", "Rain"); + } + }; + + PTransform transform = + new PTransform, PCollection>() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(subComponent1) + .include(subComponent2) + .add("MinSproggles", 200) + .withLabel("Mimimum Required Sproggles") + .add("LazerOrientation", "NORTH") + .add("TimeBomb", Instant.now().plus(Duration.standardDays(1))) + .add("FilterLogic", subComponent1.getClass()) + .add("ServiceUrl", "google.com/fizzbang") + .withLinkUrl("http://www.google.com/fizzbang"); + } + }; + + DisplayData data = DisplayData.from(transform); + + assertThat(data.items(), not(empty())); + assertThat( + data.items(), + everyItem( + allOf( + hasKey(not(isEmptyOrNullString())), + hasNamespace( + Matchers.>isOneOf( + transform.getClass(), subComponent1.getClass(), subComponent2.getClass())), + hasType(notNullValue(DisplayData.Type.class)), + hasValue(not(isEmptyOrNullString()))))); + } + + @Test + public void testDefaultInstance() { + DisplayData none = DisplayData.none(); + assertThat(none.items(), empty()); + } + + @Test + public void testCanBuild() { + DisplayData data = + DisplayData.from(new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("Foo", "bar"); + } + }); + + assertThat(data.items(), hasSize(1)); + assertThat(data, hasDisplayItem(hasKey("Foo"))); + } + + @Test + public void testAsMap() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + Map map = data.asMap(); + assertEquals(map.size(), 1); + assertThat(data, hasDisplayItem(hasKey("foo"))); + assertEquals(map.values(), data.items()); + } + + @Test + public void testItemProperties() { + final Instant value = Instant.now(); + DisplayData data = DisplayData.from(new ConcreteComponent(value)); + + @SuppressWarnings("unchecked") + DisplayData.Item item = (DisplayData.Item) data.items().toArray()[0]; + assertThat( + item, + allOf( + hasNamespace(Matchers.>is(ConcreteComponent.class)), + hasKey("now"), + hasType(is(DisplayData.Type.TIMESTAMP)), + hasValue(is(ISO_FORMATTER.print(value))), + hasShortValue(nullValue(String.class)), + hasLabel(is("the current instant")), + hasUrl(is("http://time.gov")))); + } + + static class ConcreteComponent implements HasDisplayData { + private Instant value; + + ConcreteComponent(Instant value) { + this.value = value; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("now", value).withLabel("the current instant").withLinkUrl("http://time.gov"); + } + } + + @Test + public void testUnspecifiedOptionalProperties() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + assertThat( + data, + hasDisplayItem(allOf(hasLabel(nullValue(String.class)), hasUrl(nullValue(String.class))))); + } + + @Test + public void testIncludes() { + final HasDisplayData subComponent = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.include(subComponent); + } + }); + + assertThat( + data, + hasDisplayItem( + allOf( + hasKey("foo"), + hasNamespace(Matchers.>is(subComponent.getClass()))))); + } + + @Test + public void testIdentifierEquality() { + new EqualsTester() + .addEqualityGroup( + DisplayData.Identifier.of(DisplayDataTest.class, "1"), + DisplayData.Identifier.of(DisplayDataTest.class, "1")) + .addEqualityGroup(DisplayData.Identifier.of(Object.class, "1")) + .addEqualityGroup(DisplayData.Identifier.of(DisplayDataTest.class, "2")) + .testEquals(); + } + + @Test + public void testAnonymousClassNamespace() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + DisplayData.Item item = (DisplayData.Item) data.items().toArray()[0]; + final Pattern anonClassRegex = Pattern.compile( + Pattern.quote(DisplayDataTest.class.getName()) + "\\$\\d+$"); + assertThat(item.getNamespace(), new CustomTypeSafeMatcher( + "anonymous class regex: " + anonClassRegex) { + @Override + protected boolean matchesSafely(String item) { + java.util.regex.Matcher m = anonClassRegex.matcher(item); + return m.matches(); + } + }); + } + + @Test + public void testAcceptsKeysWithDifferentNamespaces() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("foo", "bar") + .include( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + } + }); + + assertThat(data.items(), hasSize(2)); + } + + @Test + public void testDuplicateKeyThrowsException() { + thrown.expect(IllegalArgumentException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("foo", "bar") + .add("foo", "baz"); + } + }); + } + + @Test + public void testToString() { + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + DisplayData data = DisplayData.from(component); + assertEquals(String.format("%s:foo=bar", component.getClass().getName()), data.toString()); + } + + @Test + public void testHandlesIncludeCycles() { + + final IncludeSubComponent componentA = + new IncludeSubComponent() { + @Override + String getId() { + return "componentA"; + } + }; + final IncludeSubComponent componentB = + new IncludeSubComponent() { + @Override + String getId() { + return "componentB"; + } + }; + + HasDisplayData component = + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.include(componentA); + } + }; + + componentA.subComponent = componentB; + componentB.subComponent = componentA; + + DisplayData data = DisplayData.from(component); + assertThat(data.items(), hasSize(2)); + } + + @Test + public void testIncludesSubcomponentsWithObjectEquality() { + DisplayData data = DisplayData.from(new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(new EqualsEverything("foo1", "bar1")) + .include(new EqualsEverything("foo2", "bar2")); + } + }); + + assertThat(data.items(), hasSize(2)); + } + + private static class EqualsEverything implements HasDisplayData { + private final String value; + private final String key; + EqualsEverything(String key, String value) { + this.key = key; + this.value = value; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(key, value); + } + + @Override + public int hashCode() { + return 1; + } + + @Override + public boolean equals(Object obj) { + return true; + } + } + + abstract static class IncludeSubComponent implements HasDisplayData { + HasDisplayData subComponent; + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("id", getId()).include(subComponent); + } + + abstract String getId(); + } + + @Test + public void testTypeMappings() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("string", "foobar") + .add("integer", 123) + .add("float", 3.14) + .add("java_class", DisplayDataTest.class) + .add("timestamp", Instant.now()) + .add("duration", Duration.standardHours(1)); + } + }); + + Collection items = data.items(); + assertThat( + items, hasItem(allOf(hasKey("string"), hasType(is(DisplayData.Type.STRING))))); + assertThat( + items, hasItem(allOf(hasKey("integer"), hasType(is(DisplayData.Type.INTEGER))))); + assertThat(items, hasItem(allOf(hasKey("float"), hasType(is(DisplayData.Type.FLOAT))))); + assertThat( + items, + hasItem(allOf(hasKey("java_class"), hasType(is(DisplayData.Type.JAVA_CLASS))))); + assertThat( + items, + hasItem(allOf(hasKey("timestamp"), hasType(is(DisplayData.Type.TIMESTAMP))))); + assertThat( + items, hasItem(allOf(hasKey("duration"), hasType(is(DisplayData.Type.DURATION))))); + } + + @Test + public void testStringFormatting() throws IOException { + final Instant now = Instant.now(); + final Duration oneHour = Duration.standardHours(1); + + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("string", "foobar") + .add("integer", 123) + .add("float", 3.14) + .add("java_class", DisplayDataTest.class) + .add("timestamp", now) + .add("duration", oneHour); + } + }; + DisplayData data = DisplayData.from(component); + + Collection items = data.items(); + assertThat(items, hasItem(allOf(hasKey("string"), hasValue(is("foobar"))))); + assertThat(items, hasItem(allOf(hasKey("integer"), hasValue(is("123"))))); + assertThat(items, hasItem(allOf(hasKey("float"), hasValue(is("3.14"))))); + assertThat(items, hasItem(allOf(hasKey("java_class"), + hasValue(is(DisplayDataTest.class.getName())), + hasShortValue(is(DisplayDataTest.class.getSimpleName()))))); + assertThat(items, hasItem(allOf(hasKey("timestamp"), + hasValue(is(ISO_FORMATTER.print(now)))))); + assertThat(items, hasItem(allOf(hasKey("duration"), + hasValue(is(Long.toString(oneHour.getMillis())))))); + } + + @Test + public void testContextProperlyReset() { + final HasDisplayData subComponent = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + HasDisplayData component = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(subComponent) + .add("alpha", "bravo"); + } + }; + + DisplayData data = DisplayData.from(component); + assertThat( + data.items(), + hasItem( + allOf( + hasKey("alpha"), + hasNamespace(Matchers.>is(component.getClass()))))); + } + + @Test + public void testFromNull() { + thrown.expect(NullPointerException.class); + DisplayData.from(null); + } + + @Test + public void testIncludeNull() { + thrown.expect(NullPointerException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.include(null); + } + }); + } + + @Test + public void testNullKey() { + thrown.expect(NullPointerException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.add(null, "foo"); + } + }); + } + + @Test + public void testRejectsNullValues() { + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + try { + builder.add("key", (String) null); + throw new RuntimeException("Should throw on null string value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Class) null); + throw new RuntimeException("Should throw on null class value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Duration) null); + throw new RuntimeException("Should throw on null duration value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Instant) null); + throw new RuntimeException("Should throw on null instant value"); + } catch (NullPointerException ex) { + // Expected + } + } + }); + } + + public void testAcceptsNullOptionalValues() { + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.add("key", "value") + .withLabel(null) + .withLinkUrl(null); + } + }); + + // Should not throw + } + + private static Matcher hasNamespace(Matcher> nsMatcher) { + return new FeatureMatcher>( + nsMatcher, "display item with namespace", "namespace") { + @Override + protected Class featureValueOf(DisplayData.Item actual) { + try { + return Class.forName(actual.getNamespace()); + } catch (ClassNotFoundException e) { + return null; + } + } + }; + } + + private static Matcher hasType(Matcher typeMatcher) { + return new FeatureMatcher( + typeMatcher, "display item with type", "type") { + @Override + protected DisplayData.Type featureValueOf(DisplayData.Item actual) { + return actual.getType(); + } + }; + } + + private static Matcher hasLabel(Matcher labelMatcher) { + return new FeatureMatcher( + labelMatcher, "display item with label", "label") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getLabel(); + } + }; + } + + private static Matcher hasUrl(Matcher urlMatcher) { + return new FeatureMatcher( + urlMatcher, "display item with url", "URL") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getUrl(); + } + }; + } + + private static Matcher hasValue(Matcher valueMatcher) { + return new FeatureMatcher( + valueMatcher, "display item with value", "value") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getValue(); + } + }; + } + + private static Matcher hasShortValue(Matcher valueStringMatcher) { + return new FeatureMatcher( + valueStringMatcher, "display item with short value", "short value") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getShortValue(); + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java index e995b821de..fcfe1d8f2f 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java @@ -49,7 +49,8 @@ public void testOurApiSurface() throws Exception { .pruningClassName("com.google.cloud.dataflow.sdk.util.common.ReflectHelpers") .pruningClassName("com.google.cloud.dataflow.sdk.DataflowMatchers") .pruningClassName("com.google.cloud.dataflow.sdk.TestUtils") - .pruningClassName("com.google.cloud.dataflow.sdk.WindowMatchers"); + .pruningClassName("com.google.cloud.dataflow.sdk.WindowMatchers") + .pruningClassName("com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers"); checkedApiSurface.getExposedClasses(); diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java new file mode 100644 index 0000000000..b7e1467436 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Java 8 tests for {@link PipelineOptionsFactory}. + */ +@RunWith(JUnit4.class) +public class PipelineOptionsFactoryJava8Test { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private static interface OptionsWithDefaultMethod extends PipelineOptions { + default Number getValue() { + return 1024; + } + + void setValue(Number value); + } + + @Test + public void testDefaultMethodIgnoresDefaultImplementation() { + OptionsWithDefaultMethod optsWithDefault = + PipelineOptionsFactory.as(OptionsWithDefaultMethod.class); + assertThat(optsWithDefault.getValue(), nullValue()); + + optsWithDefault.setValue(12.25); + assertThat(optsWithDefault.getValue(), equalTo(Double.valueOf(12.25))); + } + + private static interface ExtendedOptionsWithDefault extends OptionsWithDefaultMethod {} + + @Test + public void testDefaultMethodInExtendedClassIgnoresDefaultImplementation() { + OptionsWithDefaultMethod extendedOptsWithDefault = + PipelineOptionsFactory.as(ExtendedOptionsWithDefault.class); + assertThat(extendedOptsWithDefault.getValue(), nullValue()); + + extendedOptsWithDefault.setValue(Double.NEGATIVE_INFINITY); + assertThat(extendedOptsWithDefault.getValue(), equalTo(Double.NEGATIVE_INFINITY)); + } + + private static interface Options extends PipelineOptions { + Number getValue(); + + void setValue(Number value); + } + + private static interface SubtypeReturingOptions extends Options { + @Override + Integer getValue(); + void setValue(Integer value); + } + + @Test + public void testReturnTypeConflictThrows() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "Method [getValue] has multiple definitions [public abstract java.lang.Integer " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$" + + "SubtypeReturingOptions.getValue(), public abstract java.lang.Number " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$Options" + + ".getValue()] with different return types for [" + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$" + + "SubtypeReturingOptions]."); + PipelineOptionsFactory.as(SubtypeReturingOptions.class); + } +}