Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,34 @@ public void beforeElement(WindowedValue<T> element) {
}
}

@Override
public void afterElement(WindowedValue<T> element) {
verifyUnmodified(mutationElements.get(element));
}

@Override
public void afterFinish(
CommittedBundle<T> input,
InProcessTransformResult result,
Iterable<? extends CommittedBundle<?>> outputs) {
for (MutationDetector detector : mutationElements.values()) {
try {
detector.verifyUnmodified();
} catch (IllegalMutationException e) {
throw UserCodeException.wrap(
new IllegalMutationException(
String.format(
"PTransform %s illegaly mutated value %s of class %s."
+ " Input values must not be mutated in any way.",
transform.getFullName(),
e.getSavedValue(),
e.getSavedValue().getClass()),
e.getSavedValue(),
e.getNewValue()));
}
verifyUnmodified(detector);
}
}

private void verifyUnmodified(MutationDetector detector) {
try {
detector.verifyUnmodified();
} catch (IllegalMutationException e) {
throw new IllegalMutationException(
String.format(
"PTransform %s illegaly mutated value %s of class %s."
+ " Input values must not be mutated in any way.",
transform.getFullName(),
e.getSavedValue(),
e.getSavedValue().getClass()),
e.getSavedValue(),
e.getNewValue());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,13 @@ public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNa
boolean isBlockOnRun();

void setBlockOnRun(boolean b);

@Default.Boolean(true)
@Description(
"Controls whether the runner should ensure that all of the elements of every "
+ "PCollection are not mutated. PTransforms are not permitted to mutate input elements "
+ "at any point, or output elements after they are output.")
boolean isTestImmutability();

void setTestImmutability(boolean test);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView;
import com.google.cloud.dataflow.sdk.util.InstanceBuilder;
import com.google.cloud.dataflow.sdk.util.MapAggregatorValues;
Expand All @@ -48,13 +49,13 @@
import com.google.cloud.dataflow.sdk.values.POutput;
import com.google.cloud.dataflow.sdk.values.PValue;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import org.joda.time.Instant;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -269,11 +270,29 @@ public InProcessPipelineResult run(Pipeline pipeline) {

private Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
defaultModelEnforcements(InProcessPipelineOptions options) {
return Collections.emptyMap();
ImmutableMap.Builder<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
enforcements = ImmutableMap.builder();
Collection<ModelEnforcementFactory> parDoEnforcements = createParDoEnforcements(options);
enforcements.put(ParDo.Bound.class, parDoEnforcements);
enforcements.put(ParDo.BoundMulti.class, parDoEnforcements);
return enforcements.build();
}

private Collection<ModelEnforcementFactory> createParDoEnforcements(
InProcessPipelineOptions options) {
ImmutableList.Builder<ModelEnforcementFactory> enforcements = ImmutableList.builder();
if (options.isTestImmutability()) {
enforcements.add(ImmutabilityEnforcementFactory.create());
}
return enforcements.build();
}

private BundleFactory createBundleFactory(InProcessPipelineOptions pipelineOptions) {
return InProcessBundleFactory.create();
BundleFactory bundleFactory = InProcessBundleFactory.create();
if (pipelineOptions.isTestImmutability()) {
bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory);
}
return bundleFactory;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
import com.google.cloud.dataflow.sdk.util.DoFnRunner;
import com.google.cloud.dataflow.sdk.util.DoFnRunners.OutputManager;
import com.google.cloud.dataflow.sdk.util.UserCodeException;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
import com.google.cloud.dataflow.sdk.util.common.CounterSet;
import com.google.cloud.dataflow.sdk.util.state.CopyOnAccessInMemoryStateInternals;
Expand Down Expand Up @@ -56,12 +57,20 @@ public ParDoInProcessEvaluator(

@Override
public void processElement(WindowedValue<T> element) {
fnRunner.processElement(element);
try {
fnRunner.processElement(element);
} catch (Exception e) {
throw UserCodeException.wrap(e);
}
}

@Override
public InProcessTransformResult finishBundle() {
fnRunner.finishBundle();
try {
fnRunner.finishBundle();
} catch (Exception e) {
throw UserCodeException.wrap(e);
}
StepTransformResult.Builder resultBuilder;
CopyOnAccessInMemoryStateInternals<?> state = stepContext.commitState();
if (state != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package com.google.cloud.dataflow.sdk.runners.inprocess;

import static org.hamcrest.Matchers.isA;

import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle;
import com.google.cloud.dataflow.sdk.testing.TestPipeline;
import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
Expand All @@ -27,7 +25,6 @@
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.util.IllegalMutationException;
import com.google.cloud.dataflow.sdk.util.UserCodeException;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
import com.google.cloud.dataflow.sdk.values.PCollection;

Expand Down Expand Up @@ -94,8 +91,7 @@ public void mutatedDuringProcessElementThrows() {
ModelEnforcement<byte[]> enforcement = factory.forBundle(elements, consumer);
enforcement.beforeElement(element);
element.getValue()[0] = 'f';
thrown.expect(UserCodeException.class);
thrown.expectCause(isA(IllegalMutationException.class));
thrown.expect(IllegalMutationException.class);
thrown.expectMessage(consumer.getFullName());
thrown.expectMessage("illegaly mutated");
thrown.expectMessage("Input values must not be mutated");
Expand All @@ -118,8 +114,7 @@ public void mutatedAfterProcessElementFails() {
enforcement.afterElement(element);

element.getValue()[0] = 'f';
thrown.expect(UserCodeException.class);
thrown.expectCause(isA(IllegalMutationException.class));
thrown.expect(IllegalMutationException.class);
thrown.expectMessage(consumer.getFullName());
thrown.expectMessage("illegaly mutated");
thrown.expectMessage("Input values must not be mutated");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.WithKeys;
import com.google.cloud.dataflow.sdk.util.UserCodeException;
import com.google.cloud.dataflow.sdk.util.IllegalMutationException;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
Expand Down Expand Up @@ -413,7 +413,7 @@ public InProcessTransformResult finishBundle() throws Exception {
fooBytes.getValue()[0] = 'b';
evaluatorLatch.countDown();

thrown.expectCause(isA(UserCodeException.class));
thrown.expectCause(isA(IllegalMutationException.class));
task.get();
}

Expand Down Expand Up @@ -472,7 +472,7 @@ public InProcessTransformResult finishBundle() throws Exception {
fooBytes.getValue()[0] = 'b';
evaluatorLatch.countDown();

thrown.expectCause(isA(UserCodeException.class));
thrown.expectCause(isA(IllegalMutationException.class));
task.get();
}

Expand Down