From f16cbb1e4799597071ce82359522a2df218fdc1c Mon Sep 17 00:00:00 2001 From: Dan Halperin Date: Tue, 15 Mar 2016 00:45:42 -0700 Subject: [PATCH 01/11] Minor Javadoc fixes - Add package-info.java to two missing packages. - Fix a compile error leftover from changing a link to a code block, which requires dropping HTML escaping of brackets. --- .../sdk/coders/protobuf/package-info.java | 23 +++++++++++++++++++ .../dataflow/sdk/io/bigtable/BigtableIO.java | 2 +- .../sdk/io/bigtable/package-info.java | 22 ++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java new file mode 100644 index 0000000000..b5bcf18eda --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/coders/protobuf/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines a {@link com.google.cloud.dataflow.sdk.coders.Coder} + * for Protocol Buffers messages, {@code ProtoCoder}. + * + * @see com.google.cloud.dataflow.sdk.coders.protobuf.ProtoCoder + */ +package com.google.cloud.dataflow.sdk.coders.protobuf; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java index 7ecccf14a6..7d59b09c8d 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/BigtableIO.java @@ -72,7 +72,7 @@ *

Reading from Cloud Bigtable

* *

The Bigtable source returns a set of rows from a single table, returning a - * {@code PCollection<Row>}. + * {@code PCollection}. * *

To configure a Cloud Bigtable source, you must supply a table id and a {@link BigtableOptions} * or builder configured with the project and other information necessary to identify the diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java new file mode 100644 index 0000000000..112a954d71 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/bigtable/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Defines transforms for reading and writing from Google Cloud Bigtable. + * + * @see com.google.cloud.dataflow.sdk.io.bigtable.BigtableIO + */ +package com.google.cloud.dataflow.sdk.io.bigtable; From 21e1f34db8c00754ec92101aacc7803530060766 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 26 Feb 2016 17:28:37 -0800 Subject: [PATCH 02/11] Implement InProcessEvaluationContext This is the primary "global state" object for the evaluation of a Pipeline using the InProcessPipelineRunner, and is responsible for properly routing information about the state of the pipeline to transform evaluators. Remove the InProcessEvaluationContext from the InProcessPipelineRunner class, and implement as a class directly. Fix associated imports. --- .../BoundedReadEvaluatorFactory.java | 1 - .../sdk/runners/inprocess/EvaluatorKey.java | 1 - .../inprocess/FlattenEvaluatorFactory.java | 1 - .../inprocess/GroupByKeyEvaluatorFactory.java | 1 - .../inprocess/InMemoryWatermarkManager.java | 14 +- .../runners/inprocess/InProcessBundle.java | 20 +- .../inprocess/InProcessEvaluationContext.java | 364 +++++++++++++++ .../inprocess/InProcessPipelineOptions.java | 7 +- .../inprocess/InProcessPipelineRunner.java | 106 +---- .../InProcessSideInputContainer.java | 71 ++- .../inprocess/ParDoMultiEvaluatorFactory.java | 1 - .../ParDoSingleEvaluatorFactory.java | 1 - .../sdk/runners/inprocess/StepAndKey.java | 68 +++ .../inprocess/TransformEvaluatorFactory.java | 1 - .../inprocess/TransformEvaluatorRegistry.java | 72 +++ .../UnboundedReadEvaluatorFactory.java | 1 - .../inprocess/ViewEvaluatorFactory.java | 1 - .../inprocess/WatermarkCallbackExecutor.java | 143 ++++++ .../BoundedReadEvaluatorFactoryTest.java | 2 +- .../FlattenEvaluatorFactoryTest.java | 1 - .../GroupByKeyEvaluatorFactoryTest.java | 1 - .../InMemoryWatermarkManagerTest.java | 12 + .../InProcessEvaluationContextTest.java | 436 ++++++++++++++++++ .../InProcessSideInputContainerTest.java | 92 ++-- .../ParDoMultiEvaluatorFactoryTest.java | 1 - .../ParDoSingleEvaluatorFactoryTest.java | 1 - .../UnboundedReadEvaluatorFactoryTest.java | 1 - .../inprocess/ViewEvaluatorFactoryTest.java | 1 - .../WatermarkCallbackExecutorTest.java | 126 +++++ 29 files changed, 1372 insertions(+), 176 deletions(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java index 1c0279897a..2a164c3518 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java @@ -18,7 +18,6 @@ import com.google.cloud.dataflow.sdk.io.Read.Bounded; import com.google.cloud.dataflow.sdk.io.Source.Reader; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.PTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java index 745f8f2718..307bc5cdb5 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/EvaluatorKey.java @@ -15,7 +15,6 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import java.util.Objects; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java index 14428888e2..bde1df45e9 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactory.java @@ -16,7 +16,6 @@ package com.google.cloud.dataflow.sdk.runners.inprocess; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.Flatten; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java index 0347281749..ec63be84c9 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java @@ -22,7 +22,6 @@ import com.google.cloud.dataflow.sdk.coders.IterableCoder; import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.StepTransformResult.Builder; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java index e280e22d2b..7cf53aafe6 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java @@ -1209,8 +1209,11 @@ public TimerUpdateBuilder deletedTimer(TimerData deletedTimer) { * and deletedTimers. */ public TimerUpdate build() { - return new TimerUpdate(key, ImmutableSet.copyOf(completedTimers), - ImmutableSet.copyOf(setTimers), ImmutableSet.copyOf(deletedTimers)); + return new TimerUpdate( + key, + ImmutableSet.copyOf(completedTimers), + ImmutableSet.copyOf(setTimers), + ImmutableSet.copyOf(deletedTimers)); } } @@ -1245,6 +1248,13 @@ Iterable getDeletedTimers() { return deletedTimers; } + /** + * Returns a {@link TimerUpdate} that is like this one, but with the specified completed timers. + */ + public TimerUpdate withCompletedTimers(Iterable completedTimers) { + return new TimerUpdate(this.key, completedTimers, setTimers, deletedTimers); + } + @Override public int hashCode() { return Objects.hash(key, completedTimers, setTimers, deletedTimers); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java index cc20161097..112ba17d14 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessBundle.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -22,7 +22,6 @@ import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.common.base.MoreObjects; -import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.collect.ImmutableList; import org.joda.time.Instant; @@ -64,6 +63,11 @@ private InProcessBundle(PCollection pcollection, boolean keyed, Object key) { this.elements = ImmutableList.builder(); } + @Override + public PCollection getPCollection() { + return pcollection; + } + @Override public InProcessBundle add(WindowedValue element) { checkState(!committed, "Can't add element %s to committed bundle %s", element, this); @@ -105,12 +109,12 @@ public Instant getSynchronizedProcessingOutputWatermark() { @Override public String toString() { - ToStringHelper toStringHelper = - MoreObjects.toStringHelper(this).add("pcollection", pcollection); - if (keyed) { - toStringHelper = toStringHelper.add("key", key); - } - return toStringHelper.add("elements", elements).toString(); + return MoreObjects.toStringHelper(this) + .omitNullValues() + .add("pcollection", pcollection) + .add("key", key) + .add("elements", committedElements) + .toString(); } }; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java new file mode 100644 index 0000000000..757e9e11d9 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java @@ -0,0 +1,364 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TransformWatermarks; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; +import com.google.cloud.dataflow.sdk.util.ExecutionContext; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nullable; + +/** + * The evaluation context for a specific pipeline being executed by the + * {@link InProcessPipelineRunner}. Contains state shared within the execution across all + * transforms. + * + *

{@link InProcessEvaluationContext} contains shared state for an execution of the + * {@link InProcessPipelineRunner} that can be used while evaluating a {@link PTransform}. This + * consists of views into underlying state and watermark implementations, access to read and write + * {@link PCollectionView PCollectionViews}, and constructing {@link CounterSet CounterSets} and + * {@link ExecutionContext ExecutionContexts}. This includes executing callbacks asynchronously when + * state changes to the appropriate point (e.g. when a {@link PCollectionView} is requested and + * known to be empty). + * + *

{@link InProcessEvaluationContext} also handles results by committing finalizing bundles based + * on the current global state and updating the global state appropriately. This includes updating + * the per-{@link StepAndKey} state, updating global watermarks, and executing any callbacks that + * can be executed. + */ +class InProcessEvaluationContext { + /** The step name for each {@link AppliedPTransform} in the {@link Pipeline}. */ + private final Map, String> stepNames; + + /** The options that were used to create this {@link Pipeline}. */ + private final InProcessPipelineOptions options; + + /** The current processing time and event time watermarks and timers. */ + private final InMemoryWatermarkManager watermarkManager; + + /** Executes callbacks based on the progression of the watermark. */ + private final WatermarkCallbackExecutor callbackExecutor; + + /** The stateInternals of the world, by applied PTransform and key. */ + private final ConcurrentMap> + applicationStateInternals; + + private final InProcessSideInputContainer sideInputContainer; + + private final CounterSet mergedCounters; + + public static InProcessEvaluationContext create( + InProcessPipelineOptions options, + Collection> rootTransforms, + Map>> valueToConsumers, + Map, String> stepNames, + Collection> views) { + return new InProcessEvaluationContext( + options, rootTransforms, valueToConsumers, stepNames, views); + } + + private InProcessEvaluationContext( + InProcessPipelineOptions options, + Collection> rootTransforms, + Map>> valueToConsumers, + Map, String> stepNames, + Collection> views) { + this.options = checkNotNull(options); + checkNotNull(rootTransforms); + checkNotNull(valueToConsumers); + checkNotNull(stepNames); + checkNotNull(views); + this.stepNames = stepNames; + + this.watermarkManager = + InMemoryWatermarkManager.create( + NanosOffsetClock.create(), rootTransforms, valueToConsumers); + this.sideInputContainer = InProcessSideInputContainer.create(this, views); + + this.applicationStateInternals = new ConcurrentHashMap<>(); + this.mergedCounters = new CounterSet(); + + this.callbackExecutor = WatermarkCallbackExecutor.create(); + } + + /** + * Handle the provided {@link InProcessTransformResult}, produced after evaluating the provided + * {@link CommittedBundle} (potentially null, if the result of a root {@link PTransform}). + * + *

The result is the output of running the transform contained in the + * {@link InProcessTransformResult} on the contents of the provided bundle. + * + * @param completedBundle the bundle that was processed to produce the result. Potentially + * {@code null} if the transform that produced the result is a root + * transform + * @param completedTimers the timers that were delivered to produce the {@code completedBundle}, + * or an empty iterable if no timers were delivered + * @param result the result of evaluating the input bundle + * @return the committed bundles contained within the handled {@code result} + */ + public synchronized Iterable> handleResult( + @Nullable CommittedBundle completedBundle, + Iterable completedTimers, + InProcessTransformResult result) { + Iterable> committedBundles = + commitBundles(result.getOutputBundles()); + // Update watermarks and timers + watermarkManager.updateWatermarks( + completedBundle, + result.getTransform(), + result.getTimerUpdate().withCompletedTimers(completedTimers), + committedBundles, + result.getWatermarkHold()); + fireAllAvailableCallbacks(); + // Update counters + if (result.getCounters() != null) { + mergedCounters.merge(result.getCounters()); + } + // Update state internals + CopyOnAccessInMemoryStateInternals theirState = result.getState(); + if (theirState != null) { + CopyOnAccessInMemoryStateInternals committedState = theirState.commit(); + StepAndKey stepAndKey = + StepAndKey.of( + result.getTransform(), completedBundle == null ? null : completedBundle.getKey()); + if (!committedState.isEmpty()) { + applicationStateInternals.put(stepAndKey, committedState); + } else { + applicationStateInternals.remove(stepAndKey); + } + } + return committedBundles; + } + + private Iterable> commitBundles( + Iterable> bundles) { + ImmutableList.Builder> completed = ImmutableList.builder(); + for (UncommittedBundle inProgress : bundles) { + AppliedPTransform producing = + inProgress.getPCollection().getProducingTransformInternal(); + TransformWatermarks watermarks = watermarkManager.getWatermarks(producing); + CommittedBundle committed = + inProgress.commit(watermarks.getSynchronizedProcessingOutputTime()); + // Empty bundles don't impact watermarks and shouldn't trigger downstream execution, so + // filter them out + if (!Iterables.isEmpty(committed.getElements())) { + completed.add(committed); + } + } + return completed.build(); + } + + private void fireAllAvailableCallbacks() { + for (AppliedPTransform transform : stepNames.keySet()) { + fireAvailableCallbacks(transform); + } + } + + private void fireAvailableCallbacks(AppliedPTransform producingTransform) { + TransformWatermarks watermarks = watermarkManager.getWatermarks(producingTransform); + callbackExecutor.fireForWatermark(producingTransform, watermarks.getOutputWatermark()); + } + + /** + * Create a {@link UncommittedBundle} for use by a source. + */ + public UncommittedBundle createRootBundle(PCollection output) { + return InProcessBundle.unkeyed(output); + } + + /** + * Create a {@link UncommittedBundle} whose elements belong to the specified {@link + * PCollection}. + */ + public UncommittedBundle createBundle(CommittedBundle input, PCollection output) { + return input.isKeyed() + ? InProcessBundle.keyed(output, input.getKey()) + : InProcessBundle.unkeyed(output); + } + + /** + * Create a {@link UncommittedBundle} with the specified keys at the specified step. For use by + * {@link InProcessGroupByKeyOnly} {@link PTransform PTransforms}. + */ + public UncommittedBundle createKeyedBundle( + CommittedBundle input, Object key, PCollection output) { + return InProcessBundle.keyed(output, key); + } + + /** + * Create a {@link PCollectionViewWriter}, whose elements will be used in the provided + * {@link PCollectionView}. + */ + public PCollectionViewWriter createPCollectionViewWriter( + PCollection> input, final PCollectionView output) { + return new PCollectionViewWriter() { + @Override + public void add(Iterable> values) { + sideInputContainer.write(output, values); + } + }; + } + + /** + * Schedule a callback to be executed after output would be produced for the given window + * if there had been input. + * + *

Output would be produced when the watermark for a {@link PValue} passes the point at + * which the trigger for the specified window (with the specified windowing strategy) must have + * fired from the perspective of that {@link PValue}, as specified by the value of + * {@link Trigger#getWatermarkThatGuaranteesFiring(BoundedWindow)} for the trigger of the + * {@link WindowingStrategy}. When the callback has fired, either values will have been produced + * for a key in that window, the window is empty, or all elements in the window are late. The + * callback will be executed regardless of whether values have been produced. + */ + public void scheduleAfterOutputWouldBeProduced( + PValue value, + BoundedWindow window, + WindowingStrategy windowingStrategy, + Runnable runnable) { + AppliedPTransform producing = getProducing(value); + callbackExecutor.callOnGuaranteedFiring(producing, window, windowingStrategy, runnable); + + fireAvailableCallbacks(lookupProducing(value)); + } + + private AppliedPTransform getProducing(PValue value) { + if (value.getProducingTransformInternal() != null) { + return value.getProducingTransformInternal(); + } + return lookupProducing(value); + } + + private AppliedPTransform lookupProducing(PValue value) { + for (AppliedPTransform transform : stepNames.keySet()) { + if (transform.getOutput().equals(value) || transform.getOutput().expand().contains(value)) { + return transform; + } + } + return null; + } + + /** + * Get the options used by this {@link Pipeline}. + */ + public InProcessPipelineOptions getPipelineOptions() { + return options; + } + + /** + * Get an {@link ExecutionContext} for the provided {@link AppliedPTransform} and key. + */ + public InProcessExecutionContext getExecutionContext( + AppliedPTransform application, Object key) { + StepAndKey stepAndKey = StepAndKey.of(application, key); + return new InProcessExecutionContext( + options.getClock(), + key, + (CopyOnAccessInMemoryStateInternals) applicationStateInternals.get(stepAndKey), + watermarkManager.getWatermarks(application)); + } + + /** + * Get all of the steps used in this {@link Pipeline}. + */ + public Collection> getSteps() { + return stepNames.keySet(); + } + + /** + * Get the Step Name for the provided application. + */ + public String getStepName(AppliedPTransform application) { + return stepNames.get(application); + } + + /** + * Returns a {@link SideInputReader} capable of reading the provided + * {@link PCollectionView PCollectionViews}. + * @param sideInputs the {@link PCollectionView PCollectionViews} the result should be able to + * read + * @return a {@link SideInputReader} that can read all of the provided + * {@link PCollectionView PCollectionViews} + */ + public SideInputReader createSideInputReader(final List> sideInputs) { + return sideInputContainer.createReaderForViews(sideInputs); + } + + /** + * Create a {@link CounterSet} for this {@link Pipeline}. The {@link CounterSet} is independent + * of all other {@link CounterSet CounterSets} created by this call. + * + * The {@link InProcessEvaluationContext} is responsible for unifying the counters present in + * all created {@link CounterSet CounterSets} when the transforms that call this method + * complete. + */ + public CounterSet createCounterSet() { + return new CounterSet(); + } + + /** + * Returns all of the counters that have been merged into this context via calls to + * {@link CounterSet#merge(CounterSet)}. + */ + public CounterSet getCounters() { + return mergedCounters; + } + + /** + * Extracts all timers that have been fired and have not already been extracted. + * + *

This is a destructive operation. Timers will only appear in the result of this method once + * for each time they are set. + */ + public Map, Map> extractFiredTimers() { + return watermarkManager.extractFiredTimers(); + } + + /** + * Returns true if all steps are done. + */ + public boolean isDone() { + return watermarkManager.isDone(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index d659d962f0..60c8543a2f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -15,10 +15,15 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.options.Default; import com.google.cloud.dataflow.sdk.options.PipelineOptions; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ -public interface InProcessPipelineOptions extends PipelineOptions {} +public interface InProcessPipelineOptions extends PipelineOptions { + @Default.InstanceFactory(NanosOffsetClock.Factory.class) + Clock getClock(); + void setClock(Clock clock); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 124de46b94..7a268ee5fa 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -17,31 +17,22 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; -import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; -import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; -import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; -import com.google.cloud.dataflow.sdk.util.ExecutionContext; -import com.google.cloud.dataflow.sdk.util.SideInputReader; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; import com.google.cloud.dataflow.sdk.util.WindowedValue; -import com.google.cloud.dataflow.sdk.util.WindowingStrategy; -import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionView; -import com.google.cloud.dataflow.sdk.values.PValue; import com.google.common.collect.ImmutableMap; import org.joda.time.Instant; -import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -82,6 +73,11 @@ public class InProcessPipelineRunner { * @param the type of elements that can be added to this bundle */ public static interface UncommittedBundle { + /** + * Returns the PCollection that the elements of this bundle belong to. + */ + PCollection getPCollection(); + /** * Outputs an element to this bundle. * @@ -110,7 +106,7 @@ public static interface UncommittedBundle { public static interface CommittedBundle { /** - * @return the PCollection that the elements of this bundle belong to + * Returns the PCollection that the elements of this bundle belong to. */ PCollection getPCollection(); @@ -154,84 +150,22 @@ public static interface PCollectionViewWriter { void add(Iterable> values); } - /** - * The evaluation context for the {@link InProcessPipelineRunner}. Contains state shared within - * the current evaluation. - */ - public static interface InProcessEvaluationContext { - /** - * Create a {@link UncommittedBundle} for use by a source. - */ - UncommittedBundle createRootBundle(PCollection output); - - /** - * Create a {@link UncommittedBundle} whose elements belong to the specified {@link - * PCollection}. - */ - UncommittedBundle createBundle(CommittedBundle input, PCollection output); - - /** - * Create a {@link UncommittedBundle} with the specified keys at the specified step. For use by - * {@link GroupByKeyOnly} {@link PTransform PTransforms}. - */ - UncommittedBundle createKeyedBundle( - CommittedBundle input, Object key, PCollection output); - - /** - * Create a bundle whose elements will be used in a PCollectionView. - */ - PCollectionViewWriter createPCollectionViewWriter( - PCollection> input, PCollectionView output); - - /** - * Get the options used by this {@link Pipeline}. - */ - InProcessPipelineOptions getPipelineOptions(); - - /** - * Get an {@link ExecutionContext} for the provided application. - */ - InProcessExecutionContext getExecutionContext( - AppliedPTransform application, @Nullable Object key); - - /** - * Get the Step Name for the provided application. - */ - String getStepName(AppliedPTransform application); - - /** - * @param sideInputs the {@link PCollectionView PCollectionViews} the result should be able to - * read - * @return a {@link SideInputReader} that can read all of the provided - * {@link PCollectionView PCollectionViews} - */ - SideInputReader createSideInputReader(List> sideInputs); + //////////////////////////////////////////////////////////////////////////////////////////////// + private final InProcessPipelineOptions options; - /** - * Schedules a callback after the watermark for a {@link PValue} after the trigger for the - * specified window (with the specified windowing strategy) must have fired from the perspective - * of that {@link PValue}, as specified by the value of - * {@link Trigger#getWatermarkThatGuaranteesFiring(BoundedWindow)} for the trigger of the - * {@link WindowingStrategy}. - */ - void callAfterOutputMustHaveBeenProduced(PValue value, BoundedWindow window, - WindowingStrategy windowingStrategy, Runnable runnable); + public static InProcessPipelineRunner fromOptions(PipelineOptions options) { + return new InProcessPipelineRunner(options.as(InProcessPipelineOptions.class)); + } - /** - * Create a {@link CounterSet} for this {@link Pipeline}. The {@link CounterSet} is independent - * of all other {@link CounterSet CounterSets} created by this call. - * - * The {@link InProcessEvaluationContext} is responsible for unifying the counters present in - * all created {@link CounterSet CounterSets} when the transforms that call this method - * complete. - */ - CounterSet createCounterSet(); + private InProcessPipelineRunner(InProcessPipelineOptions options) { + this.options = options; + } - /** - * Returns all of the counters that have been merged into this context via calls to - * {@link CounterSet#merge(CounterSet)}. - */ - CounterSet getCounters(); + /** + * Returns the {@link PipelineOptions} used to create this {@link InProcessPipelineRunner}. + */ + public InProcessPipelineOptions getPipelineOptions() { + return options; } /** diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java index bf9a2e1c53..37c9fcfa65 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainer.java @@ -17,7 +17,6 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; import com.google.cloud.dataflow.sdk.util.PCollectionViewWindow; @@ -26,6 +25,7 @@ import com.google.cloud.dataflow.sdk.util.WindowingStrategy; import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.common.base.MoreObjects; +import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; @@ -89,7 +89,7 @@ private InProcessSideInputContainer(InProcessEvaluationContext context, * the provided argument. The returned {@link InProcessSideInputContainer} is unmodifiable without * casting, but will change as this {@link InProcessSideInputContainer} is modified. */ - public SideInputReader withViews(Collection> newContainedViews) { + public SideInputReader createReaderForViews(Collection> newContainedViews) { if (!containedViews.containsAll(newContainedViews)) { Set> currentlyContained = ImmutableSet.copyOf(containedViews); Set> newRequested = ImmutableSet.copyOf(newContainedViews); @@ -108,8 +108,20 @@ public SideInputReader withViews(Collection> newContainedView * *

The provided iterable is expected to contain only a single window and pane. */ - public void write(PCollectionView view, Iterable> values) - throws ExecutionException { + public void write(PCollectionView view, Iterable> values) { + Map>> valuesPerWindow = + indexValuesByWindow(values); + for (Map.Entry>> windowValues : + valuesPerWindow.entrySet()) { + updatePCollectionViewWindowValues(view, windowValues.getKey(), windowValues.getValue()); + } + } + + /** + * Index the provided values by all {@link BoundedWindow windows} in which they appear. + */ + private Map>> indexValuesByWindow( + Iterable> values) { Map>> valuesPerWindow = new HashMap<>(); for (WindowedValue value : values) { for (BoundedWindow window : value.getWindows()) { @@ -121,29 +133,40 @@ public void write(PCollectionView view, Iterable> windowValues.add(value); } } - for (Map.Entry>> windowValues : - valuesPerWindow.entrySet()) { - PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, windowValues.getKey()); - SettableFuture>> future = viewByWindows.get(windowedView); + return valuesPerWindow; + } + + /** + * Set the value of the {@link PCollectionView} in the {@link BoundedWindow} to be based on the + * specified values, if the values are part of a later pane than currently exist within the + * {@link PCollectionViewWindow}. + */ + private void updatePCollectionViewWindowValues( + PCollectionView view, BoundedWindow window, Collection> windowValues) { + PCollectionViewWindow windowedView = PCollectionViewWindow.of(view, window); + SettableFuture>> future = null; + try { + future = viewByWindows.get(windowedView); if (future.isDone()) { - try { - Iterator> existingValues = future.get().iterator(); - PaneInfo newPane = windowValues.getValue().iterator().next().getPane(); - // The current value may have no elements, if no elements were produced for the window, - // but we are recieving late data. - if (!existingValues.hasNext() - || newPane.getIndex() > existingValues.next().getPane().getIndex()) { - viewByWindows.invalidate(windowedView); - viewByWindows.get(windowedView).set(windowValues.getValue()); - } - } catch (InterruptedException e) { - // TODO: Handle meaningfully. This should never really happen when the result remains - // useful, but the result could be available and the thread can still be interrupted. - Thread.currentThread().interrupt(); + Iterator> existingValues = future.get().iterator(); + PaneInfo newPane = windowValues.iterator().next().getPane(); + // The current value may have no elements, if no elements were produced for the window, + // but we are recieving late data. + if (!existingValues.hasNext() + || newPane.getIndex() > existingValues.next().getPane().getIndex()) { + viewByWindows.invalidate(windowedView); + viewByWindows.get(windowedView).set(windowValues); } } else { - future.set(windowValues.getValue()); + future.set(windowValues); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (future != null && !future.isDone()) { + future.set(Collections.>emptyList()); } + } catch (ExecutionException e) { + Throwables.propagate(e.getCause()); } } @@ -165,7 +188,7 @@ public T get(final PCollectionView view, final BoundedWindow window) { viewByWindows.get(windowedView); WindowingStrategy windowingStrategy = view.getWindowingStrategyInternal(); - evaluationContext.callAfterOutputMustHaveBeenProduced( + evaluationContext.scheduleAfterOutputWouldBeProduced( view, window, windowingStrategy, new Runnable() { @Override public void run() { diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java index e3ae1a028c..24142c2151 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java index cd79c219bd..af5914bab0 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.ParDoInProcessEvaluator.BundleOutputManager; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java new file mode 100644 index 0000000000..15955724eb --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/StepAndKey.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.common.base.MoreObjects; + +import java.util.Objects; + +/** + * A (Step, Key) pair. This is useful as a map key or cache key for things that are available + * per-step in a keyed manner (e.g. State). + */ +final class StepAndKey { + private final AppliedPTransform step; + private final Object key; + + /** + * Create a new {@link StepAndKey} with the provided step and key. + */ + public static StepAndKey of(AppliedPTransform step, Object key) { + return new StepAndKey(step, key); + } + + private StepAndKey(AppliedPTransform step, Object key) { + this.step = step; + this.key = key; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(StepAndKey.class) + .add("step", step.getFullName()) + .add("key", key) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(step, key); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (!(other instanceof StepAndKey)) { + return false; + } else { + StepAndKey that = (StepAndKey) other; + return Objects.equals(this.step, that.step) + && Objects.equals(this.key, that.key); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java index 3b672e0def..860ddfe48f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorFactory.java @@ -16,7 +16,6 @@ package com.google.cloud.dataflow.sdk.runners.inprocess; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java new file mode 100644 index 0000000000..0c8cb7e80a --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformEvaluatorRegistry.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Flatten.FlattenPCollectionList; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * A {@link TransformEvaluatorFactory} that delegates to primitive {@link TransformEvaluatorFactory} + * implementations based on the type of {@link PTransform} of the application. + */ +class TransformEvaluatorRegistry implements TransformEvaluatorFactory { + public static TransformEvaluatorRegistry defaultRegistry() { + @SuppressWarnings("rawtypes") + ImmutableMap, TransformEvaluatorFactory> primitives = + ImmutableMap., TransformEvaluatorFactory>builder() + .put(Read.Bounded.class, new BoundedReadEvaluatorFactory()) + .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory()) + .put(ParDo.Bound.class, new ParDoSingleEvaluatorFactory()) + .put(ParDo.BoundMulti.class, new ParDoMultiEvaluatorFactory()) + .put( + GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly.class, + new GroupByKeyEvaluatorFactory()) + .put(FlattenPCollectionList.class, new FlattenEvaluatorFactory()) + .put(ViewEvaluatorFactory.WriteView.class, new ViewEvaluatorFactory()) + .build(); + return new TransformEvaluatorRegistry(primitives); + } + + // the TransformEvaluatorFactories can construct instances of all generic types of transform, + // so all instances of a primitive can be handled with the same evaluator factory. + @SuppressWarnings("rawtypes") + private final Map, TransformEvaluatorFactory> factories; + + private TransformEvaluatorRegistry( + @SuppressWarnings("rawtypes") + Map, TransformEvaluatorFactory> factories) { + this.factories = factories; + } + + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + @Nullable CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) + throws Exception { + TransformEvaluatorFactory factory = factories.get(application.getTransform().getClass()); + return factory.forApplication(application, inputBundle, evaluationContext); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java index 4beac337d6..97f0e25d38 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java @@ -21,7 +21,6 @@ import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.PTransform; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java index f47cd1de98..314d81f6aa 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactory.java @@ -17,7 +17,6 @@ import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.coders.VoidCoder; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java new file mode 100644 index 0000000000..27d59b9a64 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutor.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.common.collect.ComparisonChain; +import com.google.common.collect.Ordering; + +import org.joda.time.Instant; + +import java.util.PriorityQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * Executes callbacks that occur based on the progression of the watermark per-step. + * + *

Callbacks are registered by calls to + * {@link #callOnGuaranteedFiring(AppliedPTransform, BoundedWindow, WindowingStrategy, Runnable)}, + * and are executed after a call to {@link #fireForWatermark(AppliedPTransform, Instant)} with the + * same {@link AppliedPTransform} and a watermark sufficient to ensure that the trigger for the + * windowing strategy would have been produced. + * + *

NOTE: {@link WatermarkCallbackExecutor} does not track the latest observed watermark for any + * {@link AppliedPTransform} - any call to + * {@link #callOnGuaranteedFiring(AppliedPTransform, BoundedWindow, WindowingStrategy, Runnable)} + * that could have potentially already fired should be followed by a call to + * {@link #fireForWatermark(AppliedPTransform, Instant)} for the same transform with the current + * value of the watermark. + */ +class WatermarkCallbackExecutor { + /** + * Create a new {@link WatermarkCallbackExecutor}. + */ + public static WatermarkCallbackExecutor create() { + return new WatermarkCallbackExecutor(); + } + + private final ConcurrentMap, PriorityQueue> + callbacks; + private final ExecutorService executor; + + private WatermarkCallbackExecutor() { + this.callbacks = new ConcurrentHashMap<>(); + this.executor = Executors.newSingleThreadExecutor(); + } + + /** + * Execute the provided {@link Runnable} after the next call to + * {@link #fireForWatermark(AppliedPTransform, Instant)} where the window is guaranteed to have + * produced output. + */ + public void callOnGuaranteedFiring( + AppliedPTransform step, + BoundedWindow window, + WindowingStrategy windowingStrategy, + Runnable runnable) { + WatermarkCallback callback = + WatermarkCallback.onGuaranteedFiring(window, windowingStrategy, runnable); + + PriorityQueue callbackQueue = callbacks.get(step); + if (callbackQueue == null) { + callbackQueue = new PriorityQueue<>(11, new CallbackOrdering()); + if (callbacks.putIfAbsent(step, callbackQueue) != null) { + callbackQueue = callbacks.get(step); + } + } + + synchronized (callbackQueue) { + callbackQueue.offer(callback); + } + } + + /** + * Schedule all pending callbacks that must have produced output by the time of the provided + * watermark. + */ + public void fireForWatermark(AppliedPTransform step, Instant watermark) { + PriorityQueue callbackQueue = callbacks.get(step); + if (callbackQueue == null) { + return; + } + synchronized (callbackQueue) { + while (!callbackQueue.isEmpty() && callbackQueue.peek().shouldFire(watermark)) { + executor.submit(callbackQueue.poll().getCallback()); + } + } + } + + private static class WatermarkCallback { + public static WatermarkCallback onGuaranteedFiring( + BoundedWindow window, WindowingStrategy strategy, Runnable callback) { + @SuppressWarnings("unchecked") + Instant firingAfter = + strategy.getTrigger().getSpec().getWatermarkThatGuaranteesFiring((W) window); + return new WatermarkCallback(firingAfter, callback); + } + + private final Instant fireAfter; + private final Runnable callback; + + private WatermarkCallback(Instant fireAfter, Runnable callback) { + this.fireAfter = fireAfter; + this.callback = callback; + } + + public boolean shouldFire(Instant currentWatermark) { + return currentWatermark.isAfter(fireAfter) + || currentWatermark.equals(BoundedWindow.TIMESTAMP_MAX_VALUE); + } + + public Runnable getCallback() { + return callback; + } + } + + private static class CallbackOrdering extends Ordering { + @Override + public int compare(WatermarkCallback left, WatermarkCallback right) { + return ComparisonChain.start() + .compare(left.fireAfter, right.fireAfter) + .compare(left.callback, right.callback, Ordering.arbitrary()) + .result(); + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java index 9f22fbbe9e..43955149ee 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java @@ -25,7 +25,7 @@ import com.google.cloud.dataflow.sdk.io.BoundedSource; import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; +import com.google.cloud.dataflow.sdk.io.Read.Bounded; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java index bf25970aff..0120b9880d 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/FlattenEvaluatorFactoryTest.java @@ -22,7 +22,6 @@ import static org.mockito.Mockito.when; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java index 5c9e824afe..4ced82f8c7 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactoryTest.java @@ -23,7 +23,6 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java index 24251522d8..52398cf73a 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManagerTest.java @@ -1047,6 +1047,18 @@ public void timerUpdateBuilderWithCompletedAfterBuildNotAddedToBuilt() { assertThat(built.getCompletedTimers(), emptyIterable()); } + @Test + public void timerUpdateWithCompletedTimersNotAddedToExisting() { + TimerUpdateBuilder builder = TimerUpdate.builder(null); + TimerData timer = TimerData.of(StateNamespaces.global(), Instant.now(), TimeDomain.EVENT_TIME); + + TimerUpdate built = builder.build(); + assertThat(built.getCompletedTimers(), emptyIterable()); + assertThat( + built.withCompletedTimers(ImmutableList.of(timer)).getCompletedTimers(), contains(timer)); + assertThat(built.getCompletedTimers(), emptyIterable()); + } + private static Matcher earlierThan(final Instant laterInstant) { return new BaseMatcher() { @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java new file mode 100644 index 0000000000..149096040a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContextTest.java @@ -0,0 +1,436 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessExecutionContext.InProcessStepContext; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo.Timing; +import com.google.cloud.dataflow.sdk.util.SideInputReader; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.Counter.AggregationKind; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; +import com.google.cloud.dataflow.sdk.util.state.BagState; +import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import com.google.cloud.dataflow.sdk.util.state.StateTag; +import com.google.cloud.dataflow.sdk.util.state.StateTags; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link InProcessEvaluationContext}. + */ +@RunWith(JUnit4.class) +public class InProcessEvaluationContextTest { + private TestPipeline p; + private InProcessEvaluationContext context; + private PCollection created; + private PCollection> downstream; + private PCollectionView> view; + + @Before + public void setup() { + InProcessPipelineRunner runner = + InProcessPipelineRunner.fromOptions(PipelineOptionsFactory.create()); + p = TestPipeline.create(); + created = p.apply(Create.of(1, 2, 3)); + downstream = created.apply(WithKeys.of("foo")); + view = created.apply(View.asIterable()); + Collection> rootTransforms = + ImmutableList.>of(created.getProducingTransformInternal()); + Map>> valueToConsumers = new HashMap<>(); + valueToConsumers.put( + created, + ImmutableList.>of( + downstream.getProducingTransformInternal(), view.getProducingTransformInternal())); + valueToConsumers.put(downstream, ImmutableList.>of()); + valueToConsumers.put(view, ImmutableList.>of()); + + Map, String> stepNames = new HashMap<>(); + stepNames.put(created.getProducingTransformInternal(), "s1"); + stepNames.put(downstream.getProducingTransformInternal(), "s2"); + stepNames.put(view.getProducingTransformInternal(), "s3"); + + Collection> views = ImmutableList.>of(view); + context = InProcessEvaluationContext.create( + runner.getPipelineOptions(), + rootTransforms, + valueToConsumers, + stepNames, + views); + } + + @Test + public void writeToViewWriterThenReadReads() { + PCollectionViewWriter> viewWriter = + context.createPCollectionViewWriter( + PCollection.>createPrimitiveOutputInternal( + p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED), + view); + BoundedWindow window = new TestBoundedWindow(new Instant(1024L)); + BoundedWindow second = new TestBoundedWindow(new Instant(899999L)); + WindowedValue firstValue = + WindowedValue.of(1, new Instant(1222), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); + WindowedValue secondValue = + WindowedValue.of( + 2, new Instant(8766L), second, PaneInfo.createPane(true, false, Timing.ON_TIME, 0, 0)); + Iterable> values = ImmutableList.of(firstValue, secondValue); + viewWriter.add(values); + + SideInputReader reader = + context.createSideInputReader(ImmutableList.>of(view)); + assertThat(reader.get(view, window), containsInAnyOrder(1)); + assertThat(reader.get(view, second), containsInAnyOrder(2)); + + WindowedValue overrittenSecondValue = + WindowedValue.of( + 4444, new Instant(8677L), second, PaneInfo.createPane(false, true, Timing.LATE, 1, 1)); + viewWriter.add(Collections.singleton(overrittenSecondValue)); + assertThat(reader.get(view, second), containsInAnyOrder(4444)); + } + + @Test + public void getExecutionContextSameStepSameKeyState() { + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + InProcessStepContext stepContext = fooContext.getOrCreateStepContext("s1", "s1", null); + stepContext.stateInternals().state(StateNamespaces.global(), intBag).add(1); + + context.handleResult( + InProcessBundle.keyed(created, "foo").commit(Instant.now()), + ImmutableList.of(), + StepTransformResult.withoutHold(created.getProducingTransformInternal()) + .withState(stepContext.commitState()) + .build()); + + InProcessExecutionContext secondFooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + assertThat( + secondFooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + contains(1)); + } + + + @Test + public void getExecutionContextDifferentKeysIndependentState() { + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), "foo"); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + fooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .add(1); + + InProcessExecutionContext barContext = + context.getExecutionContext(created.getProducingTransformInternal(), "bar"); + assertThat(barContext, not(equalTo(fooContext))); + assertThat( + barContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + emptyIterable()); + } + + @Test + public void getExecutionContextDifferentStepsIndependentState() { + String myKey = "foo"; + InProcessExecutionContext fooContext = + context.getExecutionContext(created.getProducingTransformInternal(), myKey); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + fooContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .add(1); + + InProcessExecutionContext barContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + assertThat( + barContext + .getOrCreateStepContext("s1", "s1", null) + .stateInternals() + .state(StateNamespaces.global(), intBag) + .read(), + emptyIterable()); + } + + @Test + public void handleResultMergesCounters() { + CounterSet counters = context.createCounterSet(); + Counter myCounter = Counter.longs("foo", AggregationKind.SUM); + counters.addCounter(myCounter); + + myCounter.addValue(4L); + InProcessTransformResult result = + StepTransformResult.withoutHold(created.getProducingTransformInternal()) + .withCounters(counters) + .build(); + context.handleResult(null, ImmutableList.of(), result); + assertThat((Long) context.getCounters().getExistingCounter("foo").getAggregate(), equalTo(4L)); + + CounterSet againCounters = context.createCounterSet(); + Counter myLongCounterAgain = Counter.longs("foo", AggregationKind.SUM); + againCounters.add(myLongCounterAgain); + myLongCounterAgain.addValue(8L); + + InProcessTransformResult secondResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withCounters(againCounters) + .build(); + context.handleResult( + InProcessBundle.unkeyed(created).commit(Instant.now()), + ImmutableList.of(), + secondResult); + assertThat((Long) context.getCounters().getExistingCounter("foo").getAggregate(), equalTo(12L)); + } + + @Test + public void handleResultStoresState() { + String myKey = "foo"; + InProcessExecutionContext fooContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + + StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); + + CopyOnAccessInMemoryStateInternals state = + fooContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + BagState bag = state.state(StateNamespaces.global(), intBag); + bag.add(1); + bag.add(2); + bag.add(4); + + InProcessTransformResult stateResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withState(state) + .build(); + + context.handleResult( + InProcessBundle.keyed(created, myKey).commit(Instant.now()), + ImmutableList.of(), + stateResult); + + InProcessExecutionContext afterResultContext = + context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); + + CopyOnAccessInMemoryStateInternals afterResultState = + afterResultContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + assertThat(afterResultState.state(StateNamespaces.global(), intBag).read(), contains(1, 2, 4)); + } + + @Test + public void callAfterOutputMustHaveBeenProducedAfterEndOfWatermarkCallsback() throws Exception { + final CountDownLatch callLatch = new CountDownLatch(1); + Runnable callback = + new Runnable() { + @Override + public void run() { + callLatch.countDown(); + } + }; + + // Should call back after the end of the global window + context.scheduleAfterOutputWouldBeProduced( + downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback); + + InProcessTransformResult result = + StepTransformResult.withHold(created.getProducingTransformInternal(), new Instant(0)) + .build(); + + context.handleResult(null, ImmutableList.of(), result); + + // Difficult to demonstrate that we took no action in a multithreaded world; poll for a bit + // will likely be flaky if this logic is broken + assertThat(callLatch.await(500L, TimeUnit.MILLISECONDS), is(false)); + + InProcessTransformResult finishedResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + context.handleResult(null, ImmutableList.of(), finishedResult); + // Obtain the value via blocking call + assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true)); + } + + @Test + public void callAfterOutputMustHaveBeenProducedAlreadyAfterCallsImmediately() throws Exception { + InProcessTransformResult finishedResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + context.handleResult(null, ImmutableList.of(), finishedResult); + + final CountDownLatch callLatch = new CountDownLatch(1); + Runnable callback = + new Runnable() { + @Override + public void run() { + callLatch.countDown(); + } + }; + context.scheduleAfterOutputWouldBeProduced( + downstream, GlobalWindow.INSTANCE, WindowingStrategy.globalDefault(), callback); + assertThat(callLatch.await(1, TimeUnit.SECONDS), is(true)); + } + + @Test + public void extractFiredTimersExtractsTimers() { + InProcessTransformResult holdResult = + StepTransformResult.withHold(created.getProducingTransformInternal(), new Instant(0)) + .build(); + context.handleResult(null, ImmutableList.of(), holdResult); + + String key = "foo"; + TimerData toFire = + TimerData.of(StateNamespaces.global(), new Instant(100L), TimeDomain.EVENT_TIME); + InProcessTransformResult timerResult = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()) + .withState(CopyOnAccessInMemoryStateInternals.withUnderlying(key, null)) + .withTimerUpdate(TimerUpdate.builder(key).setTimer(toFire).build()) + .build(); + + // haven't added any timers, must be empty + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + context.handleResult( + InProcessBundle.keyed(created, key).commit(Instant.now()), + ImmutableList.of(), + timerResult); + + // timer hasn't fired + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + + InProcessTransformResult advanceResult = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + // Should cause the downstream timer to fire + context.handleResult(null, ImmutableList.of(), advanceResult); + + Map, Map> fired = context.extractFiredTimers(); + assertThat( + fired, + Matchers.>hasKey(downstream.getProducingTransformInternal())); + Map downstreamFired = + fired.get(downstream.getProducingTransformInternal()); + assertThat(downstreamFired, Matchers.hasKey(key)); + + FiredTimers firedForKey = downstreamFired.get(key); + assertThat(firedForKey.getTimers(TimeDomain.PROCESSING_TIME), emptyIterable()); + assertThat(firedForKey.getTimers(TimeDomain.SYNCHRONIZED_PROCESSING_TIME), emptyIterable()); + assertThat(firedForKey.getTimers(TimeDomain.EVENT_TIME), contains(toFire)); + + // Don't reextract timers + assertThat(context.extractFiredTimers().entrySet(), emptyIterable()); + } + + @Test + public void createBundleUnkeyedResultUnkeyed() { + CommittedBundle> newBundle = + context + .createBundle(InProcessBundle.unkeyed(created).commit(Instant.now()), downstream) + .commit(Instant.now()); + assertThat(newBundle.isKeyed(), is(false)); + } + + @Test + public void createBundleKeyedResultPropagatesKey() { + CommittedBundle> newBundle = + context + .createBundle(InProcessBundle.keyed(created, "foo").commit(Instant.now()), downstream) + .commit(Instant.now()); + assertThat(newBundle.isKeyed(), is(true)); + assertThat(newBundle.getKey(), Matchers.equalTo("foo")); + } + + @Test + public void createRootBundleUnkeyed() { + assertThat(context.createRootBundle(created).commit(Instant.now()).isKeyed(), is(false)); + } + + @Test + public void createKeyedBundleKeyed() { + CommittedBundle> keyedBundle = + context + .createKeyedBundle( + InProcessBundle.unkeyed(created).commit(Instant.now()), "foo", downstream) + .commit(Instant.now()); + assertThat(keyedBundle.isKeyed(), is(true)); + assertThat(keyedBundle.getKey(), Matchers.equalTo("foo")); + } + + private static class TestBoundedWindow extends BoundedWindow { + private final Instant ts; + + public TestBoundedWindow(Instant ts) { + this.ts = ts; + } + + @Override + public Instant maxTimestamp() { + return ts; + } + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java index 4cfe782936..16b4eb7d07 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessSideInputContainerTest.java @@ -24,7 +24,6 @@ import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.Mean; @@ -137,7 +136,7 @@ public void getAfterWriteReturnsPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(one, two)); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, firstWindow); assertThat(viewContents, hasEntry("one", 1)); assertThat(viewContents, hasEntry("two", 2)); @@ -153,7 +152,7 @@ public void getReturnsLatestPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(one, two)); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(viewContents, hasEntry("one", 1)); assertThat(viewContents, hasEntry("two", 2)); @@ -164,7 +163,7 @@ public void getReturnsLatestPaneInWindow() throws Exception { container.write(mapView, ImmutableList.>of(three)); Map overwrittenViewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(overwrittenViewContents, hasEntry("three", 3)); assertThat(overwrittenViewContents.size(), is(1)); @@ -176,15 +175,18 @@ public void getReturnsLatestPaneInWindow() throws Exception { */ @Test public void getBlocksUntilPaneAvailable() throws Exception { - BoundedWindow window = new BoundedWindow() { - @Override - public Instant maxTimestamp() { - return new Instant(1024L); - } - }; + BoundedWindow window = + new BoundedWindow() { + @Override + public Instant maxTimestamp() { + return new Instant(1024L); + } + }; Future singletonFuture = - getFutureOfView(container.withViews(ImmutableList.>of(singletonView)), - singletonView, window); + getFutureOfView( + container.createReaderForViews(ImmutableList.>of(singletonView)), + singletonView, + window); WindowedValue singletonValue = WindowedValue.of(4.75, new Instant(475L), window, PaneInfo.ON_TIME_AND_ONLY_FIRING); @@ -203,7 +205,7 @@ public Instant maxTimestamp() { } }; SideInputReader newReader = - container.withViews(ImmutableList.>of(singletonView)); + container.createReaderForViews(ImmutableList.>of(singletonView)); Future singletonFuture = getFutureOfView(newReader, singletonView, window); WindowedValue singletonValue = @@ -216,25 +218,31 @@ public Instant maxTimestamp() { @Test public void withPCollectionViewsErrorsForContainsNotInViews() { - PCollectionView>> newView = PCollectionViews.multimapView(pipeline, - WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + PCollectionView>> newView = + PCollectionViews.multimapView( + pipeline, + WindowingStrategy.globalDefault(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("with unknown views " + ImmutableList.of(newView).toString()); - container.withViews(ImmutableList.>of(newView)); + container.createReaderForViews(ImmutableList.>of(newView)); } @Test public void withViewsForViewNotInContainerFails() { - PCollectionView>> newView = PCollectionViews.multimapView(pipeline, - WindowingStrategy.globalDefault(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + PCollectionView>> newView = + PCollectionViews.multimapView( + pipeline, + WindowingStrategy.globalDefault(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("unknown views"); thrown.expectMessage(newView.toString()); - container.withViews(ImmutableList.>of(newView)); + container.createReaderForViews(ImmutableList.>of(newView)); } @Test @@ -242,7 +250,7 @@ public void getOnReaderForViewNotInReaderFails() { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("unknown view: " + iterableView.toString()); - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(iterableView, GlobalWindow.INSTANCE); } @@ -255,11 +263,11 @@ public void writeForMultipleElementsInDifferentWindowsSucceeds() throws Exceptio PaneInfo.ON_TIME_AND_ONLY_FIRING); container.write(singletonView, ImmutableList.of(firstWindowedValue, secondWindowedValue)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, firstWindow), equalTo(2.875)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, secondWindow), equalTo(4.125)); } @@ -274,7 +282,7 @@ public void writeForMultipleIdenticalElementsInSameWindowSucceeds() throws Excep container.write(iterableView, ImmutableList.of(firstValue, secondValue)); assertThat( - container.withViews(ImmutableList.>of(iterableView)) + container.createReaderForViews(ImmutableList.>of(iterableView)) .get(iterableView, firstWindow), contains(44, 44)); } @@ -286,11 +294,11 @@ public void writeForElementInMultipleWindowsSucceeds() throws Exception { ImmutableList.of(firstWindow, secondWindow), PaneInfo.ON_TIME_AND_ONLY_FIRING); container.write(singletonView, ImmutableList.of(multiWindowedValue)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, firstWindow), equalTo(2.875)); assertThat( - container.withViews(ImmutableList.>of(singletonView)) + container.createReaderForViews(ImmutableList.>of(singletonView)) .get(singletonView, secondWindow), equalTo(2.875)); } @@ -306,7 +314,7 @@ public void finishDoesNotOverwriteWrittenElements() throws Exception { immediatelyInvokeCallback(mapView, secondWindow); Map viewContents = - container.withViews(ImmutableList.>of(mapView)) + container.createReaderForViews(ImmutableList.>of(mapView)) .get(mapView, secondWindow); assertThat(viewContents, hasEntry("one", 1)); @@ -317,8 +325,11 @@ public void finishDoesNotOverwriteWrittenElements() throws Exception { @Test public void finishOnPendingViewsSetsEmptyElements() throws Exception { immediatelyInvokeCallback(mapView, secondWindow); - Future> mapFuture = getFutureOfView( - container.withViews(ImmutableList.>of(mapView)), mapView, secondWindow); + Future> mapFuture = + getFutureOfView( + container.createReaderForViews(ImmutableList.>of(mapView)), + mapView, + secondWindow); assertThat(mapFuture.get().isEmpty(), is(true)); } @@ -329,18 +340,21 @@ public void finishOnPendingViewsSetsEmptyElements() throws Exception { */ private void immediatelyInvokeCallback(PCollectionView view, BoundedWindow window) { doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - Object callback = invocation.getArguments()[3]; - Runnable callbackRunnable = (Runnable) callback; - callbackRunnable.run(); - return null; - } - }) + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + Object callback = invocation.getArguments()[3]; + Runnable callbackRunnable = (Runnable) callback; + callbackRunnable.run(); + return null; + } + }) .when(context) - .callAfterOutputMustHaveBeenProduced(Mockito.eq(view), Mockito.eq(window), - Mockito.eq(view.getWindowingStrategyInternal()), Mockito.any(Runnable.class)); + .scheduleAfterOutputWouldBeProduced( + Mockito.eq(view), + Mockito.eq(window), + Mockito.eq(view.getWindowingStrategyInternal()), + Mockito.any(Runnable.class)); } private Future getFutureOfView(final SideInputReader myReader, diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java index 033f9de204..66430b6a7f 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoMultiEvaluatorFactoryTest.java @@ -26,7 +26,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java index ae599bab62..3b928b9077 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ParDoSingleEvaluatorFactoryTest.java @@ -26,7 +26,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.TimerUpdate; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java index f139c5648e..a9bbcc8cc5 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java @@ -25,7 +25,6 @@ import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.UnboundedSource; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java index 2f5cd0fb88..2f5bdde6ad 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ViewEvaluatorFactoryTest.java @@ -25,7 +25,6 @@ import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.coders.VoidCoder; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; -import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessEvaluationContext; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.PCollectionViewWriter; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java new file mode 100644 index 0000000000..be3e062fbd --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WatermarkCallbackExecutorTest.java @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link WatermarkCallbackExecutor}. + */ +@RunWith(JUnit4.class) +public class WatermarkCallbackExecutorTest { + private WatermarkCallbackExecutor executor = WatermarkCallbackExecutor.create(); + private AppliedPTransform create; + private AppliedPTransform sum; + + @Before + public void setup() { + TestPipeline p = TestPipeline.create(); + PCollection created = p.apply(Create.of(1, 2, 3)); + create = created.getProducingTransformInternal(); + sum = created.apply(Sum.integersGlobally()).getProducingTransformInternal(); + } + + @Test + public void onGuaranteedFiringFiresAfterTrigger() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + executor.callOnGuaranteedFiring( + create, + GlobalWindow.INSTANCE, + WindowingStrategy.globalDefault(), + new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, BoundedWindow.TIMESTAMP_MAX_VALUE); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(true)); + } + + @Test + public void multipleCallbacksShouldFireFires() throws Exception { + CountDownLatch latch = new CountDownLatch(2); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(10))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(true)); + } + + @Test + public void noCallbacksShouldFire() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + create, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(5))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(false)); + } + + @Test + public void unrelatedStepShouldNotFire() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + WindowFn windowFn = FixedWindows.of(Duration.standardMinutes(10)); + IntervalWindow window = + new IntervalWindow(new Instant(0L), new Instant(0L).plus(Duration.standardMinutes(10))); + executor.callOnGuaranteedFiring( + sum, window, WindowingStrategy.of(windowFn), new CountDownLatchCallback(latch)); + + executor.fireForWatermark(create, new Instant(0L).plus(Duration.standardMinutes(20))); + assertThat(latch.await(500, TimeUnit.MILLISECONDS), equalTo(false)); + } + + private static class CountDownLatchCallback implements Runnable { + private final CountDownLatch latch; + + public CountDownLatchCallback(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public void run() { + latch.countDown(); + } + } +} From 5104624701a1d05344603ab6081be1be74124ac0 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 16 Mar 2016 13:28:02 -0700 Subject: [PATCH 03/11] Fix AfterWatermark Early and Late javadoc The docs were reversed - the late trigger is only considered after the watermark has passed the end of the window, and the early trigger only before the watermark has passed the end of the window. --- .../dataflow/sdk/transforms/windowing/AfterWatermark.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java index da16db99c6..fac2c2841b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/windowing/AfterWatermark.java @@ -80,7 +80,7 @@ public static FromEndOfWindow pastEndOfWindow() { public interface AfterWatermarkEarly extends TriggerBuilder { /** * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever - * the given {@code Trigger} fires before the watermark has passed the end of the window. + * the given {@code Trigger} fires after the watermark has passed the end of the window. */ TriggerBuilder withLateFirings(OnceTrigger lateTrigger); } @@ -91,7 +91,7 @@ public interface AfterWatermarkEarly extends TriggerBui public interface AfterWatermarkLate extends TriggerBuilder { /** * Creates a new {@code Trigger} like the this, except that it fires repeatedly whenever - * the given {@code Trigger} fires after the watermark has passed the end of the window. + * the given {@code Trigger} fires before the watermark has passed the end of the window. */ TriggerBuilder withEarlyFirings(OnceTrigger earlyTrigger); } From c2fe45d3271deff14b2ba3627c287012ea072898 Mon Sep 17 00:00:00 2001 From: Pei He Date: Thu, 10 Mar 2016 14:17:36 -0800 Subject: [PATCH 04/11] [BEAM-80] Enable combiner lifting for combine with contexts --- .../runners/DataflowPipelineTranslator.java | 3 +++ .../cloud/dataflow/sdk/transforms/Combine.java | 18 +++--------------- .../cloud/dataflow/sdk/util/PropertyNames.java | 1 + 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java index d0cc4e53d5..0feae957f8 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -952,6 +952,9 @@ private void groupByKeyHelper( context.addInput( PropertyNames.SERIALIZED_FN, byteArrayToJsonString(serializeToByteArray(windowingStrategy))); + context.addInput( + PropertyNames.IS_MERGING_WINDOW_FN, + !windowingStrategy.getWindowFn().isNonMerging()); } }); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java index cc0347a124..b8d20e303f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/Combine.java @@ -1690,21 +1690,9 @@ public List> getSideInputs() { @Override public PCollection> apply(PCollection> input) { - if (fn instanceof RequiresContextInternal) { - return input - .apply(GroupByKey.create(fewKeys)) - .apply(ParDo.of(new DoFn>, KV>>() { - @Override - public void processElement(ProcessContext c) throws Exception { - c.output(c.element()); - } - })) - .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); - } else { - return input - .apply(GroupByKey.create(fewKeys)) - .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); - } + return input + .apply(GroupByKey.create(fewKeys)) + .apply(Combine.groupedValues(fn).withSideInputs(sideInputs)); } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java index 5611fabe28..ec6518976b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PropertyNames.java @@ -65,6 +65,7 @@ public class PropertyNames { public static final String INPUTS = "inputs"; public static final String INPUT_CODER = "input_coder"; public static final String IS_GENERATED = "is_generated"; + public static final String IS_MERGING_WINDOW_FN = "is_merging_window_fn"; public static final String IS_PAIR_LIKE = "is_pair_like"; public static final String IS_STREAM_LIKE = "is_stream_like"; public static final String IS_WRAPPER = "is_wrapper"; From 1556b0462bf9c9f2670a218f27d7698af659aa98 Mon Sep 17 00:00:00 2001 From: Pei He Date: Fri, 4 Mar 2016 13:54:34 -0800 Subject: [PATCH 05/11] [BEAM-96] Add composed `CombineFn` builders in `CombineFns` * `compose()` or `composeKeyed()` are used to start composition * `with()` is used to add an input-transformation, a `CombineFn` and an output `TupleTag`. * A non-`CombineFn` initial builder is used to ensure that every composition includes at least one item * Duplicate output tags are not allowed in the same composition --- .../dataflow/sdk/transforms/CombineFns.java | 1100 +++++++++++++++++ .../sdk/transforms/CombineFnsTest.java | 413 +++++++ 2 files changed, 1513 insertions(+) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java new file mode 100644 index 0000000000..656c010d91 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java @@ -0,0 +1,1100 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Static utility methods that create combine function instances. + */ +public class CombineFns { + + /** + * Returns a {@link ComposeKeyedCombineFnBuilder} to construct a composed + * {@link PerKeyCombineFn}. + * + *

The same {@link TupleTag} cannot be used in a composition multiple times. + * + *

Example: + *

{ @code
+   * PCollection> latencies = ...;
+   *
+   * TupleTag maxLatencyTag = new TupleTag();
+   * TupleTag meanLatencyTag = new TupleTag();
+   *
+   * SimpleFunction identityFn =
+   *     new SimpleFunction() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection> maxAndMean = latencies.apply(
+   *     Combine.perKey(
+   *         CombineFns.composeKeyed()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn(), meanLatencyTag)));
+   *
+   * PCollection finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn, T>() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             KV e = c.element();
+   *             Integer maxLatency = e.getValue().get(maxLatencyTag);
+   *             Double meanLatency = e.getValue().get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } 
+ */ + public static ComposeKeyedCombineFnBuilder composeKeyed() { + return new ComposeKeyedCombineFnBuilder(); + } + + /** + * Returns a {@link ComposeCombineFnBuilder} to construct a composed + * {@link GlobalCombineFn}. + * + *

The same {@link TupleTag} cannot be used in a composition multiple times. + * + *

Example: + *

{ @code
+   * PCollection globalLatencies = ...;
+   *
+   * TupleTag maxLatencyTag = new TupleTag();
+   * TupleTag meanLatencyTag = new TupleTag();
+   *
+   * SimpleFunction identityFn =
+   *     new SimpleFunction() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection maxAndMean = globalLatencies.apply(
+   *     Combine.globally(
+   *         CombineFns.compose()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn(), meanLatencyTag)));
+   *
+   * PCollection finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             CoCombineResult e = c.element();
+   *             Integer maxLatency = e.get(maxLatencyTag);
+   *             Double meanLatency = e.get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } 
+ */ + public static ComposeCombineFnBuilder compose() { + return new ComposeCombineFnBuilder(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A builder class to construct a composed {@link PerKeyCombineFn}. + */ + public static class ComposeKeyedCombineFnBuilder { + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedKeyedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + KeyedCombineFn keyedCombineFn, + TupleTag outputTag) { + return new ComposedKeyedCombineFn() + .with(extractInputFn, keyedCombineFn, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedKeyedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + KeyedCombineFnWithContext keyedCombineFnWithContext, + TupleTag outputTag) { + return new ComposedKeyedCombineFnWithContext() + .with(extractInputFn, keyedCombineFnWithContext, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + return with(extractInputFn, combineFn.asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFnWithContext, + TupleTag outputTag) { + return with(extractInputFn, combineFnWithContext.asKeyedFn(), outputTag); + } + } + + /** + * A builder class to construct a composed {@link GlobalCombineFn}. + */ + public static class ComposeCombineFnBuilder { + /** + * Returns a {@link ComposedCombineFn} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + return new ComposedCombineFn() + .with(extractInputFn, combineFn, outputTag); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + *

The {@link ComposedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFnWithContext, + TupleTag outputTag) { + return new ComposedCombineFnWithContext() + .with(extractInputFn, combineFnWithContext, outputTag); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A tuple of outputs produced by a composed combine functions. + * + *

See {@link #compose()} or {@link #composeKeyed()}) for details. + */ + public static class CoCombineResult implements Serializable { + + private enum NullValue { + INSTANCE; + } + + private final Map, Object> valuesMap; + + /** + * The constructor of {@link CoCombineResult}. + * + *

Null values should have been filtered out from the {@code valuesMap}. + * {@link TupleTag TupleTags} that associate with null values doesn't exist in the key set of + * {@code valuesMap}. + * + * @throws NullPointerException if any key or value in {@code valuesMap} is null + */ + CoCombineResult(Map, Object> valuesMap) { + ImmutableMap.Builder, Object> builder = ImmutableMap.builder(); + for (Entry, Object> entry : valuesMap.entrySet()) { + if (entry.getValue() != null) { + builder.put(entry); + } else { + builder.put(entry.getKey(), NullValue.INSTANCE); + } + } + this.valuesMap = builder.build(); + } + + /** + * Returns the value represented by the given {@link TupleTag}. + * + *

It is an error to request a non-exist tuple tag from the {@link CoCombineResult}. + */ + @SuppressWarnings("unchecked") + public V get(TupleTag tag) { + checkArgument( + valuesMap.keySet().contains(tag), "TupleTag " + tag + " is not in the CoCombineResult"); + Object value = valuesMap.get(tag); + if (value == NullValue.INSTANCE) { + return null; + } else { + return (V) value; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A composed {@link CombineFn} that applies multiple {@link CombineFn CombineFns}. + * + *

For each {@link CombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFn extends CombineFn { + + private final List> combineFns; + private final List> extractInputFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.combineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFn( + ImmutableList> extractInputFns, + ImmutableList> combineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedCombineFns = (List) combineFns; + this.combineFns = castedCombineFns; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFns.size(); + } + + /** + * Returns a {@link ComposedCombineFn} with an additional {@link CombineFn}. + */ + public ComposedCombineFn with( + SimpleFunction extractInputFn, + CombineFn combineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFn<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(combineFns) + .add(combineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext combineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + List> fnsWithContext = Lists.newArrayList(); + for (CombineFn fn : combineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(fnsWithContext) + .add(combineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator() { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFns.get(i).createAccumulator(); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFns.get(i).addInput(accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFns.get(i).extractOutput(accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFns.get(i).compact(accumulator[i]); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFns.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link CombineFnWithContext} that applies multiple + * {@link CombineFnWithContext CombineFnWithContexts}. + * + *

For each {@link CombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFnWithContext + extends CombineFnWithContext { + + private final List> extractInputFns; + private final List> combineFnWithContexts; + private final List> outputTags; + private final int combineFnCount; + + private ComposedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.combineFnWithContexts = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFnWithContext( + ImmutableList> extractInputFns, + ImmutableList> combineFnWithContexts, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"rawtypes", "unchecked"}) + List> castedCombineFnWithContexts + = (List) combineFnWithContexts; + this.combineFnWithContexts = castedCombineFnWithContexts; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFnWithContexts.size(); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional {@link GlobalCombineFn}. + */ + public ComposedCombineFnWithContext with( + SimpleFunction extractInputFn, + GlobalCombineFn globalCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(combineFnWithContexts) + .add(toFnWithContext(globalCombineFn)) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator(Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFnWithContexts.get(i).createAccumulator(c); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFnWithContexts.get(i).addInput(accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable accumulators, Context c) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFnWithContexts.get(i).mergeAccumulators( + new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator, Context c) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFnWithContexts.get(i).extractOutput(accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFnWithContexts.get(i).compact(accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFnWithContexts.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFn} that applies multiple {@link KeyedCombineFn KeyedCombineFns}. + * + *

For each {@link KeyedCombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFn + extends KeyedCombineFn { + + private final List> extractInputFns; + private final List> keyedCombineFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFn( + ImmutableList> extractInputFns, + ImmutableList> keyedCombineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link KeyedCombineFn}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + KeyedCombineFn keyedCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFn<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(keyedCombineFns) + .add(keyedCombineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link KeyedCombineFnWithContext}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + KeyedCombineFnWithContext keyedCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + List> fnsWithContext = + Lists.newArrayList(); + for (KeyedCombineFn fn : keyedCombineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(fnsWithContext) + .add(keyedCombineFn) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link CombineFn}. + */ + public ComposedKeyedCombineFn with( + SimpleFunction extractInputFn, + CombineFn keyedCombineFn, + TupleTag outputTag) { + return with(extractInputFn, keyedCombineFn.asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + CombineFnWithContext keyedCombineFn, + TupleTag outputTag) { + return with(extractInputFn, keyedCombineFn.asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, final Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i]); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder(registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFnWithContext} that applies multiple + * {@link KeyedCombineFnWithContext KeyedCombineFnWithContexts}. + * + *

For each {@link KeyedCombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFnWithContext + extends KeyedCombineFnWithContext { + + private final List> extractInputFns; + private final List> keyedCombineFns; + private final List> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFnWithContext( + ImmutableList> extractInputFns, + ImmutableList> keyedCombineFns, + ImmutableList> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link PerKeyCombineFn}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + PerKeyCombineFn perKeyCombineFn, + TupleTag outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.>builder() + .addAll(keyedCombineFns) + .add(toFnWithContext(perKeyCombineFn)) + .build(), + ImmutableList.>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link GlobalCombineFn}. + */ + public ComposedKeyedCombineFnWithContext with( + SimpleFunction extractInputFn, + GlobalCombineFn perKeyCombineFn, + TupleTag outputTag) { + return with(extractInputFn, perKeyCombineFn.asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key, Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key, c); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, Iterable accumulators, Context c) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key, c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator, Context c) { + Map, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder keyCoder, Coder dataCoder) + throws CannotProvideCoderException { + List> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder( + registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private static class ProjectionIterable implements Iterable { + private final Iterable iterable; + private final int column; + + private ProjectionIterable(Iterable iterable, int column) { + this.iterable = iterable; + this.column = column; + } + + @Override + public Iterator iterator() { + final Iterator iter = iterable.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Object next() { + return iter.next()[column]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + } + + private static class ComposedAccumulatorCoder extends StandardCoder { + private List> coders; + private int codersCount; + + public ComposedAccumulatorCoder(List> coders) { + this.coders = ImmutableList.copyOf(coders); + this.codersCount = coders.size(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @JsonCreator + public static ComposedAccumulatorCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + return new ComposedAccumulatorCoder((List) components); + } + + @Override + public void encode(Object[] value, OutputStream outStream, Context context) + throws CoderException, IOException { + checkArgument(value.length == codersCount); + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + coders.get(i).encode(value[i], outStream, nestedContext); + } + } + + @Override + public Object[] decode(InputStream inStream, Context context) + throws CoderException, IOException { + Object[] ret = new Object[codersCount]; + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + ret[i] = coders.get(i).decode(inStream, nestedContext); + } + return ret; + } + + @Override + public List> getCoderArguments() { + return coders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + for (int i = 0; i < codersCount; ++i) { + coders.get(i).verifyDeterministic(); + } + } + } + + @SuppressWarnings("unchecked") + private static CombineFnWithContext + toFnWithContext(GlobalCombineFn globalCombineFn) { + if (globalCombineFn instanceof CombineFnWithContext) { + return (CombineFnWithContext) globalCombineFn; + } else { + final CombineFn combineFn = + (CombineFn) globalCombineFn; + return new CombineFnWithContext() { + @Override + public AccumT createAccumulator(Context c) { + return combineFn.createAccumulator(); + } + @Override + public AccumT addInput(AccumT accumulator, InputT input, Context c) { + return combineFn.addInput(accumulator, input); + } + @Override + public AccumT mergeAccumulators(Iterable accumulators, Context c) { + return combineFn.mergeAccumulators(accumulators); + } + @Override + public OutputT extractOutput(AccumT accumulator, Context c) { + return combineFn.extractOutput(accumulator); + } + @Override + public AccumT compact(AccumT accumulator, Context c) { + return combineFn.compact(accumulator); + } + @Override + public OutputT defaultValue() { + return combineFn.defaultValue(); + } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return combineFn.getAccumulatorCoder(registry, inputCoder); + } + @Override + public Coder getDefaultOutputCoder( + CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { + return combineFn.getDefaultOutputCoder(registry, inputCoder); + } + }; + } + } + + private static KeyedCombineFnWithContext + toFnWithContext(PerKeyCombineFn perKeyCombineFn) { + if (perKeyCombineFn instanceof KeyedCombineFnWithContext) { + @SuppressWarnings("unchecked") + KeyedCombineFnWithContext keyedCombineFnWithContext = + (KeyedCombineFnWithContext) perKeyCombineFn; + return keyedCombineFnWithContext; + } else { + @SuppressWarnings("unchecked") + final KeyedCombineFn keyedCombineFn = + (KeyedCombineFn) perKeyCombineFn; + return new KeyedCombineFnWithContext() { + @Override + public AccumT createAccumulator(K key, Context c) { + return keyedCombineFn.createAccumulator(key); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) { + return keyedCombineFn.addInput(key, accumulator, value); + } + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, Context c) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return keyedCombineFn.extractOutput(key, accumulator); + } + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return keyedCombineFn.compact(key, accumulator); + } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder getDefaultOutputCoder(CoderRegistry registry, Coder keyCoder, + Coder inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + }; + } + } + + private static void checkUniqueness( + List> registeredTags, TupleTag outputTag) { + checkArgument( + !registeredTags.contains(outputTag), + "Cannot compose with tuple tag %s because it is already present in the composition.", + outputTag); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java new file mode 100644 index 0000000000..ad37708677 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.NullableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFns.CoCombineResult; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for {@link CombineFns}. + */ +@RunWith(JUnit4.class) +public class CombineFnsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testDuplicatedTags() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.compose() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.composeKeyed() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsWithContext() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.compose() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag); + } + + @Test + public void testDuplicatedTagsWithContextKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag tag = new TupleTag(); + CombineFns.composeKeyed() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombine() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + PCollection>> combineGlobally = perKeyInput + .apply(Values.>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString(), + concatStringTag))) + .apply(WithKeys.of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString().asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114")), + KV.of("b", KV.of(13, "113"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineWithContext() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollectionView view = p + .apply(Create.of("I")) + .apply(View.asSingleton()); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + PCollection>> combineGlobally = perKeyInput + .apply(Values.>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view).forKey("G", StringUtf8Coder.of()), + concatStringTag)) + .withoutDefaults() + .withSideInputs(ImmutableList.of(view))) + .apply(WithKeys.of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view), + concatStringTag)) + .withSideInputs(ImmutableList.of(view))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134GI"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114Ia")), + KV.of("b", KV.of(13, "113Ib"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineNullValues() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, NullableCoder.of(UserStringCoder.of())); + p.getCoderRegistry().registerCoder(String.class, NullableCoder.of(StringUtf8Coder.of())); + + PCollection>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of( + BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of()))))); + + TupleTag maxIntTag = new TupleTag(); + TupleTag concatStringTag = new TupleTag(); + + PCollection>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new OutputNullString().asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, (String) null)), + KV.of("b", KV.of(13, (String) null))); + p.run(); + } + + private static class UserString implements Serializable { + private String strValue; + + static UserString of(String strValue) { + UserString ret = new UserString(); + ret.strValue = strValue; + return ret; + } + } + + private static class UserStringCoder extends StandardCoder { + public static UserStringCoder of() { + return INSTANCE; + } + + private static final UserStringCoder INSTANCE = new UserStringCoder(); + + @Override + public void encode(UserString value, OutputStream outStream, Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(value.strValue, outStream, context); + } + + @Override + public UserString decode(InputStream inStream, Context context) + throws CoderException, IOException { + return UserString.of(StringUtf8Coder.of().decode(inStream, context)); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException {} + } + + private static class GetIntegerFunction + extends SimpleFunction, Integer> { + @Override + public Integer apply(KV input) { + return input.getKey(); + } + } + + private static class GetUserStringFunction + extends SimpleFunction, UserString> { + @Override + public UserString apply(KV input) { + return input.getValue(); + } + } + + private static class ConcatString extends BinaryCombineFn { + @Override + public UserString apply(UserString left, UserString right) { + String retStr = left.strValue + right.strValue; + char[] chars = retStr.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class OutputNullString extends BinaryCombineFn { + @Override + public UserString apply(UserString left, UserString right) { + return null; + } + } + + private static class ConcatStringWithContext + extends KeyedCombineFnWithContext { + private final PCollectionView view; + + private ConcatStringWithContext(PCollectionView view) { + this.view = view; + } + + @Override + public UserString createAccumulator(String key, CombineWithContext.Context c) { + return UserString.of(key + c.sideInput(view)); + } + + @Override + public UserString addInput( + String key, UserString accumulator, UserString input, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + accumulator.strValue += input.strValue; + return accumulator; + } + + @Override + public UserString mergeAccumulators( + String key, Iterable accumulators, CombineWithContext.Context c) { + String keyPrefix = key + c.sideInput(view); + String all = keyPrefix; + for (UserString accumulator : accumulators) { + assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix)); + all += accumulator.strValue.substring(keyPrefix.length()); + accumulator.strValue = "cleared in mergeAccumulators"; + } + return UserString.of(all); + } + + @Override + public UserString extractOutput( + String key, UserString accumulator, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + char[] chars = accumulator.strValue.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class ExtractResultDoFn + extends DoFn, KV>>{ + + private final TupleTag maxIntTag; + private final TupleTag concatStringTag; + + ExtractResultDoFn(TupleTag maxIntTag, TupleTag concatStringTag) { + this.maxIntTag = maxIntTag; + this.concatStringTag = concatStringTag; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + UserString userString = c.element().getValue().get(concatStringTag); + KV value = KV.of( + c.element().getValue().get(maxIntTag), + userString == null ? null : userString.strValue); + c.output(KV.of(c.element().getKey(), value)); + } + } +} From d5a8c75f2cdcd3e3fcd529023d0ac9f682d0fa47 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 26 Feb 2016 17:29:43 -0800 Subject: [PATCH 06/11] Add ExecutorServiceParallelExecutor as an InProcessExecutor This is responsible for scheduling transform evaluations and communicating results back to the evaluation context. The executor handle PTransforms that block arbitarily waiting for additional input. --- .../runners/inprocess/CompletionCallback.java | 33 ++ .../ExecutorServiceParallelExecutor.java | 394 ++++++++++++++++++ .../inprocess/InMemoryWatermarkManager.java | 2 +- .../inprocess/InProcessEvaluationContext.java | 21 +- .../runners/inprocess/InProcessExecutor.java | 46 ++ .../inprocess/InProcessPipelineOptions.java | 7 +- .../inprocess/InProcessPipelineRunner.java | 25 -- .../runners/inprocess/TransformExecutor.java | 114 +++++ .../inprocess/TransformExecutorService.java | 34 ++ .../inprocess/TransformExecutorServices.java | 153 +++++++ .../TransformExecutorServicesTest.java | 134 ++++++ .../inprocess/TransformExecutorTest.java | 312 ++++++++++++++ 12 files changed, 1247 insertions(+), 28 deletions(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java new file mode 100644 index 0000000000..2792631560 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; + +/** + * A callback for completing a bundle of input. + */ +interface CompletionCallback { + /** + * Handle a successful result. + */ + void handleResult(CommittedBundle inputBundle, InProcessTransformResult result); + + /** + * Handle a result that terminated abnormally due to the provided {@link Throwable}. + */ + void handleThrowable(CommittedBundle inputBundle, Throwable t); +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java new file mode 100644 index 0000000000..ae686f2979 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java @@ -0,0 +1,394 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItem; +import com.google.cloud.dataflow.sdk.util.KeyedWorkItems; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.MoreObjects; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; + +import javax.annotation.Nullable; + +/** + * An {@link InProcessExecutor} that uses an underlying {@link ExecutorService} and + * {@link InProcessEvaluationContext} to execute a {@link Pipeline}. + */ +final class ExecutorServiceParallelExecutor implements InProcessExecutor { + private static final Logger LOG = LoggerFactory.getLogger(ExecutorServiceParallelExecutor.class); + + private final ExecutorService executorService; + + private final Map>> valueToConsumers; + private final Set keyedPValues; + private final TransformEvaluatorRegistry registry; + private final InProcessEvaluationContext evaluationContext; + + private final ConcurrentMap currentEvaluations; + private final ConcurrentMap, Boolean> scheduledExecutors; + + private final Queue allUpdates; + private final BlockingQueue visibleUpdates; + + private final TransformExecutorService parallelExecutorService; + private final CompletionCallback defaultCompletionCallback; + + private Collection> rootNodes; + + public static ExecutorServiceParallelExecutor create( + ExecutorService executorService, + Map>> valueToConsumers, + Set keyedPValues, + TransformEvaluatorRegistry registry, + InProcessEvaluationContext context) { + return new ExecutorServiceParallelExecutor( + executorService, valueToConsumers, keyedPValues, registry, context); + } + + private ExecutorServiceParallelExecutor( + ExecutorService executorService, + Map>> valueToConsumers, + Set keyedPValues, + TransformEvaluatorRegistry registry, + InProcessEvaluationContext context) { + this.executorService = executorService; + this.valueToConsumers = valueToConsumers; + this.keyedPValues = keyedPValues; + this.registry = registry; + this.evaluationContext = context; + + currentEvaluations = new ConcurrentHashMap<>(); + scheduledExecutors = new ConcurrentHashMap<>(); + + this.allUpdates = new ConcurrentLinkedQueue<>(); + this.visibleUpdates = new ArrayBlockingQueue<>(20); + + parallelExecutorService = + TransformExecutorServices.parallel(executorService, scheduledExecutors); + defaultCompletionCallback = new DefaultCompletionCallback(); + } + + @Override + public void start(Collection> roots) { + rootNodes = ImmutableList.copyOf(roots); + Runnable monitorRunnable = new MonitorRunnable(); + executorService.submit(monitorRunnable); + } + + @SuppressWarnings("unchecked") + public void scheduleConsumption( + AppliedPTransform consumer, + @Nullable CommittedBundle bundle, + CompletionCallback onComplete) { + evaluateBundle(consumer, bundle, onComplete); + } + + private void evaluateBundle( + final AppliedPTransform transform, + @Nullable final CommittedBundle bundle, + final CompletionCallback onComplete) { + TransformExecutorService transformExecutor; + if (isKeyed(bundle.getPCollection())) { + final StepAndKey stepAndKey = + StepAndKey.of(transform, bundle == null ? null : bundle.getKey()); + transformExecutor = getSerialExecutorService(stepAndKey); + } else { + transformExecutor = parallelExecutorService; + } + TransformExecutor callable = + TransformExecutor.create( + registry, evaluationContext, bundle, transform, onComplete, transformExecutor); + transformExecutor.schedule(callable); + } + + private boolean isKeyed(PValue pvalue) { + return keyedPValues.contains(pvalue); + } + + private void scheduleConsumers(CommittedBundle bundle) { + for (AppliedPTransform consumer : valueToConsumers.get(bundle.getPCollection())) { + scheduleConsumption(consumer, bundle, defaultCompletionCallback); + } + } + + private TransformExecutorService getSerialExecutorService(StepAndKey stepAndKey) { + if (!currentEvaluations.containsKey(stepAndKey)) { + currentEvaluations.putIfAbsent( + stepAndKey, TransformExecutorServices.serial(executorService, scheduledExecutors)); + } + return currentEvaluations.get(stepAndKey); + } + + @Override + public void awaitCompletion() throws Throwable { + VisibleExecutorUpdate update; + do { + update = visibleUpdates.take(); + if (update.throwable.isPresent()) { + throw update.throwable.get(); + } + } while (!update.isDone()); + executorService.shutdown(); + } + + /** + * The default {@link CompletionCallback}. The default completion callback is used to complete + * transform evaluations that are triggered due to the arrival of elements from an upstream + * transform, or for a source transform. + */ + private class DefaultCompletionCallback implements CompletionCallback { + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + Iterable> resultBundles = + evaluationContext.handleResult(inputBundle, Collections.emptyList(), result); + for (CommittedBundle outputBundle : resultBundles) { + allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle)); + } + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + allUpdates.offer(ExecutorUpdate.fromThrowable(t)); + } + } + + /** + * A {@link CompletionCallback} where the completed bundle was produced to deliver some collection + * of {@link TimerData timers}. When the evaluator completes successfully, reports all of the + * timers used to create the input to the {@link InProcessEvaluationContext evaluation context} + * as part of the result. + */ + private class TimerCompletionCallback implements CompletionCallback { + private final Iterable timers; + + private TimerCompletionCallback(Iterable timers) { + this.timers = timers; + } + + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + Iterable> resultBundles = + evaluationContext.handleResult(inputBundle, timers, result); + for (CommittedBundle outputBundle : resultBundles) { + allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle)); + } + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + allUpdates.offer(ExecutorUpdate.fromThrowable(t)); + } + } + + /** + * An internal status update on the state of the executor. + * + * Used to signal when the executor should be shut down (due to an exception). + */ + private static class ExecutorUpdate { + private final Optional> bundle; + private final Optional throwable; + + public static ExecutorUpdate fromBundle(CommittedBundle bundle) { + return new ExecutorUpdate(bundle, null); + } + + public static ExecutorUpdate fromThrowable(Throwable t) { + return new ExecutorUpdate(null, t); + } + + private ExecutorUpdate(CommittedBundle producedBundle, Throwable throwable) { + this.bundle = Optional.fromNullable(producedBundle); + this.throwable = Optional.fromNullable(throwable); + } + + public Optional> getBundle() { + return bundle; + } + + public Optional getException() { + return throwable; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(ExecutorUpdate.class) + .add("bundle", bundle) + .add("exception", throwable) + .toString(); + } + } + + /** + * An update of interest to the user. Used in {@link #awaitCompletion} to decide whether to + * return normally or throw an exception. + */ + private static class VisibleExecutorUpdate { + private final Optional throwable; + private final boolean done; + + public static VisibleExecutorUpdate fromThrowable(Throwable e) { + return new VisibleExecutorUpdate(false, e); + } + + public static VisibleExecutorUpdate finished() { + return new VisibleExecutorUpdate(true, null); + } + + private VisibleExecutorUpdate(boolean done, @Nullable Throwable exception) { + this.throwable = Optional.fromNullable(exception); + this.done = done; + } + + public boolean isDone() { + return done; + } + } + + private class MonitorRunnable implements Runnable { + private final String runnableName = + String.format( + "%s$%s-monitor", + evaluationContext.getPipelineOptions().getAppName(), + ExecutorServiceParallelExecutor.class.getSimpleName()); + + @Override + public void run() { + String oldName = Thread.currentThread().getName(); + Thread.currentThread().setName(runnableName); + try { + ExecutorUpdate update = allUpdates.poll(); + if (update != null) { + LOG.debug("Executor Update: {}", update); + if (update.getBundle().isPresent()) { + scheduleConsumers(update.getBundle().get()); + } else if (update.getException().isPresent()) { + visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(update.getException().get())); + } + } + fireTimers(); + mightNeedMoreWork(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.error("Monitor died due to being interrupted"); + while (!visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(e))) { + visibleUpdates.poll(); + } + } catch (Throwable t) { + LOG.error("Monitor thread died due to throwable", t); + while (!visibleUpdates.offer(VisibleExecutorUpdate.fromThrowable(t))) { + visibleUpdates.poll(); + } + } finally { + if (!shouldShutdown()) { + // The monitor thread should always be scheduled; but we only need to be scheduled once + executorService.submit(this); + } + Thread.currentThread().setName(oldName); + } + } + + private void fireTimers() throws Exception { + try { + for (Map.Entry, Map> transformTimers : + evaluationContext.extractFiredTimers().entrySet()) { + AppliedPTransform transform = transformTimers.getKey(); + for (Map.Entry keyTimers : transformTimers.getValue().entrySet()) { + for (TimeDomain domain : TimeDomain.values()) { + Collection delivery = keyTimers.getValue().getTimers(domain); + if (delivery.isEmpty()) { + continue; + } + KeyedWorkItem work = + KeyedWorkItems.timersWorkItem(keyTimers.getKey(), delivery); + @SuppressWarnings({"unchecked", "rawtypes"}) + CommittedBundle bundle = + InProcessBundle.>keyed( + (PCollection) transform.getInput(), keyTimers.getKey()) + .add(WindowedValue.valueInEmptyWindows(work)) + .commit(Instant.now()); + scheduleConsumption(transform, bundle, new TimerCompletionCallback(delivery)); + } + } + } + } catch (Exception e) { + LOG.error("Internal Error while delivering timers", e); + throw e; + } + } + + private boolean shouldShutdown() { + if (evaluationContext.isDone()) { + LOG.debug("Pipeline is finished. Shutting down. {}"); + while (!visibleUpdates.offer(VisibleExecutorUpdate.finished())) { + visibleUpdates.poll(); + } + executorService.shutdown(); + return true; + } + return false; + } + + private void mightNeedMoreWork() { + synchronized (scheduledExecutors) { + for (TransformExecutor executor : scheduledExecutors.keySet()) { + Thread thread = executor.getThread(); + if (thread != null) { + switch (thread.getState()) { + case BLOCKED: + case WAITING: + case TERMINATED: + case TIMED_WAITING: + break; + default: + return; + } + } + } + } + // All current TransformExecutors are blocked; add more work from the roots. + for (AppliedPTransform root : rootNodes) { + scheduleConsumption(root, null, defaultCompletionCallback); + } + } + } +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java index 7cf53aafe6..094526d962 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InMemoryWatermarkManager.java @@ -866,7 +866,7 @@ private void updatePending( * {@link #getWatermarks(AppliedPTransform)}, the output watermark will be equal to * {@link BoundedWindow#TIMESTAMP_MAX_VALUE}. */ - public boolean isDone() { + public boolean allWatermarksAtPositiveInfinity() { for (Map.Entry, TransformWatermarks> watermarksEntry : transformToWatermarks.entrySet()) { Instant endOfTime = THE_END_OF_TIME.get(); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java index 757e9e11d9..2908fba818 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessEvaluationContext.java @@ -36,6 +36,7 @@ import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.PValue; import com.google.common.collect.ImmutableList; @@ -359,6 +360,24 @@ public CounterSet getCounters() { * Returns true if all steps are done. */ public boolean isDone() { - return watermarkManager.isDone(); + if (!options.isShutdownUnboundedProducersWithMaxWatermark() && containsUnboundedPCollection()) { + return false; + } + if (!watermarkManager.allWatermarksAtPositiveInfinity()) { + return false; + } + return true; + } + + private boolean containsUnboundedPCollection() { + for (AppliedPTransform transform : stepNames.keySet()) { + for (PValue value : transform.getInput().expand()) { + if (value instanceof PCollection + && ((PCollection) value).isBounded().equals(IsBounded.UNBOUNDED)) { + return true; + } + } + } + return false; } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java new file mode 100644 index 0000000000..7b60bca17d --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessExecutor.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import java.util.Collection; + +/** + * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both + * source and intermediate {@link PTransform PTransforms}. + */ +interface InProcessExecutor { + /** + * Starts this executor. The provided collection is the collection of root transforms to + * initially schedule. + * + * @param rootTransforms + */ + void start(Collection> rootTransforms); + + /** + * Blocks until the job being executed enters a terminal state. A job is completed after all + * root {@link AppliedPTransform AppliedPTransforms} have completed, and all + * {@link CommittedBundle Bundles} have been consumed. Jobs may also terminate abnormally. + * + * @throws Throwable whenever an executor thread throws anything, transfers the throwable to the + * waiting thread and rethrows it + */ + void awaitCompletion() throws Throwable; +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index 60c8543a2f..27e9a4be6e 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -15,15 +15,20 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; import com.google.cloud.dataflow.sdk.options.Default; import com.google.cloud.dataflow.sdk.options.PipelineOptions; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ -public interface InProcessPipelineOptions extends PipelineOptions { +public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions { @Default.InstanceFactory(NanosOffsetClock.Factory.class) Clock getClock(); void setClock(Clock clock); + + boolean isShutdownUnboundedProducersWithMaxWatermark(); + + void setShutdownUnboundedProducersWithMaxWatermark(boolean shutdown); } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 7a268ee5fa..32859dae63 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -21,7 +21,6 @@ import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; -import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; @@ -167,28 +166,4 @@ private InProcessPipelineRunner(InProcessPipelineOptions options) { public InProcessPipelineOptions getPipelineOptions() { return options; } - - /** - * An executor that schedules and executes {@link AppliedPTransform AppliedPTransforms} for both - * source and intermediate {@link PTransform PTransforms}. - */ - public static interface InProcessExecutor { - /** - * @param root the root {@link AppliedPTransform} to schedule - */ - void scheduleRoot(AppliedPTransform root); - - /** - * @param consumer the {@link AppliedPTransform} to schedule - * @param bundle the input bundle to the consumer - */ - void scheduleConsumption(AppliedPTransform consumer, CommittedBundle bundle); - - /** - * Blocks until the job being executed enters a terminal state. A job is completed after all - * root {@link AppliedPTransform AppliedPTransforms} have completed, and all - * {@link CommittedBundle Bundles} have been consumed. Jobs may also terminate abnormally. - */ - void awaitCompletion(); - } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java new file mode 100644 index 0000000000..d630749387 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.common.base.Throwables; + +import java.util.concurrent.Callable; + +import javax.annotation.Nullable; + +/** + * A {@link Callable} responsible for constructing a {@link TransformEvaluator} from a + * {@link TransformEvaluatorFactory} and evaluating it on some bundle of input, and registering + * the result using a registered {@link CompletionCallback}. + * + *

A {@link TransformExecutor} that is currently executing also provides access to the thread + * that it is being executed on. + */ +class TransformExecutor implements Callable { + public static TransformExecutor create( + TransformEvaluatorFactory factory, + InProcessEvaluationContext evaluationContext, + CommittedBundle inputBundle, + AppliedPTransform transform, + CompletionCallback completionCallback, + TransformExecutorService transformEvaluationState) { + return new TransformExecutor<>( + factory, + evaluationContext, + inputBundle, + transform, + completionCallback, + transformEvaluationState); + } + + private final TransformEvaluatorFactory evaluatorFactory; + private final InProcessEvaluationContext evaluationContext; + + /** The transform that will be evaluated. */ + private final AppliedPTransform transform; + /** The inputs this {@link TransformExecutor} will deliver to the transform. */ + private final CommittedBundle inputBundle; + + private final CompletionCallback onComplete; + private final TransformExecutorService transformEvaluationState; + + private Thread thread; + + private TransformExecutor( + TransformEvaluatorFactory factory, + InProcessEvaluationContext evaluationContext, + CommittedBundle inputBundle, + AppliedPTransform transform, + CompletionCallback completionCallback, + TransformExecutorService transformEvaluationState) { + this.evaluatorFactory = factory; + this.evaluationContext = evaluationContext; + + this.inputBundle = inputBundle; + this.transform = transform; + + this.onComplete = completionCallback; + + this.transformEvaluationState = transformEvaluationState; + } + + @Override + public InProcessTransformResult call() { + this.thread = Thread.currentThread(); + try { + TransformEvaluator evaluator = + evaluatorFactory.forApplication(transform, inputBundle, evaluationContext); + if (inputBundle != null) { + for (WindowedValue value : inputBundle.getElements()) { + evaluator.processElement(value); + } + } + InProcessTransformResult result = evaluator.finishBundle(); + onComplete.handleResult(inputBundle, result); + return result; + } catch (Throwable t) { + onComplete.handleThrowable(inputBundle, t); + throw Throwables.propagate(t); + } finally { + this.thread = null; + transformEvaluationState.complete(this); + } + } + + /** + * If this {@link TransformExecutor} is currently executing, return the thread it is executing in. + * Otherwise, return null. + */ + @Nullable + public Thread getThread() { + return this.thread; + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java new file mode 100644 index 0000000000..3f00da6ebe --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorService.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +/** + * Schedules and completes {@link TransformExecutor TransformExecutors}, controlling concurrency as + * appropriate for the {@link StepAndKey} the executor exists for. + */ +interface TransformExecutorService { + /** + * Schedule the provided work to be eventually executed. + */ + void schedule(TransformExecutor work); + + /** + * Finish executing the provided work. This may cause additional + * {@link TransformExecutor TransformExecutors} to be evaluated. + */ + void complete(TransformExecutor completed); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java new file mode 100644 index 0000000000..34efdf694e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServices.java @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.common.base.MoreObjects; + +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Static factory methods for constructing instances of {@link TransformExecutorService}. + */ +final class TransformExecutorServices { + private TransformExecutorServices() { + // Do not instantiate + } + + /** + * Returns an EvaluationState that evaluates {@link TransformExecutor TransformExecutors} in + * parallel. + */ + public static TransformExecutorService parallel( + ExecutorService executor, Map, Boolean> scheduled) { + return new ParallelEvaluationState(executor, scheduled); + } + + /** + * Returns an EvaluationState that evaluates {@link TransformExecutor TransformExecutors} in + * serial. + */ + public static TransformExecutorService serial( + ExecutorService executor, Map, Boolean> scheduled) { + return new SerialEvaluationState(executor, scheduled); + } + + /** + * A {@link TransformExecutorService} with unlimited parallelism. Any {@link TransformExecutor} + * scheduled will be immediately submitted to the {@link ExecutorService}. + * + *

A principal use of this is for the evaluation of an unkeyed Step. Unkeyed computations are + * processed in parallel. + */ + private static class ParallelEvaluationState implements TransformExecutorService { + private final ExecutorService executor; + private final Map, Boolean> scheduled; + + private ParallelEvaluationState( + ExecutorService executor, Map, Boolean> scheduled) { + this.executor = executor; + this.scheduled = scheduled; + } + + @Override + public void schedule(TransformExecutor work) { + executor.submit(work); + scheduled.put(work, true); + } + + @Override + public void complete(TransformExecutor completed) { + scheduled.remove(completed); + } + } + + /** + * A {@link TransformExecutorService} with a single work queue. Any {@link TransformExecutor} + * scheduled will be placed on the work queue. Only one item of work will be submitted to the + * {@link ExecutorService} at any time. + * + *

A principal use of this is for the serial evaluation of a (Step, Key) pair. + * Keyed computations are processed serially per step. + */ + private static class SerialEvaluationState implements TransformExecutorService { + private final ExecutorService executor; + private final Map, Boolean> scheduled; + + private AtomicReference> currentlyEvaluating; + private final Queue> workQueue; + + private SerialEvaluationState( + ExecutorService executor, Map, Boolean> scheduled) { + this.scheduled = scheduled; + this.executor = executor; + this.currentlyEvaluating = new AtomicReference<>(); + this.workQueue = new ConcurrentLinkedQueue<>(); + } + + /** + * Schedules the work, adding it to the work queue if there is a bundle currently being + * evaluated and scheduling it immediately otherwise. + */ + @Override + public void schedule(TransformExecutor work) { + workQueue.offer(work); + updateCurrentlyEvaluating(); + } + + @Override + public void complete(TransformExecutor completed) { + if (!currentlyEvaluating.compareAndSet(completed, null)) { + throw new IllegalStateException( + "Finished work " + + completed + + " but could not complete due to unexpected currently executing " + + currentlyEvaluating.get()); + } + scheduled.remove(completed); + updateCurrentlyEvaluating(); + } + + private void updateCurrentlyEvaluating() { + if (currentlyEvaluating.get() == null) { + // Only synchronize if we need to update what's currently evaluating + synchronized (this) { + TransformExecutor newWork = workQueue.poll(); + if (newWork != null) { + if (currentlyEvaluating.compareAndSet(null, newWork)) { + scheduled.put(newWork, true); + executor.submit(newWork); + } else { + workQueue.offer(newWork); + } + } + } + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(SerialEvaluationState.class) + .add("currentlyEvaluating", currentlyEvaluating) + .add("workQueue", workQueue) + .toString(); + } + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java new file mode 100644 index 0000000000..2c66dc2c32 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorServicesTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.google.common.util.concurrent.MoreExecutors; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +/** + * Tests for {@link TransformExecutorServices}. + */ +@RunWith(JUnit4.class) +public class TransformExecutorServicesTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private ExecutorService executorService; + private Map, Boolean> scheduled; + + @Before + public void setup() { + executorService = MoreExecutors.newDirectExecutorService(); + scheduled = new ConcurrentHashMap<>(); + } + + @Test + public void parallelScheduleMultipleSchedulesBothImmediately() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService parallel = + TransformExecutorServices.parallel(executorService, scheduled); + parallel.schedule(first); + parallel.schedule(second); + + verify(first).call(); + verify(second).call(); + assertThat( + scheduled, + Matchers.allOf( + Matchers., Boolean>hasEntry(first, true), + Matchers., Boolean>hasEntry(second, true))); + + parallel.complete(first); + assertThat(scheduled, Matchers., Boolean>hasEntry(second, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(first), any(Boolean.class)))); + parallel.complete(second); + assertThat(scheduled.isEmpty(), is(true)); + } + + @Test + public void serialScheduleTwoWaitsForFirstToComplete() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService serial = TransformExecutorServices.serial(executorService, scheduled); + serial.schedule(first); + verify(first).call(); + + serial.schedule(second); + verify(second, never()).call(); + + assertThat(scheduled, Matchers., Boolean>hasEntry(first, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(second), any(Boolean.class)))); + + serial.complete(first); + verify(second).call(); + assertThat(scheduled, Matchers., Boolean>hasEntry(second, true)); + assertThat( + scheduled, + not( + Matchers., Boolean>hasEntry( + Matchers.>equalTo(first), any(Boolean.class)))); + + serial.complete(second); + } + + @Test + public void serialCompleteNotExecutingTaskThrows() { + @SuppressWarnings("unchecked") + TransformExecutor first = mock(TransformExecutor.class); + @SuppressWarnings("unchecked") + TransformExecutor second = mock(TransformExecutor.class); + + TransformExecutorService serial = TransformExecutorServices.serial(executorService, scheduled); + serial.schedule(first); + thrown.expect(IllegalStateException.class); + thrown.expectMessage("unexpected currently executing"); + + serial.complete(second); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java new file mode 100644 index 0000000000..bd6325278a --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.util.concurrent.MoreExecutors; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Tests for {@link TransformExecutor}. + */ +@RunWith(JUnit4.class) +public class TransformExecutorTest { + private PCollection created; + private PCollection> downstream; + + private CountDownLatch evaluatorCompleted; + + private RegisteringCompletionCallback completionCallback; + private TransformExecutorService transformEvaluationState; + @Mock private InProcessEvaluationContext evaluationContext; + @Mock private TransformEvaluatorRegistry registry; + private Map, Boolean> scheduled; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + scheduled = new HashMap<>(); + transformEvaluationState = + TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService(), scheduled); + + evaluatorCompleted = new CountDownLatch(1); + completionCallback = new RegisteringCompletionCallback(evaluatorCompleted); + + TestPipeline p = TestPipeline.create(); + created = p.apply(Create.of("foo", "spam", "third")); + downstream = created.apply(WithKeys.of(3)); + } + + @Test + public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(created.getProducingTransformInternal()).build(); + final AtomicBoolean finishCalled = new AtomicBoolean(false); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw new IllegalArgumentException("Shouldn't be called"); + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + finishCalled.set(true); + return result; + } + }; + + when(registry.forApplication(created.getProducingTransformInternal(), null, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + null, + created.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + executor.call(); + + assertThat(finishCalled.get(), is(true)); + assertThat(completionCallback.handledResult, equalTo(result)); + assertThat(completionCallback.handledThrowable, is(nullValue())); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final Collection> elementsProcessed = new ArrayList<>(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + elementsProcessed.add(element); + return; + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return result; + } + }; + + WindowedValue foo = WindowedValue.valueInGlobalWindow("foo"); + WindowedValue spam = WindowedValue.valueInGlobalWindow("spam"); + WindowedValue third = WindowedValue.valueInGlobalWindow("third"); + CommittedBundle inputBundle = + InProcessBundle.unkeyed(created).add(foo).add(spam).add(third).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo)); + assertThat(completionCallback.handledResult, equalTo(result)); + assertThat(completionCallback.handledThrowable, is(nullValue())); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void processElementThrowsExceptionCallsback() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final Exception exception = new Exception(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw exception; + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return result; + } + }; + + WindowedValue foo = WindowedValue.valueInGlobalWindow("foo"); + CommittedBundle inputBundle = + InProcessBundle.unkeyed(created).add(foo).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(completionCallback.handledResult, is(nullValue())); + assertThat(completionCallback.handledThrowable, Matchers.equalTo(exception)); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void finishBundleThrowsExceptionCallsback() throws Exception { + final Exception exception = new Exception(); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception {} + + @Override + public InProcessTransformResult finishBundle() throws Exception { + throw exception; + } + }; + + CommittedBundle inputBundle = InProcessBundle.unkeyed(created).commit(Instant.now()); + when( + registry.forApplication( + downstream.getProducingTransformInternal(), inputBundle, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + inputBundle, + downstream.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + Executors.newSingleThreadExecutor().submit(executor); + + evaluatorCompleted.await(); + + assertThat(completionCallback.handledResult, is(nullValue())); + assertThat(completionCallback.handledThrowable, Matchers.equalTo(exception)); + assertThat(scheduled, not(Matchers.>hasKey(executor))); + } + + @Test + public void duringCallGetThreadIsNonNull() throws Exception { + final InProcessTransformResult result = + StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build(); + final CountDownLatch testLatch = new CountDownLatch(1); + final CountDownLatch evaluatorLatch = new CountDownLatch(1); + TransformEvaluator evaluator = + new TransformEvaluator() { + @Override + public void processElement(WindowedValue element) throws Exception { + throw new IllegalArgumentException("Shouldn't be called"); + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + testLatch.countDown(); + evaluatorLatch.await(); + return result; + } + }; + + when(registry.forApplication(created.getProducingTransformInternal(), null, evaluationContext)) + .thenReturn(evaluator); + + TransformExecutor executor = + TransformExecutor.create( + registry, + evaluationContext, + null, + created.getProducingTransformInternal(), + completionCallback, + transformEvaluationState); + + Executors.newSingleThreadExecutor().submit(executor); + testLatch.await(); + assertThat(executor.getThread(), not(nullValue())); + + // Finish the execution so everything can get closed down cleanly. + evaluatorLatch.countDown(); + } + + private static class RegisteringCompletionCallback implements CompletionCallback { + private InProcessTransformResult handledResult = null; + private Throwable handledThrowable = null; + private final CountDownLatch onMethod; + + private RegisteringCompletionCallback(CountDownLatch onMethod) { + this.onMethod = onMethod; + } + + @Override + public void handleResult(CommittedBundle inputBundle, InProcessTransformResult result) { + handledResult = result; + onMethod.countDown(); + } + + @Override + public void handleThrowable(CommittedBundle inputBundle, Throwable t) { + handledThrowable = t; + onMethod.countDown(); + } + } +} From 931eb73a19c522663000a8517913e9ab81145e57 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 15 Mar 2016 11:50:38 -0700 Subject: [PATCH 07/11] Close Readers in InProcess Read Evaluators The readers were formerly left open, which prevents release of any resources that should be released. --- .../BoundedReadEvaluatorFactory.java | 49 +++-- .../UnboundedReadEvaluatorFactory.java | 53 +++--- .../BoundedReadEvaluatorFactoryTest.java | 136 +++++++++++++- .../UnboundedReadEvaluatorFactoryTest.java | 168 ++++++++++++++++++ 4 files changed, 366 insertions(+), 40 deletions(-) diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java index 2a164c3518..eaea3ed293 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactory.java @@ -15,6 +15,8 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.Read.Bounded; import com.google.cloud.dataflow.sdk.io.Source.Reader; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; @@ -78,8 +80,7 @@ private TransformEvaluator getTransformEvaluator( @SuppressWarnings("unchecked") private Queue> getTransformEvaluatorQueue( final AppliedPTransform, Bounded> transform, - final InProcessEvaluationContext evaluationContext) - throws IOException { + final InProcessEvaluationContext evaluationContext) { // Key by the application and the context the evaluation is occurring in (which call to // Pipeline#run). EvaluatorKey key = new EvaluatorKey(transform, evaluationContext); @@ -101,21 +102,25 @@ private Queue> getTransformEvaluatorQueu return evaluatorQueue; } + /** + * A {@link BoundedReadEvaluator} produces elements from an underlying {@link BoundedSource}, + * discarding all input elements. Within the call to {@link #finishBundle()}, the evaluator + * creates the {@link BoundedReader} and consumes all available input. + * + *

A {@link BoundedReadEvaluator} should only be created once per {@link BoundedSource}, and + * each evaluator should only be called once per evaluation of the pipeline. Otherwise, the source + * may produce duplicate elements. + */ private static class BoundedReadEvaluator implements TransformEvaluator { private final AppliedPTransform, Bounded> transform; private final InProcessEvaluationContext evaluationContext; - private final Reader reader; private boolean contentsRemaining; public BoundedReadEvaluator( AppliedPTransform, Bounded> transform, - InProcessEvaluationContext evaluationContext) - throws IOException { + InProcessEvaluationContext evaluationContext) { this.transform = transform; this.evaluationContext = evaluationContext; - reader = - transform.getTransform().getSource().createReader(evaluationContext.getPipelineOptions()); - contentsRemaining = reader.start(); } @Override @@ -123,17 +128,25 @@ public void processElement(WindowedValue element) {} @Override public InProcessTransformResult finishBundle() throws IOException { - UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); - while (contentsRemaining) { - output.add( - WindowedValue.timestampedValueInGlobalWindow( - reader.getCurrent(), reader.getCurrentTimestamp())); - contentsRemaining = reader.advance(); + try (final Reader reader = + transform + .getTransform() + .getSource() + .createReader(evaluationContext.getPipelineOptions());) { + contentsRemaining = reader.start(); + UncommittedBundle output = + evaluationContext.createRootBundle(transform.getOutput()); + while (contentsRemaining) { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + contentsRemaining = reader.advance(); + } + reader.close(); + return StepTransformResult.withHold(transform, BoundedWindow.TIMESTAMP_MAX_VALUE) + .addOutput(output) + .build(); } - return StepTransformResult - .withHold(transform, BoundedWindow.TIMESTAMP_MAX_VALUE) - .addOutput(output) - .build(); } } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java index 97f0e25d38..549afabcc0 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactory.java @@ -99,6 +99,16 @@ private Queue> getTransformEvaluatorQu return evaluatorQueue; } + /** + * A {@link UnboundedReadEvaluator} produces elements from an underlying {@link UnboundedSource}, + * discarding all input elements. Within the call to {@link #finishBundle()}, the evaluator + * creates the {@link UnboundedReader} and consumes some currently available input. + * + *

Calls to {@link UnboundedReadEvaluator} are not internally thread-safe, and should only be + * used by a single thread at a time. Each {@link UnboundedReadEvaluator} maintains its own + * checkpoint, and constructs its reader from the current checkpoint in each call to + * {@link #finishBundle()}. + */ private static class UnboundedReadEvaluator implements TransformEvaluator { private static final int ARBITRARY_MAX_ELEMENTS = 10; private final AppliedPTransform, Unbounded> transform; @@ -122,28 +132,29 @@ public void processElement(WindowedValue element) {} @Override public InProcessTransformResult finishBundle() throws IOException { UncommittedBundle output = evaluationContext.createRootBundle(transform.getOutput()); - UnboundedReader reader = - createReader( - transform.getTransform().getSource(), evaluationContext.getPipelineOptions()); - int numElements = 0; - if (reader.start()) { - do { - output.add( - WindowedValue.timestampedValueInGlobalWindow( - reader.getCurrent(), reader.getCurrentTimestamp())); - numElements++; - } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); + try (UnboundedReader reader = + createReader( + transform.getTransform().getSource(), evaluationContext.getPipelineOptions());) { + int numElements = 0; + if (reader.start()) { + do { + output.add( + WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp())); + numElements++; + } while (numElements < ARBITRARY_MAX_ELEMENTS && reader.advance()); + } + checkpointMark = reader.getCheckpointMark(); + checkpointMark.finalizeCheckpoint(); + // TODO: When exercising create initial splits, make this the minimum watermark across all + // existing readers + StepTransformResult result = + StepTransformResult.withHold(transform, reader.getWatermark()) + .addOutput(output) + .build(); + evaluatorQueue.offer(this); + return result; } - checkpointMark = reader.getCheckpointMark(); - checkpointMark.finalizeCheckpoint(); - // TODO: When exercising create initial splits, make this the minimum watermark across all - // existing readers - StepTransformResult result = - StepTransformResult.withHold(transform, reader.getWatermark()) - .addOutput(output) - .build(); - evaluatorQueue.offer(this); - return result; } private UnboundedReader createReader( diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java index 43955149ee..e641dd6181 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/BoundedReadEvaluatorFactoryTest.java @@ -18,24 +18,39 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.BoundedSource.BoundedReader; import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.Read.Bounded; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; +import org.joda.time.Instant; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; /** * Tests for {@link BoundedReadEvaluatorFactory}. @@ -45,7 +60,7 @@ public class BoundedReadEvaluatorFactoryTest { private BoundedSource source; private PCollection longs; private TransformEvaluatorFactory factory; - private InProcessEvaluationContext context; + @Mock private InProcessEvaluationContext context; @Before public void setup() { @@ -146,6 +161,125 @@ public void boundedSourceEvaluatorSimultaneousEvaluations() throws Exception { gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); } + @Test + public void boundedSourceEvaluatorClosesReader() throws Exception { + TestSource source = new TestSource<>(BigEndianLongCoder.of(), 1L, 2L, 3L); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(gw(2L), gw(3L), gw(1L))); + assertThat(TestSource.readerClosed, is(true)); + } + + @Test + public void boundedSourceEvaluatorNoElementsClosesReader() throws Exception { + TestSource source = new TestSource<>(BigEndianLongCoder.of()); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle output = InProcessBundle.unkeyed(longs); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), emptyIterable()); + assertThat(TestSource.readerClosed, is(true)); + } + + private static class TestSource extends BoundedSource { + private static boolean readerClosed; + private final Coder coder; + private final T[] elems; + + public TestSource(Coder coder, T... elems) { + this.elems = elems; + this.coder = coder; + readerClosed = false; + } + + @Override + public List> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return ImmutableList.of(this); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader createReader(PipelineOptions options) throws IOException { + return new TestReader<>(this, elems); + } + + @Override + public void validate() { + } + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + } + + private static class TestReader extends BoundedReader { + private final BoundedSource source; + private final List elems; + private int index; + + public TestReader(BoundedSource source, T... elems) { + this.source = source; + this.elems = Arrays.asList(elems); + this.index = -1; + } + + @Override + public BoundedSource getCurrentSource() { + return source; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (elems.size() > index + 1) { + index++; + return true; + } + return false; + } + + @Override + public T getCurrent() throws NoSuchElementException { + return elems.get(index); + } + + @Override + public void close() throws IOException { + TestSource.readerClosed = true; + } + } + private static WindowedValue gw(Long elem) { return WindowedValue.valueInGlobalWindow(elem); } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java index a9bbcc8cc5..20a7d60bce 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/UnboundedReadEvaluatorFactoryTest.java @@ -18,20 +18,30 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.BigEndianLongCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; import com.google.cloud.dataflow.sdk.io.CountingSource; import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.io.UnboundedSource.CheckpointMark; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.UncommittedBundle; import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; import com.google.cloud.dataflow.sdk.util.WindowedValue; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; import org.hamcrest.Matchers; import org.joda.time.DateTime; @@ -41,6 +51,15 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; /** * Tests for {@link UnboundedReadEvaluatorFactory}. */ @@ -111,6 +130,41 @@ public void unboundedSourceInMemoryTransformEvaluatorMultipleSequentialCalls() t tgw(15L), tgw(13L), tgw(10L))); } + @Test + public void boundedSourceEvaluatorClosesReader() throws Exception { + TestUnboundedSource source = + new TestUnboundedSource<>(BigEndianLongCoder.of(), 1L, 2L, 3L); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(ImmutableList.copyOf(committed.getElements()), hasSize(3)); + assertThat(TestUnboundedSource.readerClosedCount, equalTo(1)); + } + + @Test + public void boundedSourceEvaluatorNoElementsClosesReader() throws Exception { + TestUnboundedSource source = new TestUnboundedSource<>(BigEndianLongCoder.of()); + + TestPipeline p = TestPipeline.create(); + PCollection pcollection = p.apply(Read.from(source)); + AppliedPTransform sourceTransform = pcollection.getProducingTransformInternal(); + + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle committed = output.commit(Instant.now()); + assertThat(committed.getElements(), emptyIterable()); + assertThat(TestUnboundedSource.readerClosedCount, equalTo(1)); + } + // TODO: Once the source is split into multiple sources before evaluating, this test will have to // be updated. /** @@ -156,4 +210,118 @@ public Instant apply(Long input) { return new Instant(input); } } + + private static class TestUnboundedSource extends UnboundedSource { + static int readerClosedCount; + private final Coder coder; + private final List elems; + + public TestUnboundedSource(Coder coder, T... elems) { + readerClosedCount = 0; + this.coder = coder; + this.elems = Arrays.asList(elems); + } + + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { + return ImmutableList.of(this); + } + + @Override + public UnboundedSource.UnboundedReader createReader( + PipelineOptions options, TestCheckpointMark checkpointMark) { + return new TestUnboundedReader(elems); + } + + @Override + @Nullable + public Coder getCheckpointMarkCoder() { + return new TestCheckpointMark.Coder(); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + private class TestUnboundedReader extends UnboundedReader { + private final List elems; + private int index; + + public TestUnboundedReader(List elems) { + this.elems = elems; + this.index = -1; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (index + 1 < elems.size()) { + index++; + return true; + } + return false; + } + + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public CheckpointMark getCheckpointMark() { + return new TestCheckpointMark(); + } + + @Override + public UnboundedSource getCurrentSource() { + TestUnboundedSource source = TestUnboundedSource.this; + return source; + } + + @Override + public T getCurrent() throws NoSuchElementException { + return elems.get(index); + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return Instant.now(); + } + + @Override + public void close() throws IOException { + readerClosedCount++; + } + } + } + + private static class TestCheckpointMark implements CheckpointMark { + @Override + public void finalizeCheckpoint() throws IOException {} + + public static class Coder extends AtomicCoder { + @Override + public void encode( + TestCheckpointMark value, + OutputStream outStream, + com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public TestCheckpointMark decode( + InputStream inStream, com.google.cloud.dataflow.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new TestCheckpointMark(); + } + } + } } From 2d3ad38ffc3f12bfdf77d9d1d58fddfb6e740dd0 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 18 Mar 2016 16:20:56 -0700 Subject: [PATCH 08/11] Look up a runner if it is not registered If a fully qualified runner is passed as the value of --runner, and it is not present within the map of registered runners, attempts to look up the runner using Class#forName, and uses the result class if the result class is an instance of PipelineRunner. This brings the behavior in line with the described behavior in PipelineOptions. --- .../sdk/options/PipelineOptionsFactory.java | 31 +++++++++++++--- .../options/PipelineOptionsFactoryTest.java | 36 +++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java index e77b89f9a4..48cff6df93 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -16,6 +16,8 @@ package com.google.cloud.dataflow.sdk.options; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.cloud.dataflow.sdk.options.Validation.Required; import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; @@ -1391,7 +1393,10 @@ private static ListMultimap parseCommandLine( * split up each string on ','. * *

We special case the "runner" option. It is mapped to the class of the {@link PipelineRunner} - * based off of the {@link PipelineRunner}s simple class name or fully qualified class name. + * based off of the {@link PipelineRunner PipelineRunners} simple class name. If the provided + * runner name is not registered via a {@link PipelineRunnerRegistrar}, we attempt to obtain the + * class that the name represents using {@link Class#forName(String)} and use the result class if + * it subclasses {@link PipelineRunner}. * *

If strict parsing is enabled, unknown options or options that cannot be converted to * the expected java type using an {@link ObjectMapper} will be ignored. @@ -1442,10 +1447,26 @@ public boolean apply(@Nullable String input) { JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); if ("runner".equals(entry.getKey())) { String runner = Iterables.getOnlyElement(entry.getValue()); - Preconditions.checkArgument(SUPPORTED_PIPELINE_RUNNERS.containsKey(runner), - "Unknown 'runner' specified '%s', supported pipeline runners %s", - runner, Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); - convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + if (SUPPORTED_PIPELINE_RUNNERS.containsKey(runner)) { + convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + } else { + try { + Class runnerClass = Class.forName(runner); + checkArgument( + PipelineRunner.class.isAssignableFrom(runnerClass), + "Class '%s' does not implement PipelineRunner. Supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + convertedOptions.put("runner", runnerClass); + } catch (ClassNotFoundException e) { + String msg = + String.format( + "Unknown 'runner' specified '%s', supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + throw new IllegalArgumentException(msg, e); + } + } } else if ((returnType.isArray() && (SIMPLE_TYPES.contains(returnType.getComponentType()) || returnType.getComponentType().isEnum())) || Collection.class.isAssignableFrom(returnType)) { diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java index e687f27989..045a8ad0f2 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java @@ -25,8 +25,12 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; import com.google.common.collect.ArrayListMultimap; @@ -824,6 +828,14 @@ public void testSettingRunner() { assertEquals(BlockingDataflowPipelineRunner.class, options.getRunner()); } + @Test + public void testSettingRunnerFullName() { + String[] args = + new String[] {String.format("--runner=%s", DataflowPipelineRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), DataflowPipelineRunner.class); + } + @Test public void testSettingUnknownRunner() { String[] args = new String[] {"--runner=UnknownRunner"}; @@ -834,6 +846,30 @@ public void testSettingUnknownRunner() { PipelineOptionsFactory.fromArgs(args).create(); } + private static class ExampleTestRunner extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + return null; + } + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedExists() { + String[] args = new String[] {String.format("--runner=%s", ExampleTestRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), ExampleTestRunner.class); + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedNotPipelineRunner() { + String[] args = new String[] {"--runner=java.lang.String"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("does not implement PipelineRunner"); + expectedException.expectMessage("java.lang.String"); + + PipelineOptionsFactory.fromArgs(args).create(); + } + @Test public void testUsingArgumentWithUnknownPropertyIsNotAllowed() { String[] args = new String[] {"--unknownProperty=value"}; From 30f8f6b30fbcb083dd5617160c6af10a3dc4f3d0 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 7 Mar 2016 13:47:30 -0800 Subject: [PATCH 09/11] Filter Synthetic Methods in PipelineOptionsFactory --- .../dataflow/sdk/options/PipelineOptions.java | 3 +- .../sdk/options/PipelineOptionsFactory.java | 41 ++++++--- .../PipelineOptionsFactoryJava8Test.java | 90 +++++++++++++++++++ 3 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java index 923033d5da..8ff1fa9783 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptions.java @@ -137,7 +137,8 @@ * *

{@link Default @Default} represents a set of annotations that can be used to annotate getter * properties on {@link PipelineOptions} with information representing the default value to be - * returned if no value is specified. + * returned if no value is specified. Any default implementation (using the {@code default} keyword) + * is ignored. * *

{@link Hidden @Hidden} hides an option from being listed when {@code --help} * is invoked via {@link PipelineOptionsFactory#fromArgs(String[])}. diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java index 48cff6df93..4781d1c829 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -445,13 +445,21 @@ Class getProxyClass() { private static final Map>> SUPPORTED_PIPELINE_RUNNERS; /** Classes that are used as the boundary in the stack trace to find the callers class name. */ - private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = ImmutableSet.of( - PipelineOptionsFactory.class.getName(), - Builder.class.getName()); + private static final Set PIPELINE_OPTIONS_FACTORY_CLASSES = + ImmutableSet.of(PipelineOptionsFactory.class.getName(), Builder.class.getName()); /** Methods that are ignored when validating the proxy class. */ private static final Set IGNORED_METHODS; + /** A predicate that checks if a method is synthetic via {@link Method#isSynthetic()}. */ + private static final Predicate NOT_SYNTHETIC_PREDICATE = + new Predicate() { + @Override + public boolean apply(Method input) { + return !input.isSynthetic(); + } + }; + /** The set of options that have been registered and visible to the user. */ private static final Set> REGISTERED_OPTIONS = Sets.newConcurrentHashSet(); @@ -664,7 +672,9 @@ public static void printHelp(PrintStream out, Class i Preconditions.checkNotNull(iface); validateWellFormed(iface, REGISTERED_OPTIONS); - Iterable methods = ReflectHelpers.getClosureOfMethodsOnInterface(iface); + Iterable methods = + Iterables.filter( + ReflectHelpers.getClosureOfMethodsOnInterface(iface), NOT_SYNTHETIC_PREDICATE); ListMultimap, Method> ifaceToMethods = ArrayListMultimap.create(); for (Method method : methods) { // Process only methods that are not marked as hidden. @@ -878,7 +888,8 @@ private static List getPropertyDescriptors(Class beanClas throws IntrospectionException { // The sorting is important to make this method stable. SortedSet methods = Sets.newTreeSet(MethodComparator.INSTANCE); - methods.addAll(Arrays.asList(beanClass.getMethods())); + methods.addAll( + Collections2.filter(Arrays.asList(beanClass.getMethods()), NOT_SYNTHETIC_PREDICATE)); SortedMap propertyNamesToGetters = getPropertyNamesToGetters(methods); List descriptors = Lists.newArrayList(); @@ -1019,8 +1030,9 @@ private static List validateClass(Class klass) throws IntrospectionException { Set methods = Sets.newHashSet(IGNORED_METHODS); // Ignore static methods, "equals", "hashCode", "toString" and "as" on the generated class. + // Ignore synthetic methods for (Method method : klass.getMethods()) { - if (Modifier.isStatic(method.getModifiers())) { + if (Modifier.isStatic(method.getModifiers()) || method.isSynthetic()) { methods.add(method); } } @@ -1037,6 +1049,7 @@ private static List validateClass(Class interfaceMethods = FluentIterable .from(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .filter(NOT_SYNTHETIC_PREDICATE) .toSortedSet(MethodComparator.INSTANCE); SortedSetMultimap methodNameToMethodMap = TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); @@ -1061,10 +1074,13 @@ private static List validateClass(Class allInterfaceMethods = FluentIterable - .from(ReflectHelpers.getClosureOfMethodsOnInterfaces(validatedPipelineOptionsInterfaces)) - .append(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) - .toSortedSet(MethodComparator.INSTANCE); + Iterable allInterfaceMethods = + FluentIterable.from( + ReflectHelpers.getClosureOfMethodsOnInterfaces( + validatedPipelineOptionsInterfaces)) + .append(ReflectHelpers.getClosureOfMethodsOnInterface(iface)) + .filter(NOT_SYNTHETIC_PREDICATE) + .toSortedSet(MethodComparator.INSTANCE); SortedSetMultimap methodNameToAllMethodMap = TreeMultimap.create(MethodNameComparator.INSTANCE, MethodComparator.INSTANCE); for (Method method : allInterfaceMethods) { @@ -1148,7 +1164,10 @@ private static List validateClass(Class unknownMethods = new TreeSet<>(MethodComparator.INSTANCE); - unknownMethods.addAll(Sets.difference(Sets.newHashSet(klass.getMethods()), methods)); + unknownMethods.addAll( + Sets.filter( + Sets.difference(Sets.newHashSet(klass.getMethods()), methods), + NOT_SYNTHETIC_PREDICATE)); Preconditions.checkArgument(unknownMethods.isEmpty(), "Methods %s on [%s] do not conform to being bean properties.", FluentIterable.from(unknownMethods).transform(ReflectHelpers.METHOD_FORMATTER), diff --git a/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java new file mode 100644 index 0000000000..b7e1467436 --- /dev/null +++ b/sdk/src/test/java8/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryJava8Test.java @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.options; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThat; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Java 8 tests for {@link PipelineOptionsFactory}. + */ +@RunWith(JUnit4.class) +public class PipelineOptionsFactoryJava8Test { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private static interface OptionsWithDefaultMethod extends PipelineOptions { + default Number getValue() { + return 1024; + } + + void setValue(Number value); + } + + @Test + public void testDefaultMethodIgnoresDefaultImplementation() { + OptionsWithDefaultMethod optsWithDefault = + PipelineOptionsFactory.as(OptionsWithDefaultMethod.class); + assertThat(optsWithDefault.getValue(), nullValue()); + + optsWithDefault.setValue(12.25); + assertThat(optsWithDefault.getValue(), equalTo(Double.valueOf(12.25))); + } + + private static interface ExtendedOptionsWithDefault extends OptionsWithDefaultMethod {} + + @Test + public void testDefaultMethodInExtendedClassIgnoresDefaultImplementation() { + OptionsWithDefaultMethod extendedOptsWithDefault = + PipelineOptionsFactory.as(ExtendedOptionsWithDefault.class); + assertThat(extendedOptsWithDefault.getValue(), nullValue()); + + extendedOptsWithDefault.setValue(Double.NEGATIVE_INFINITY); + assertThat(extendedOptsWithDefault.getValue(), equalTo(Double.NEGATIVE_INFINITY)); + } + + private static interface Options extends PipelineOptions { + Number getValue(); + + void setValue(Number value); + } + + private static interface SubtypeReturingOptions extends Options { + @Override + Integer getValue(); + void setValue(Integer value); + } + + @Test + public void testReturnTypeConflictThrows() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "Method [getValue] has multiple definitions [public abstract java.lang.Integer " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$" + + "SubtypeReturingOptions.getValue(), public abstract java.lang.Number " + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$Options" + + ".getValue()] with different return types for [" + + "com.google.cloud.dataflow.sdk.options.PipelineOptionsFactoryJava8Test$" + + "SubtypeReturingOptions]."); + PipelineOptionsFactory.as(SubtypeReturingOptions.class); + } +} From 42d1f15854184901c8abf4c85a6698848ed3d105 Mon Sep 17 00:00:00 2001 From: Scott Wegner Date: Thu, 17 Mar 2016 10:22:42 -0700 Subject: [PATCH 10/11] Add DisplayData builder API to SDK This allows generating the display data which will be attached to PTransforms. --- sdk/pom.xml | 7 + .../cloud/dataflow/sdk/transforms/DoFn.java | 13 +- .../dataflow/sdk/transforms/PTransform.java | 14 +- .../cloud/dataflow/sdk/transforms/ParDo.java | 13 + .../sdk/transforms/display/DisplayData.java | 517 ++++++++++++++ .../transforms/display/HasDisplayData.java | 53 ++ .../dataflow/sdk/transforms/DoFnTest.java | 15 + .../sdk/transforms/PTransformTest.java | 41 ++ .../dataflow/sdk/transforms/ParDoTest.java | 23 + .../display/DisplayDataMatchers.java | 98 +++ .../display/DisplayDataMatchersTest.java | 81 +++ .../transforms/display/DisplayDataTest.java | 633 ++++++++++++++++++ .../dataflow/sdk/util/ApiSurfaceTest.java | 3 +- 13 files changed, 1508 insertions(+), 3 deletions(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java diff --git a/sdk/pom.xml b/sdk/pom.xml index d7e10a53a8..2639bc13db 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -685,6 +685,13 @@ ${guava.version} + + com.google.guava + guava-testlib + ${guava.version} + test + + com.google.protobuf protobuf-java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java index af06cc8796..5ba9992143 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/DoFn.java @@ -24,6 +24,8 @@ import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.transforms.display.HasDisplayData; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; import com.google.cloud.dataflow.sdk.util.WindowingInternals; @@ -69,7 +71,7 @@ * @param the type of the (main) input elements * @param the type of the (main) output elements */ -public abstract class DoFn implements Serializable { +public abstract class DoFn implements Serializable, HasDisplayData { /** * Information accessible to all methods in this {@code DoFn}. @@ -366,6 +368,15 @@ public void startBundle(Context c) throws Exception { public void finishBundle(Context c) throws Exception { } + /** + * {@inheritDoc} + * + *

By default, does not register any display data. Implementors may override this method + * to provide their own display metadata. + */ + @Override + public void populateDisplayData(DisplayData.Builder builder) { + } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java index 8a7450997a..d4496b8c74 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/PTransform.java @@ -19,6 +19,8 @@ import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.transforms.display.HasDisplayData; import com.google.cloud.dataflow.sdk.util.StringUtils; import com.google.cloud.dataflow.sdk.values.PInput; import com.google.cloud.dataflow.sdk.values.POutput; @@ -168,7 +170,7 @@ * @param the type of the output of this PTransform */ public abstract class PTransform - implements Serializable /* See the note above */ { + implements Serializable /* See the note above */, HasDisplayData { /** * Applies this {@code PTransform} on the given {@code InputT}, and returns its * {@code Output}. @@ -309,4 +311,14 @@ public Coder getDefaultOutputCoder( Coder defaultOutputCoder = (Coder) getDefaultOutputCoder(input); return defaultOutputCoder; } + + /** + * {@inheritDoc} + * + *

By default, does not register any display data. Implementors may override this method + * to provide their own display metadata. + */ + @Override + public void populateDisplayData(Builder builder) { + } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java index 0922767adc..c77ac4447c 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/ParDo.java @@ -22,6 +22,7 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderException; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; import com.google.cloud.dataflow.sdk.util.DirectModeExecutionContext; import com.google.cloud.dataflow.sdk.util.DirectSideInputReader; @@ -787,6 +788,18 @@ protected String getKindString() { } } + /** + * {@inheritDoc} + * + *

{@link ParDo} registers its internal {@link DoFn} as a subcomponent for display metadata. + * {@link DoFn} implementations can register display data by overriding + * {@link DoFn#populateDisplayData}. + */ + @Override + public void populateDisplayData(Builder builder) { + builder.include(fn); + } + public DoFn getFn() { return fn; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java new file mode 100644 index 0000000000..05fa7c7881 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java @@ -0,0 +1,517 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + +import org.apache.avro.reflect.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; + +import java.util.Collection; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Static display metadata associated with a pipeline component. Display data is useful for + * pipeline runner UIs and diagnostic dashboards to display details about + * {@link PTransform PTransforms} that make up a pipeline. + * + *

Components specify their display data by implementing the {@link HasDisplayData} + * interface. + */ +public class DisplayData { + private static final DisplayData EMPTY = new DisplayData(Maps.newHashMap()); + private static final DateTimeFormatter TIMESTAMP_FORMATTER = ISODateTimeFormat.dateTime(); + + private final ImmutableMap entries; + + private DisplayData(Map entries) { + this.entries = ImmutableMap.copyOf(entries); + } + + /** + * Default empty {@link DisplayData} instance. + */ + public static DisplayData none() { + return EMPTY; + } + + /** + * Collect the {@link DisplayData} from a component. This will traverse all subcomponents + * specified via {@link Builder#include} in the given component. Data in this component will be in + * a namespace derived from the component. + */ + public static DisplayData from(HasDisplayData component) { + checkNotNull(component); + return InternalBuilder.forRoot(component).build(); + } + + public Collection items() { + return entries.values(); + } + + public Map asMap() { + return entries; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + boolean isFirstLine = true; + for (Map.Entry entry : entries.entrySet()) { + if (isFirstLine) { + isFirstLine = false; + } else { + builder.append("\n"); + } + + builder.append(entry); + } + + return builder.toString(); + } + + /** + * Utility to build up display metadata from a component and its included + * subcomponents. + */ + public interface Builder { + /** + * Include display metadata from the specified subcomponent. For example, a {@link ParDo} + * transform includes display metadata from the encapsulated {@link DoFn}. + * + * @return A builder instance to continue to build in a fluent-style. + */ + Builder include(HasDisplayData subComponent); + + /** + * Register the given string display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#STRING}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, String value); + + /** + * Register the given numeric display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#INTEGER}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, long value); + + /** + * Register the given floating point display metadata. The metadata item will be registered with + * type {@link DisplayData.Type#FLOAT}, and is identified by the specified key and namespace + * from the current transform or component. + */ + ItemBuilder add(String key, double value); + + /** + * Register the given timestamp display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#TIMESTAMP}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, Instant value); + + /** + * Register the given duration display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#DURATION}, and is identified by the specified key and namespace from + * the current transform or component. + */ + ItemBuilder add(String key, Duration value); + + /** + * Register the given class display metadata. The metadata item will be registered with type + * {@link DisplayData.Type#JAVA_CLASS}, and is identified by the specified key and namespace + * from the current transform or component. + */ + ItemBuilder add(String key, Class value); + } + + /** + * Utility to append optional fields to display metadata, or register additional display metadata + * items. + */ + public interface ItemBuilder extends Builder { + /** + * Add a human-readable label to describe the most-recently added metadata field. + * A label is optional; if unspecified, UIs should display the metadata key to identify the + * display item. + * + *

Specifying a null value will clear the label if it was previously defined. + */ + ItemBuilder withLabel(@Nullable String label); + + /** + * Add a link URL to the most-recently added display metadata. A link URL is optional and + * can be provided to point the reader to additional details about the metadata. + * + *

Specifying a null value will clear the URL if it was previously defined. + */ + ItemBuilder withLinkUrl(@Nullable String url); + } + + /** + * A display metadata item. DisplayData items are registered via {@link Builder#add} within + * {@link HasDisplayData#populateDisplayData} implementations. Each metadata item is uniquely + * identified by the specified key and namespace generated from the registering component's + * class name. + */ + public static class Item { + private final String key; + private final String ns; + private final Type type; + private final String value; + private final String shortValue; + private final String label; + private final String url; + + private static Item create(String namespace, String key, Type type, T value) { + FormattedItemValue formatted = type.format(value); + return new Item( + namespace, key, type, formatted.getLongValue(), formatted.getShortValue(), null, null); + } + + private Item( + String namespace, + String key, + Type type, + String value, + String shortValue, + String url, + String label) { + this.ns = namespace; + this.key = key; + this.type = type; + this.value = value; + this.shortValue = shortValue; + this.url = url; + this.label = label; + } + + public String getNamespace() { + return ns; + } + + public String getKey() { + return key; + } + + /** + * Retrieve the {@link DisplayData.Type} of display metadata. All metadata conforms to a + * predefined set of allowed types. + */ + public Type getType() { + return type; + } + + /** + * Retrieve the value of the metadata item. + */ + public String getValue() { + return value; + } + + /** + * Return the optional short value for an item. Types may provide a short-value to displayed + * instead of or in addition to the full {@link Item#value}. + * + *

Some display data types will not provide a short value, in which case the return value + * will be null. + */ + @Nullable + public String getShortValue() { + return shortValue; + } + + /** + * Retrieve the optional label for an item. The label is a human-readable description of what + * the metadata represents. UIs may choose to display the label instead of the item key. + * + *

If no label was specified, this will return {@code null}. + */ + @Nullable + public String getLabel() { + return label; + } + + /** + * Retrieve the optional link URL for an item. The URL points to an address where the reader + * can find additional context for the display metadata. + * + *

If no URL was specified, this will return {@code null}. + */ + @Nullable + public String getUrl() { + return url; + } + + @Override + public String toString() { + return getValue(); + } + + private Item withLabel(String label) { + return new Item(this.ns, this.key, this.type, this.value, this.shortValue, this.url, label); + } + + private Item withUrl(String url) { + return new Item(this.ns, this.key, this.type, this.value, this.shortValue, url, this.label); + } + } + + /** + * Unique identifier for a display metadata item within a component. + * Identifiers are composed of the key they are registered with and a namespace generated from + * the class of the component which registered the item. + * + *

Display metadata registered with the same key from different components will have different + * namespaces and thus will both be represented in the composed {@link DisplayData}. If a + * single component registers multiple metadata items with the same key, only the most recent + * item will be retained; previous versions are discarded. + */ + public static class Identifier { + private final String ns; + private final String key; + + static Identifier of(Class namespace, String key) { + return new Identifier(namespace.getName(), key); + } + + private Identifier(String ns, String key) { + this.ns = ns; + this.key = key; + } + + public String getNamespace() { + return ns; + } + + public String getKey() { + return key; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof Identifier) { + Identifier that = (Identifier) obj; + return Objects.equals(this.ns, that.ns) + && Objects.equals(this.key, that.key); + } + + return false; + } + + @Override + public int hashCode() { + return Objects.hash(ns, key); + } + + @Override + public String toString() { + return String.format("%s:%s", ns, key); + } + } + + /** + * Display metadata type. + */ + enum Type { + STRING { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue((String) value); + } + }, + INTEGER { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Long.toString((long) value)); + } + }, + FLOAT { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Double.toString((Double) value)); + } + }, + TIMESTAMP() { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue((TIMESTAMP_FORMATTER.print((Instant) value))); + } + }, + DURATION { + @Override + FormattedItemValue format(Object value) { + return new FormattedItemValue(Long.toString(((Duration) value).getMillis())); + } + }, + JAVA_CLASS { + @Override + FormattedItemValue format(Object value) { + Class clazz = (Class) value; + return new FormattedItemValue(clazz.getName(), clazz.getSimpleName()); + } + }; + + /** + * Format the display metadata value into a long string representation, and optionally + * a shorter representation for display. + * + *

Internal-only. Value objects can be safely cast to the expected Java type. + */ + abstract FormattedItemValue format(Object value); + } + + private static class FormattedItemValue { + private final String shortValue; + private final String longValue; + + private FormattedItemValue(String longValue) { + this(longValue, null); + } + + private FormattedItemValue(String longValue, String shortValue) { + this.longValue = longValue; + this.shortValue = shortValue; + } + + private String getLongValue () { + return this.longValue; + } + + private String getShortValue() { + return this.shortValue; + } + } + + private static class InternalBuilder implements ItemBuilder { + private final Map entries; + private final Set visited; + + private Class latestNs; + private Item latestItem; + private Identifier latestIdentifier; + + private InternalBuilder() { + this.entries = Maps.newHashMap(); + this.visited = Sets.newIdentityHashSet(); + } + + private static InternalBuilder forRoot(HasDisplayData instance) { + InternalBuilder builder = new InternalBuilder(); + builder.include(instance); + return builder; + } + + @Override + public Builder include(HasDisplayData subComponent) { + checkNotNull(subComponent); + boolean newComponent = visited.add(subComponent); + if (newComponent) { + Class prevNs = this.latestNs; + this.latestNs = subComponent.getClass(); + subComponent.populateDisplayData(this); + this.latestNs = prevNs; + } + + return this; + } + + @Override + public ItemBuilder add(String key, String value) { + checkNotNull(value); + return addItem(key, Type.STRING, value); + } + + @Override + public ItemBuilder add(String key, long value) { + return addItem(key, Type.INTEGER, value); + } + + @Override + public ItemBuilder add(String key, double value) { + return addItem(key, Type.FLOAT, value); + } + + @Override + public ItemBuilder add(String key, Instant value) { + checkNotNull(value); + return addItem(key, Type.TIMESTAMP, value); + } + + @Override + public ItemBuilder add(String key, Duration value) { + checkNotNull(value); + return addItem(key, Type.DURATION, value); + } + + @Override + public ItemBuilder add(String key, Class value) { + checkNotNull(value); + return addItem(key, Type.JAVA_CLASS, value); + } + + private ItemBuilder addItem(String key, Type type, T value) { + checkNotNull(key); + checkArgument(!key.isEmpty()); + + Identifier id = Identifier.of(latestNs, key); + if (entries.containsKey(id)) { + throw new IllegalArgumentException("DisplayData key already exists. All display data " + + "for a component must be registered with a unique key.\nKey: " + id); + } + Item item = Item.create(id.getNamespace(), key, type, value); + entries.put(id, item); + + latestItem = item; + latestIdentifier = id; + + return this; + } + + @Override + public ItemBuilder withLabel(String label) { + latestItem = latestItem.withLabel(label); + entries.put(latestIdentifier, latestItem); + return this; + } + + @Override + public ItemBuilder withLinkUrl(String url) { + latestItem = latestItem.withUrl(url); + entries.put(latestIdentifier, latestItem); + return this; + } + + private DisplayData build() { + return new DisplayData(this.entries); + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java new file mode 100644 index 0000000000..b2eca3d881 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/HasDisplayData.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +/** + * Marker interface for {@link PTransform PTransforms} and components used within + * {@link PTransform PTransforms} to specify display metadata to be used within UIs and diagnostic + * tools. + * + *

Display metadata is optional and may be collected during pipeline construction. It should + * only be used to informational purposes. Tools and components should not assume that display data + * will always be collected, or that collected display data will always be displayed. + */ +public interface HasDisplayData { + /** + * Register display metadata for the given transform or component. Metadata can be registered + * directly on the provided builder, as well as via included sub-components. + * + *

+   * {@code
+   * @Override
+   * public void populateDisplayData(DisplayData.Builder builder) {
+   *  builder
+   *     .include(subComponent)
+   *     .add("minFilter", 42)
+   *     .add("topic", "projects/myproject/topics/mytopic")
+   *       .withLabel("Pub/Sub Topic")
+   *     .add("serviceInstance", "myservice.com/fizzbang")
+   *       .withLinkUrl("http://www.myservice.com/fizzbang");
+   * }
+   * }
+   * 
+ * + * @param builder The builder to populate with display metadata. + */ + void populateDisplayData(DisplayData.Builder builder); +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java index a709a233ab..dabad7bae7 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/DoFnTest.java @@ -17,14 +17,17 @@ package com.google.cloud.dataflow.sdk.transforms; import static org.hamcrest.CoreMatchers.isA; +import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; import com.google.cloud.dataflow.sdk.testing.RunnableOnService; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; import org.junit.Rule; import org.junit.Test; @@ -188,4 +191,16 @@ private TestPipeline createTestPipeline(DoFn return pipeline; } + + @Test + public void testPopulateDisplayDataDefaultBehavior() { + DoFn usesDefault = + new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception {} + }; + + DisplayData data = DisplayData.from(usesDefault); + assertThat(data.items(), empty()); + } } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java new file mode 100644 index 0000000000..cea1b38df7 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/PTransformTest.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms; + +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link PTransform} base class. + */ +@RunWith(JUnit4.class) +public class PTransformTest { + @Test + public void testPopulateDisplayDataDefaultBehavior() { + PTransform, PCollection> transform = + new PTransform, PCollection>() {}; + DisplayData displayData = DisplayData.from(transform); + assertThat(displayData.items(), empty()); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java index f3f9bde92d..1ff46e41c2 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/ParDoTest.java @@ -16,6 +16,8 @@ package com.google.cloud.dataflow.sdk.transforms; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; import static com.google.cloud.dataflow.sdk.util.SerializableUtils.serializeToByteArray; import static com.google.cloud.dataflow.sdk.util.StringUtils.byteArrayToJsonString; import static com.google.cloud.dataflow.sdk.util.StringUtils.jsonStringToByteArray; @@ -39,6 +41,9 @@ import com.google.cloud.dataflow.sdk.testing.RunnableOnService; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.DoFn.RequiresWindowAccess; +import com.google.cloud.dataflow.sdk.transforms.ParDo.Bound; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; import com.google.cloud.dataflow.sdk.transforms.windowing.Window; import com.google.cloud.dataflow.sdk.util.IllegalMutationException; @@ -1515,4 +1520,22 @@ public void testMutatingInputCoderDoFnError() throws Exception { thrown.expectMessage("must not be mutated"); pipeline.run(); } + + @Test + public void testIncludesDoFnDisplayData() { + Bound parDo = + ParDo.of( + new DoFn() { + @Override + public void processElement(ProcessContext c) {} + + @Override + public void populateDisplayData(Builder builder) { + builder.add("foo", "bar"); + } + }); + + DisplayData displayData = DisplayData.from(parDo); + assertThat(displayData, hasDisplayItem(hasKey("foo"))); + } } diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java new file mode 100644 index 0000000000..2753aafa7b --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Item; + +import org.hamcrest.Description; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeDiagnosingMatcher; + +import java.util.Collection; + +/** + * Hamcrest matcher for making assertions on {@link DisplayData} instances. + */ +public class DisplayDataMatchers { + /** + * Do not instantiate. + */ + private DisplayDataMatchers() {} + + /** + * Creates a matcher that matches if the examined {@link DisplayData} contains any items. + */ + public static Matcher hasDisplayItem() { + return hasDisplayItem(Matchers.any(DisplayData.Item.class)); + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData} contains any item + * matching the specified {@code itemMatcher}. + */ + public static Matcher hasDisplayItem(Matcher itemMatcher) { + return new HasDisplayDataItemMatcher(itemMatcher); + } + + private static class HasDisplayDataItemMatcher extends TypeSafeDiagnosingMatcher { + private final Matcher itemMatcher; + + private HasDisplayDataItemMatcher(Matcher itemMatcher) { + this.itemMatcher = itemMatcher; + } + + @Override + public void describeTo(Description description) { + description.appendText("display data with item: "); + itemMatcher.describeTo(description); + } + + @Override + protected boolean matchesSafely(DisplayData data, Description mismatchDescription) { + Collection items = data.items(); + boolean isMatch = Matchers.hasItem(itemMatcher).matches(items); + if (!isMatch) { + mismatchDescription.appendText("found " + items.size() + " non-matching items"); + } + + return isMatch; + } + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData.Item} contains a key + * with the specified value. + */ + public static Matcher hasKey(String key) { + return hasKey(Matchers.is(key)); + } + + /** + * Creates a matcher that matches if the examined {@link DisplayData.Item} contains a key + * matching the specified key matcher. + */ + public static Matcher hasKey(Matcher keyMatcher) { + return new FeatureMatcher(keyMatcher, "with key", "key") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getKey(); + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java new file mode 100644 index 0000000000..2636cf85c8 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchersTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.StringDescription; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link DisplayDataMatchers}. + */ +@RunWith(JUnit4.class) +public class DisplayDataMatchersTest { + @Test + public void testHasDisplayItem() { + Matcher matcher = hasDisplayItem(); + + assertFalse(matcher.matches(DisplayData.none())); + assertTrue(matcher.matches(createDisplayDataWithItem("foo", "bar"))); + } + + @Test + public void testHasDisplayItemDescription() { + Matcher matcher = hasDisplayItem(); + Description desc = new StringDescription(); + Description mismatchDesc = new StringDescription(); + + matcher.describeTo(desc); + matcher.describeMismatch(DisplayData.none(), mismatchDesc); + + assertThat(desc.toString(), startsWith("display data with item: ")); + assertThat(mismatchDesc.toString(), containsString("found 0 non-matching items")); + } + + @Test + public void testHasKey() { + Matcher matcher = hasDisplayItem(hasKey("foo")); + + assertTrue(matcher.matches(createDisplayDataWithItem("foo", "bar"))); + assertFalse(matcher.matches(createDisplayDataWithItem("fooz", "bar"))); + } + + private DisplayData createDisplayDataWithItem(final String key, final String value) { + return DisplayData.from( + new PTransform, PCollection>() { + @Override + public void populateDisplayData(Builder builder) { + builder.add(key, value); + } + }); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java new file mode 100644 index 0000000000..13dd618657 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java @@ -0,0 +1,633 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package com.google.cloud.dataflow.sdk.transforms.display; + +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers.hasKey; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.everyItem; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isEmptyOrNullString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; +import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Item; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.testing.EqualsTester; + +import org.hamcrest.CustomTypeSafeMatcher; +import org.hamcrest.FeatureMatcher; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Tests for {@link DisplayData} class. + */ +@RunWith(JUnit4.class) +public class DisplayDataTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + private static final DateTimeFormatter ISO_FORMATTER = ISODateTimeFormat.dateTime(); + + @Test + public void testTypicalUsage() { + final HasDisplayData subComponent1 = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("ExpectedAnswer", 42); + } + }; + + final HasDisplayData subComponent2 = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("Location", "Seattle").add("Forecast", "Rain"); + } + }; + + PTransform transform = + new PTransform, PCollection>() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(subComponent1) + .include(subComponent2) + .add("MinSproggles", 200) + .withLabel("Mimimum Required Sproggles") + .add("LazerOrientation", "NORTH") + .add("TimeBomb", Instant.now().plus(Duration.standardDays(1))) + .add("FilterLogic", subComponent1.getClass()) + .add("ServiceUrl", "google.com/fizzbang") + .withLinkUrl("http://www.google.com/fizzbang"); + } + }; + + DisplayData data = DisplayData.from(transform); + + assertThat(data.items(), not(empty())); + assertThat( + data.items(), + everyItem( + allOf( + hasKey(not(isEmptyOrNullString())), + hasNamespace( + Matchers.>isOneOf( + transform.getClass(), subComponent1.getClass(), subComponent2.getClass())), + hasType(notNullValue(DisplayData.Type.class)), + hasValue(not(isEmptyOrNullString()))))); + } + + @Test + public void testDefaultInstance() { + DisplayData none = DisplayData.none(); + assertThat(none.items(), empty()); + } + + @Test + public void testCanBuild() { + DisplayData data = + DisplayData.from(new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("Foo", "bar"); + } + }); + + assertThat(data.items(), hasSize(1)); + assertThat(data, hasDisplayItem(hasKey("Foo"))); + } + + @Test + public void testAsMap() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + Map map = data.asMap(); + assertEquals(map.size(), 1); + assertThat(data, hasDisplayItem(hasKey("foo"))); + assertEquals(map.values(), data.items()); + } + + @Test + public void testItemProperties() { + final Instant value = Instant.now(); + DisplayData data = DisplayData.from(new ConcreteComponent(value)); + + @SuppressWarnings("unchecked") + DisplayData.Item item = (DisplayData.Item) data.items().toArray()[0]; + assertThat( + item, + allOf( + hasNamespace(Matchers.>is(ConcreteComponent.class)), + hasKey("now"), + hasType(is(DisplayData.Type.TIMESTAMP)), + hasValue(is(ISO_FORMATTER.print(value))), + hasShortValue(nullValue(String.class)), + hasLabel(is("the current instant")), + hasUrl(is("http://time.gov")))); + } + + static class ConcreteComponent implements HasDisplayData { + private Instant value; + + ConcreteComponent(Instant value) { + this.value = value; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("now", value).withLabel("the current instant").withLinkUrl("http://time.gov"); + } + } + + @Test + public void testUnspecifiedOptionalProperties() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + assertThat( + data, + hasDisplayItem(allOf(hasLabel(nullValue(String.class)), hasUrl(nullValue(String.class))))); + } + + @Test + public void testIncludes() { + final HasDisplayData subComponent = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.include(subComponent); + } + }); + + assertThat( + data, + hasDisplayItem( + allOf( + hasKey("foo"), + hasNamespace(Matchers.>is(subComponent.getClass()))))); + } + + @Test + public void testIdentifierEquality() { + new EqualsTester() + .addEqualityGroup( + DisplayData.Identifier.of(DisplayDataTest.class, "1"), + DisplayData.Identifier.of(DisplayDataTest.class, "1")) + .addEqualityGroup(DisplayData.Identifier.of(Object.class, "1")) + .addEqualityGroup(DisplayData.Identifier.of(DisplayDataTest.class, "2")) + .testEquals(); + } + + @Test + public void testAnonymousClassNamespace() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + + DisplayData.Item item = (DisplayData.Item) data.items().toArray()[0]; + final Pattern anonClassRegex = Pattern.compile( + Pattern.quote(DisplayDataTest.class.getName()) + "\\$\\d+$"); + assertThat(item.getNamespace(), new CustomTypeSafeMatcher( + "anonymous class regex: " + anonClassRegex) { + @Override + protected boolean matchesSafely(String item) { + java.util.regex.Matcher m = anonClassRegex.matcher(item); + return m.matches(); + } + }); + } + + @Test + public void testAcceptsKeysWithDifferentNamespaces() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("foo", "bar") + .include( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }); + } + }); + + assertThat(data.items(), hasSize(2)); + } + + @Test + public void testDuplicateKeyThrowsException() { + thrown.expect(IllegalArgumentException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("foo", "bar") + .add("foo", "baz"); + } + }); + } + + @Test + public void testToString() { + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + DisplayData data = DisplayData.from(component); + assertEquals(String.format("%s:foo=bar", component.getClass().getName()), data.toString()); + } + + @Test + public void testHandlesIncludeCycles() { + + final IncludeSubComponent componentA = + new IncludeSubComponent() { + @Override + String getId() { + return "componentA"; + } + }; + final IncludeSubComponent componentB = + new IncludeSubComponent() { + @Override + String getId() { + return "componentB"; + } + }; + + HasDisplayData component = + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.include(componentA); + } + }; + + componentA.subComponent = componentB; + componentB.subComponent = componentA; + + DisplayData data = DisplayData.from(component); + assertThat(data.items(), hasSize(2)); + } + + @Test + public void testIncludesSubcomponentsWithObjectEquality() { + DisplayData data = DisplayData.from(new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(new EqualsEverything("foo1", "bar1")) + .include(new EqualsEverything("foo2", "bar2")); + } + }); + + assertThat(data.items(), hasSize(2)); + } + + private static class EqualsEverything implements HasDisplayData { + private final String value; + private final String key; + EqualsEverything(String key, String value) { + this.key = key; + this.value = value; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(key, value); + } + + @Override + public int hashCode() { + return 1; + } + + @Override + public boolean equals(Object obj) { + return true; + } + } + + abstract static class IncludeSubComponent implements HasDisplayData { + HasDisplayData subComponent; + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("id", getId()).include(subComponent); + } + + abstract String getId(); + } + + @Test + public void testTypeMappings() { + DisplayData data = + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("string", "foobar") + .add("integer", 123) + .add("float", 3.14) + .add("java_class", DisplayDataTest.class) + .add("timestamp", Instant.now()) + .add("duration", Duration.standardHours(1)); + } + }); + + Collection items = data.items(); + assertThat( + items, hasItem(allOf(hasKey("string"), hasType(is(DisplayData.Type.STRING))))); + assertThat( + items, hasItem(allOf(hasKey("integer"), hasType(is(DisplayData.Type.INTEGER))))); + assertThat(items, hasItem(allOf(hasKey("float"), hasType(is(DisplayData.Type.FLOAT))))); + assertThat( + items, + hasItem(allOf(hasKey("java_class"), hasType(is(DisplayData.Type.JAVA_CLASS))))); + assertThat( + items, + hasItem(allOf(hasKey("timestamp"), hasType(is(DisplayData.Type.TIMESTAMP))))); + assertThat( + items, hasItem(allOf(hasKey("duration"), hasType(is(DisplayData.Type.DURATION))))); + } + + @Test + public void testStringFormatting() throws IOException { + final Instant now = Instant.now(); + final Duration oneHour = Duration.standardHours(1); + + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("string", "foobar") + .add("integer", 123) + .add("float", 3.14) + .add("java_class", DisplayDataTest.class) + .add("timestamp", now) + .add("duration", oneHour); + } + }; + DisplayData data = DisplayData.from(component); + + Collection items = data.items(); + assertThat(items, hasItem(allOf(hasKey("string"), hasValue(is("foobar"))))); + assertThat(items, hasItem(allOf(hasKey("integer"), hasValue(is("123"))))); + assertThat(items, hasItem(allOf(hasKey("float"), hasValue(is("3.14"))))); + assertThat(items, hasItem(allOf(hasKey("java_class"), + hasValue(is(DisplayDataTest.class.getName())), + hasShortValue(is(DisplayDataTest.class.getSimpleName()))))); + assertThat(items, hasItem(allOf(hasKey("timestamp"), + hasValue(is(ISO_FORMATTER.print(now)))))); + assertThat(items, hasItem(allOf(hasKey("duration"), + hasValue(is(Long.toString(oneHour.getMillis())))))); + } + + @Test + public void testContextProperlyReset() { + final HasDisplayData subComponent = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + HasDisplayData component = + new HasDisplayData() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .include(subComponent) + .add("alpha", "bravo"); + } + }; + + DisplayData data = DisplayData.from(component); + assertThat( + data.items(), + hasItem( + allOf( + hasKey("alpha"), + hasNamespace(Matchers.>is(component.getClass()))))); + } + + @Test + public void testFromNull() { + thrown.expect(NullPointerException.class); + DisplayData.from(null); + } + + @Test + public void testIncludeNull() { + thrown.expect(NullPointerException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.include(null); + } + }); + } + + @Test + public void testNullKey() { + thrown.expect(NullPointerException.class); + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.add(null, "foo"); + } + }); + } + + @Test + public void testRejectsNullValues() { + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + try { + builder.add("key", (String) null); + throw new RuntimeException("Should throw on null string value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Class) null); + throw new RuntimeException("Should throw on null class value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Duration) null); + throw new RuntimeException("Should throw on null duration value"); + } catch (NullPointerException ex) { + // Expected + } + + try { + builder.add("key", (Instant) null); + throw new RuntimeException("Should throw on null instant value"); + } catch (NullPointerException ex) { + // Expected + } + } + }); + } + + public void testAcceptsNullOptionalValues() { + DisplayData.from( + new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.add("key", "value") + .withLabel(null) + .withLinkUrl(null); + } + }); + + // Should not throw + } + + private static Matcher hasNamespace(Matcher> nsMatcher) { + return new FeatureMatcher>( + nsMatcher, "display item with namespace", "namespace") { + @Override + protected Class featureValueOf(DisplayData.Item actual) { + try { + return Class.forName(actual.getNamespace()); + } catch (ClassNotFoundException e) { + return null; + } + } + }; + } + + private static Matcher hasType(Matcher typeMatcher) { + return new FeatureMatcher( + typeMatcher, "display item with type", "type") { + @Override + protected DisplayData.Type featureValueOf(DisplayData.Item actual) { + return actual.getType(); + } + }; + } + + private static Matcher hasLabel(Matcher labelMatcher) { + return new FeatureMatcher( + labelMatcher, "display item with label", "label") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getLabel(); + } + }; + } + + private static Matcher hasUrl(Matcher urlMatcher) { + return new FeatureMatcher( + urlMatcher, "display item with url", "URL") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getUrl(); + } + }; + } + + private static Matcher hasValue(Matcher valueMatcher) { + return new FeatureMatcher( + valueMatcher, "display item with value", "value") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getValue(); + } + }; + } + + private static Matcher hasShortValue(Matcher valueStringMatcher) { + return new FeatureMatcher( + valueStringMatcher, "display item with short value", "short value") { + @Override + protected String featureValueOf(DisplayData.Item actual) { + return actual.getShortValue(); + } + }; + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java index e995b821de..fcfe1d8f2f 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ApiSurfaceTest.java @@ -49,7 +49,8 @@ public void testOurApiSurface() throws Exception { .pruningClassName("com.google.cloud.dataflow.sdk.util.common.ReflectHelpers") .pruningClassName("com.google.cloud.dataflow.sdk.DataflowMatchers") .pruningClassName("com.google.cloud.dataflow.sdk.TestUtils") - .pruningClassName("com.google.cloud.dataflow.sdk.WindowMatchers"); + .pruningClassName("com.google.cloud.dataflow.sdk.WindowMatchers") + .pruningClassName("com.google.cloud.dataflow.sdk.transforms.display.DisplayDataMatchers"); checkedApiSurface.getExposedClasses(); From 3f99e1fb00f17c3e5a84d1f563d085ce815d7a98 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 26 Feb 2016 17:30:13 -0800 Subject: [PATCH 11/11] Implement InProcessPipelineRunner#run Appropriately construct an evaluation context and executor, and start the pipeline when run is called. Implement InProcessPipelineResult. Apply PTransform overrides. --- ...achedThreadPoolExecutorServiceFactory.java | 42 ++++ .../ConsumerTrackingPipelineVisitor.java | 173 +++++++++++++ .../inprocess/ExecutorServiceFactory.java | 32 +++ .../ExecutorServiceParallelExecutor.java | 2 +- .../inprocess/GroupByKeyEvaluatorFactory.java | 8 +- .../inprocess/InProcessPipelineOptions.java | 56 +++++ .../inprocess/InProcessPipelineRunner.java | 228 +++++++++++++++-- .../inprocess/KeyedPValueTrackingVisitor.java | 95 +++++++ .../ConsumerTrackingPipelineVisitorTest.java | 233 ++++++++++++++++++ .../InProcessPipelineRunnerTest.java | 77 ++++++ .../KeyedPValueTrackingVisitorTest.java | 189 ++++++++++++++ 11 files changed, 1104 insertions(+), 31 deletions(-) create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java create mode 100644 sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java create mode 100644 sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java new file mode 100644 index 0000000000..3350d2b4d5 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.options.DefaultValueFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * A {@link ExecutorServiceFactory} that produces cached thread pools via + * {@link Executors#newCachedThreadPool()}. + */ +class CachedThreadPoolExecutorServiceFactory + implements DefaultValueFactory, ExecutorServiceFactory { + private static final CachedThreadPoolExecutorServiceFactory INSTANCE = + new CachedThreadPoolExecutorServiceFactory(); + + @Override + public ExecutorServiceFactory create(PipelineOptions options) { + return INSTANCE; + } + + @Override + public ExecutorService create() { + return Executors.newCachedThreadPool(); + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java new file mode 100644 index 0000000000..c602b23c41 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the + * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume + * input after the upstream transform has produced and committed output. + */ +public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { + private Map>> valueToConsumers = new HashMap<>(); + private Collection> rootTransforms = new ArrayList<>(); + private Collection> views = new ArrayList<>(); + private Map, String> stepNames = new HashMap<>(); + private Set toFinalize = new HashSet<>(); + private int numTransforms = 0; + private boolean finalized = false; + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s " + + "which has already visited a Pipeline and is finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempting to traverse a pipeline (node %s) with a %s which is already finalized", + node.getFullName(), + ConsumerTrackingPipelineVisitor.class.getSimpleName()); + if (node.isRootNode()) { + finalized = true; + } + } + + @Override + public void visitTransform(TransformTreeNode node) { + toFinalize.removeAll(node.getInput().expand()); + AppliedPTransform appliedTransform = getAppliedTransform(node); + if (node.getInput().expand().isEmpty()) { + rootTransforms.add(appliedTransform); + } else { + for (PValue value : node.getInput().expand()) { + valueToConsumers.get(value).add(appliedTransform); + stepNames.put(appliedTransform, genStepName()); + } + } + } + + private AppliedPTransform getAppliedTransform(TransformTreeNode node) { + @SuppressWarnings({"rawtypes", "unchecked"}) + AppliedPTransform application = AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform()); + return application; + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + toFinalize.add(value); + for (PValue expandedValue : value.expand()) { + valueToConsumers.put(expandedValue, new ArrayList>()); + if (expandedValue instanceof PCollectionView) { + views.add((PCollectionView) expandedValue); + } + expandedValue.recordAsOutput(getAppliedTransform(producer)); + } + value.recordAsOutput(getAppliedTransform(producer)); + } + + private String genStepName() { + return String.format("s%s", numTransforms++); + } + + + /** + * Returns a mapping of each fully-expanded {@link PValue} to each + * {@link AppliedPTransform} that consumes it. For each AppliedPTransform in the collection + * returned from {@code getValueToCustomers().get(PValue)}, + * {@code AppliedPTransform#getInput().expand()} will contain the argument {@link PValue}. + */ + public Map>> getValueToConsumers() { + checkState( + finalized, + "Can't call getValueToConsumers before the Pipeline has been completely traversed"); + + return valueToConsumers; + } + + /** + * Returns the mapping for each {@link AppliedPTransform} in the {@link Pipeline} to a unique step + * name. + */ + public Map, String> getStepNames() { + checkState( + finalized, "Can't call getStepNames before the Pipeline has been completely traversed"); + + return stepNames; + } + + /** + * Returns the root transforms of the {@link Pipeline}. A root {@link AppliedPTransform} consumes + * a {@link PInput} where the {@link PInput#expand()} returns an empty collection. + */ + public Collection> getRootTransforms() { + checkState( + finalized, + "Can't call getRootTransforms before the Pipeline has been completely traversed"); + + return rootTransforms; + } + + /** + * Returns all of the {@link PCollectionView PCollectionViews} contained in the visited + * {@link Pipeline}. + */ + public Collection> getViews() { + checkState(finalized, "Can't call getViews before the Pipeline has been completely traversed"); + + return views; + } + + /** + * Returns all of the {@link PValue PValues} that have been produced but not consumed. These + * {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the + * {@link Pipeline} is executed. + */ + public Set getUnfinalizedPValues() { + checkState( + finalized, + "Can't call getUnfinalizedPValues before the Pipeline has been completely traversed"); + + return toFinalize; + } +} + + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java new file mode 100644 index 0000000000..480bcdefed --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import java.util.concurrent.ExecutorService; + +/** + * A factory that creates {@link ExecutorService ExecutorServices}. + * {@link ExecutorService ExecutorServices} created by this factory should be independent of one + * another (e.g., if any executor is shut down the remaining executors should continue to process + * work). + */ +public interface ExecutorServiceFactory { + /** + * Create a new {@link ExecutorService}. + */ + ExecutorService create(); +} + diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java index ae686f2979..c72a1155f5 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java @@ -126,7 +126,7 @@ private void evaluateBundle( @Nullable final CommittedBundle bundle, final CompletionCallback onComplete) { TransformExecutorService transformExecutor; - if (isKeyed(bundle.getPCollection())) { + if (bundle != null && isKeyed(bundle.getPCollection())) { final StepAndKey stepAndKey = StepAndKey.of(transform, bundle == null ? null : bundle.getKey()); transformExecutor = getSerialExecutorService(stepAndKey); diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java index ec63be84c9..b1d5d35869 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java @@ -59,8 +59,10 @@ public TransformEvaluator forApplication( AppliedPTransform application, CommittedBundle inputBundle, InProcessEvaluationContext evaluationContext) { - return createEvaluator( - (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + @SuppressWarnings({"cast", "unchecked", "rawtypes"}) + TransformEvaluator evaluator = createEvaluator( + (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + return evaluator; } private TransformEvaluator>> createEvaluator( @@ -183,7 +185,7 @@ public static final class InProcessGroupByKey extends ForwardingPTransform>, PCollection>>> { private final GroupByKey original; - public InProcessGroupByKey(GroupByKey from) { + private InProcessGroupByKey(GroupByKey from) { this.original = from; } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java index 27e9a4be6e..5ee0e88b5b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java @@ -15,20 +15,76 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; +import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.Hidden; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.Validation.Required; +import com.google.cloud.dataflow.sdk.transforms.PTransform; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Options that can be used to configure the {@link InProcessPipelineRunner}. */ public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions { + /** + * Gets the {@link ExecutorServiceFactory} to use to create instances of {@link ExecutorService} + * to execute {@link PTransform PTransforms}. + * + *

Note that {@link ExecutorService ExecutorServices} returned by the factory must ensure that + * it cannot enter a state in which it will not schedule additional pending work unless currently + * scheduled work completes, as this may cause the {@link Pipeline} to cease processing. + * + *

Defaults to a {@link CachedThreadPoolExecutorServiceFactory}, which produces instances of + * {@link Executors#newCachedThreadPool()}. + */ + @JsonIgnore + @Required + @Hidden + @Default.InstanceFactory(CachedThreadPoolExecutorServiceFactory.class) + ExecutorServiceFactory getExecutorServiceFactory(); + + void setExecutorServiceFactory(ExecutorServiceFactory executorService); + + /** + * Gets the {@link Clock} used by this pipeline. The clock is used in place of accessing the + * system time when time values are required by the evaluator. + */ @Default.InstanceFactory(NanosOffsetClock.Factory.class) + @JsonIgnore + @Required + @Hidden + @Description( + "The processing time source used by the pipeline. When the current time is " + + "needed by the evaluator, the result of clock#now() is used.") Clock getClock(); void setClock(Clock clock); + @Default.Boolean(false) + @Description( + "If the pipeline should shut down producers which have reached the maximum " + + "representable watermark. If this is set to true, a pipeline in which all PTransforms " + + "have reached the maximum watermark will be shut down, even if there are unbounded " + + "sources that could produce additional (late) data. By default, if the pipeline " + + "contains any unbounded PCollections, it will run until explicitly shut down.") boolean isShutdownUnboundedProducersWithMaxWatermark(); void setShutdownUnboundedProducersWithMaxWatermark(boolean shutdown); + + @Default.Boolean(true) + @Description( + "If the pipeline should block awaiting completion of the pipeline. If set to true, " + + "a call to Pipeline#run() will block until all PTransforms are complete. Otherwise, " + + "the Pipeline will execute asynchronously. If set to false, the completion of the " + + "pipeline can be awaited on by use of InProcessPipelineResult#awaitCompletion().") + boolean isBlockOnRun(); + + void setBlockOnRun(boolean b); } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 32859dae63..a1c8756c5e 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2015 Google Inc. + * Copyright (C) 2016 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of @@ -15,25 +15,46 @@ */ package com.google.cloud.dataflow.sdk.runners.inprocess; -import static com.google.common.base.Preconditions.checkArgument; - +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.annotations.Experimental; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey; -import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView; +import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView; +import com.google.cloud.dataflow.sdk.util.InstanceBuilder; +import com.google.cloud.dataflow.sdk.util.MapAggregatorValues; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; +import com.google.cloud.dataflow.sdk.util.UserCodeException; import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.common.Counter; +import com.google.cloud.dataflow.sdk.util.common.CounterSet; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded; import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.joda.time.Instant; +import java.util.Collection; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; import javax.annotation.Nullable; @@ -42,28 +63,25 @@ * {@link PCollection PCollections}. */ @Experimental -public class InProcessPipelineRunner { - @SuppressWarnings({"rawtypes", "unused"}) +public class InProcessPipelineRunner + extends PipelineRunner { + /** + * The default set of transform overrides to use in the {@link InProcessPipelineRunner}. + * + *

A transform override must have a single-argument constructor that takes an instance of the + * type of transform it is overriding. + */ + @SuppressWarnings("rawtypes") private static Map, Class> defaultTransformOverrides = ImmutableMap., Class>builder() + .put(Create.Values.class, InProcessCreate.class) .put(GroupByKey.class, InProcessGroupByKey.class) - .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class) + .put( + CreatePCollectionView.class, + ViewEvaluatorFactory.InProcessCreatePCollectionView.class) .build(); - private static Map, TransformEvaluatorFactory> defaultEvaluatorFactories = - new ConcurrentHashMap<>(); - - /** - * Register a default transform evaluator. - */ - public static > void registerTransformEvaluatorFactory( - Class clazz, TransformEvaluatorFactory evaluator) { - checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null, - "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator, - clazz); - } - /** * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be * executed by {@link PTransform PTransforms} that consume the {@link PCollection} this bundle is @@ -73,7 +91,7 @@ public class InProcessPipelineRunner { */ public static interface UncommittedBundle { /** - * Returns the PCollection that the elements of this bundle belong to. + * Returns the PCollection that the elements of this {@link UncommittedBundle} belong to. */ PCollection getPCollection(); @@ -103,14 +121,13 @@ public static interface UncommittedBundle { * @param the type of elements contained within this bundle */ public static interface CommittedBundle { - /** * Returns the PCollection that the elements of this bundle belong to. */ PCollection getPCollection(); /** - * Returns weather this bundle is keyed. A bundle that is part of a {@link PCollection} that + * Returns whether this bundle is keyed. A bundle that is part of a {@link PCollection} that * occurs after a {@link GroupByKey} is keyed by the result of the last {@link GroupByKey}. */ boolean isKeyed(); @@ -119,11 +136,12 @@ public static interface CommittedBundle { * Returns the (possibly null) key that was output in the most recent {@link GroupByKey} in the * execution of this bundle. */ - @Nullable Object getKey(); + @Nullable + Object getKey(); /** - * @return an {@link Iterable} containing all of the elements that have been added to this - * {@link CommittedBundle} + * Returns an {@link Iterable} containing all of the elements that have been added to this + * {@link CommittedBundle}. */ Iterable> getElements(); @@ -166,4 +184,160 @@ private InProcessPipelineRunner(InProcessPipelineOptions options) { public InProcessPipelineOptions getPipelineOptions() { return options; } + + @Override + public OutputT apply( + PTransform transform, InputT input) { + Class overrideClass = defaultTransformOverrides.get(transform.getClass()); + if (overrideClass != null) { + // It is the responsibility of whoever constructs overrides to ensure this is type safe. + @SuppressWarnings("unchecked") + Class> transformClass = + (Class>) transform.getClass(); + + @SuppressWarnings("unchecked") + Class> customTransformClass = + (Class>) overrideClass; + + PTransform customTransform = + InstanceBuilder.ofType(customTransformClass) + .withArg(transformClass, transform) + .build(); + + // This overrides the contents of the apply method without changing the TransformTreeNode that + // is generated by the PCollection application. + return super.apply(customTransform, input); + } else { + return super.apply(transform, input); + } + } + + @Override + public InProcessPipelineResult run(Pipeline pipeline) { + ConsumerTrackingPipelineVisitor consumerTrackingVisitor = new ConsumerTrackingPipelineVisitor(); + pipeline.traverseTopologically(consumerTrackingVisitor); + for (PValue unfinalized : consumerTrackingVisitor.getUnfinalizedPValues()) { + unfinalized.finishSpecifying(); + } + @SuppressWarnings("rawtypes") + KeyedPValueTrackingVisitor keyedPValueVisitor = + KeyedPValueTrackingVisitor.create( + ImmutableSet.>of( + GroupByKey.class, InProcessGroupByKeyOnly.class)); + pipeline.traverseTopologically(keyedPValueVisitor); + + InProcessEvaluationContext context = + InProcessEvaluationContext.create( + getPipelineOptions(), + consumerTrackingVisitor.getRootTransforms(), + consumerTrackingVisitor.getValueToConsumers(), + consumerTrackingVisitor.getStepNames(), + consumerTrackingVisitor.getViews()); + + // independent executor service for each run + ExecutorService executorService = + context.getPipelineOptions().getExecutorServiceFactory().create(); + InProcessExecutor executor = + ExecutorServiceParallelExecutor.create( + executorService, + consumerTrackingVisitor.getValueToConsumers(), + keyedPValueVisitor.getKeyedPValues(), + TransformEvaluatorRegistry.defaultRegistry(), + context); + executor.start(consumerTrackingVisitor.getRootTransforms()); + + Map, Collection>> aggregatorSteps = + new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + InProcessPipelineResult result = + new InProcessPipelineResult(executor, context, aggregatorSteps); + if (options.isBlockOnRun()) { + try { + result.awaitCompletion(); + } catch (UserCodeException userException) { + throw new PipelineExecutionException(userException.getCause()); + } catch (Throwable t) { + Throwables.propagate(t); + } + } + return result; + } + + /** + * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}. + * + * Throws {@link UnsupportedOperationException} for all methods. + */ + public static class InProcessPipelineResult implements PipelineResult { + private final InProcessExecutor executor; + private final InProcessEvaluationContext evaluationContext; + private final Map, Collection>> aggregatorSteps; + private State state; + + private InProcessPipelineResult( + InProcessExecutor executor, + InProcessEvaluationContext evaluationContext, + Map, Collection>> aggregatorSteps) { + this.executor = executor; + this.evaluationContext = evaluationContext; + this.aggregatorSteps = aggregatorSteps; + // Only ever constructed after the executor has started. + this.state = State.RUNNING; + } + + @Override + public State getState() { + return state; + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException { + CounterSet counters = evaluationContext.getCounters(); + Collection> steps = aggregatorSteps.get(aggregator); + Map stepValues = new HashMap<>(); + for (AppliedPTransform transform : evaluationContext.getSteps()) { + if (steps.contains(transform.getTransform())) { + String stepName = + String.format( + "user-%s-%s", evaluationContext.getStepName(transform), aggregator.getName()); + Counter counter = (Counter) counters.getExistingCounter(stepName); + if (counter != null) { + stepValues.put(transform.getFullName(), counter.getAggregate()); + } + } + } + return new MapAggregatorValues<>(stepValues); + } + + /** + * Blocks until the {@link Pipeline} execution represented by this + * {@link InProcessPipelineResult} is complete, returning the terminal state. + * + *

If the pipeline terminates abnormally by throwing an exception, this will rethrow the + * exception. Future calls to {@link #getState()} will return + * {@link com.google.cloud.dataflow.sdk.PipelineResult.State#FAILED}. + * + *

NOTE: if the {@link Pipeline} contains an {@link IsBounded#UNBOUNDED unbounded} + * {@link PCollection}, and the {@link PipelineRunner} was created with + * {@link InProcessPipelineOptions#isShutdownUnboundedProducersWithMaxWatermark()} set to false, + * this method will never return. + * + * See also {@link InProcessExecutor#awaitCompletion()}. + */ + public State awaitCompletion() throws Throwable { + if (!state.isTerminal()) { + try { + executor.awaitCompletion(); + state = State.DONE; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw e; + } catch (Throwable t) { + state = State.FAILED; + throw t; + } + } + return state; + } + } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java new file mode 100644 index 0000000000..23a8c0f506 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; + +import java.util.HashSet; +import java.util.Set; + +/** + * A pipeline visitor that tracks all keyed {@link PValue PValues}. A {@link PValue} is keyed if it + * is the result of a {@link PTransform} that produces keyed outputs. A {@link PTransform} that + * produces keyed outputs is assumed to colocate output elements that share a key. + * + *

All {@link GroupByKey} transforms, or their runner-specific implementation primitive, produce + * keyed output. + */ +// TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms +// unkeyed +class KeyedPValueTrackingVisitor implements PipelineVisitor { + @SuppressWarnings("rawtypes") + private final Set> producesKeyedOutputs; + private final Set keyedValues; + private boolean finalized; + + public static KeyedPValueTrackingVisitor create( + @SuppressWarnings("rawtypes") Set> producesKeyedOutputs) { + return new KeyedPValueTrackingVisitor(producesKeyedOutputs); + } + + private KeyedPValueTrackingVisitor( + @SuppressWarnings("rawtypes") Set> producesKeyedOutputs) { + this.producesKeyedOutputs = producesKeyedOutputs; + this.keyedValues = new HashSet<>(); + } + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + checkState( + !finalized, + "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", + KeyedPValueTrackingVisitor.class.getSimpleName(), + node); + if (node.isRootNode()) { + finalized = true; + } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) { + keyedValues.addAll(node.getExpandedOutputs()); + } + } + + @Override + public void visitTransform(TransformTreeNode node) {} + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + if (producesKeyedOutputs.contains(producer.getTransform().getClass())) { + keyedValues.addAll(value.expand()); + } + } + + public Set getKeyedPValues() { + checkState( + finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed"); + return keyedValues; + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java new file mode 100644 index 0000000000..d921f6cdb5 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.io.CountingInput; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.PValue; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Tests for {@link ConsumerTrackingPipelineVisitor}. + */ +@RunWith(JUnit4.class) +public class ConsumerTrackingPipelineVisitorTest implements Serializable { + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private transient TestPipeline p = TestPipeline.create(); + private transient ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + + @Test + public void getViewsReturnsViews() { + PCollectionView> listView = + p.apply("listCreate", Create.of("foo", "bar")) + .apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply(View.asList()); + PCollectionView singletonView = + p.apply("singletonCreate", Create.of(1, 2, 3)).apply(View.asSingleton()); + p.traverseTopologically(visitor); + assertThat( + visitor.getViews(), + Matchers.>containsInAnyOrder(listView, singletonView)); + } + + @Test + public void getRootTransformsContainsPBegins() { + PCollection created = p.apply(Create.of("foo", "bar")); + PCollection counted = p.apply(CountingInput.upTo(1234L)); + PCollection unCounted = p.apply(CountingInput.unbounded()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.>containsInAnyOrder( + created.getProducingTransformInternal(), + counted.getProducingTransformInternal(), + unCounted.getProducingTransformInternal())); + } + + @Test + public void getRootTransformsContainsEmptyFlatten() { + PCollection empty = + PCollectionList.empty(p).apply(Flatten.pCollections()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.>containsInAnyOrder( + empty.getProducingTransformInternal())); + } + + @Test + public void getValueToConsumersSucceeds() { + PCollection created = p.apply(Create.of("1", "2", "3")); + PCollection transformed = + created.apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + PCollection flattened = + PCollectionList.of(created).and(transformed).apply(Flatten.pCollections()); + + p.traverseTopologically(visitor); + + assertThat( + visitor.getValueToConsumers().get(created), + Matchers.>containsInAnyOrder( + transformed.getProducingTransformInternal(), + flattened.getProducingTransformInternal())); + assertThat( + visitor.getValueToConsumers().get(transformed), + Matchers.>containsInAnyOrder( + flattened.getProducingTransformInternal())); + assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + } + + @Test + public void getUnfinalizedPValuesContainsDanglingOutputs() { + PCollection created = p.apply(Create.of("1", "2", "3")); + PCollection transformed = + created.apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), Matchers.contains(transformed)); + } + + @Test + public void getUnfinalizedPValuesEmpty() { + p.apply(Create.of("1", "2", "3")) + .apply( + ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply( + new PTransform() { + @Override + public PDone apply(PInput input) { + return PDone.in(input.getPipeline()); + } + }); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), emptyIterable()); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply(Create.of(1, 2, 3)); + + p.traverseTopologically(visitor); + thrown.expect(IllegalStateException.class); + thrown.expectMessage(ConsumerTrackingPipelineVisitor.class.getSimpleName()); + thrown.expectMessage("is finalized"); + p.traverseTopologically(visitor); + } + + @Test + public void traverseIndependentPathsSucceeds() { + p.apply("left", Create.of(1, 2, 3)); + p.apply("right", Create.of("foo", "bar", "baz")); + + p.traverseTopologically(visitor); + } + + @Test + public void getRootTransformsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getRootTransforms"); + visitor.getRootTransforms(); + } + @Test + public void getStepNamesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getStepNames"); + visitor.getStepNames(); + } + @Test + public void getUnfinalizedPValuesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getUnfinalizedPValues"); + visitor.getUnfinalizedPValues(); + } + + @Test + public void getValueToConsumersWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getValueToConsumers"); + visitor.getValueToConsumers(); + } + + @Test + public void getViewsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getViews"); + visitor.getViews(); + } +} diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java new file mode 100644 index 0000000000..adb64cd625 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessPipelineResult; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.SimpleFunction; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; + +/** + * Tests for basic {@link InProcessPipelineRunner} functionality. + */ +@RunWith(JUnit4.class) +public class InProcessPipelineRunnerTest implements Serializable { + @Test + public void wordCountShouldSucceed() throws Throwable { + Pipeline p = getPipeline(); + + PCollection> counts = + p.apply(Create.of("foo", "bar", "foo", "baz", "bar", "foo")) + .apply(MapElements.via(new SimpleFunction() { + @Override + public String apply(String input) { + return input; + } + })) + .apply(Count.perElement()); + PCollection countStrs = + counts.apply(MapElements.via(new SimpleFunction, String>() { + @Override + public String apply(KV input) { + String str = String.format("%s: %s", input.getKey(), input.getValue()); + return str; + } + })); + + DataflowAssert.that(countStrs).containsInAnyOrder("baz: 1", "bar: 2", "foo: 3"); + + InProcessPipelineResult result = ((InProcessPipelineResult) p.run()); + result.awaitCompletion(); + } + + private Pipeline getPipeline() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(InProcessPipelineRunner.class); + + Pipeline p = Pipeline.create(opts); + return p; + } +} + diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java new file mode 100644 index 0000000000..0aaccc2848 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java @@ -0,0 +1,189 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.IterableCoder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableSet; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Collections; +import java.util.Set; + +/** + * Tests for {@link KeyedPValueTrackingVisitor}. + */ +@RunWith(JUnit4.class) +public class KeyedPValueTrackingVisitorTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + private KeyedPValueTrackingVisitor visitor; + private Pipeline p; + + @Before + public void setup() { + PipelineOptions options = PipelineOptionsFactory.create(); + + p = Pipeline.create(options); + @SuppressWarnings("rawtypes") + Set> producesKeyed = + ImmutableSet.>of(PrimitiveKeyer.class, CompositeKeyer.class); + visitor = KeyedPValueTrackingVisitor.create(producesKeyed); + } + + @Test + public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)).apply(new PrimitiveKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)) + .apply("firstKey", new PrimitiveKeyer()) + .apply("secondKey", new PrimitiveKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)).apply(new CompositeKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test + public void compositeProducesKeyedOutputKeyedInputKeyedOutut() { + PCollection keyed = + p.apply(Create.of(1, 2, 3)) + .apply("firstKey", new CompositeKeyer()) + .apply("secondKey", new CompositeKeyer()); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + + @Test + public void noInputUnkeyedOutput() { + PCollection>> unkeyed = + p.apply( + Create.of(KV.>of(-1, Collections.emptyList())) + .withCoder(KvCoder.of(VarIntCoder.of(), IterableCoder.of(VoidCoder.of())))); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void keyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection onceKeyed = + p.apply(Create.of(1, 2, 3)) + .apply(new PrimitiveKeyer()) + .apply(ParDo.of(new IdentityFn())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed))); + } + + @Test + public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() { + PCollection unkeyed = + p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new IdentityFn())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply( + Create.>of( + KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null)) + .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of()))) + .apply(GroupByKey.create()) + .apply(Keys.create()); + + p.traverseTopologically(visitor); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("already been finalized"); + thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName()); + p.traverseTopologically(visitor); + } + + @Test + public void getKeyedPValuesBeforeTraverseThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getKeyedPValues"); + visitor.getKeyedPValues(); + } + + private static class PrimitiveKeyer extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); + } + } + + private static class CompositeKeyer extends PTransform, PCollection> { + @Override + public PCollection apply(PCollection input) { + return input.apply(new PrimitiveKeyer()).apply(ParDo.of(new IdentityFn())); + } + } + + private static class IdentityFn extends DoFn { + @Override + public void processElement(DoFn.ProcessContext c) throws Exception { + c.output(c.element()); + } + } +}