diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java index 4e60545c20af..5c0745f54ace 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.runners; -import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray; import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString; import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray; @@ -34,7 +33,6 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.io.BigQueryIO; import org.apache.beam.sdk.io.PubsubIO; @@ -47,7 +45,6 @@ import org.apache.beam.sdk.runners.dataflow.ReadTranslator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; @@ -844,45 +841,6 @@ private void translateHelper( } }); - registerTransformTranslator( - Create.Values.class, - new TransformTranslator() { - @Override - public void translate( - Create.Values transform, - TranslationContext context) { - createHelper(transform, context); - } - - private void createHelper( - Create.Values transform, - TranslationContext context) { - context.addStep(transform, "CreateCollection"); - - Coder coder = context.getOutput(transform).getCoder(); - List elements = new LinkedList<>(); - for (T elem : transform.getElements()) { - byte[] encodedBytes; - try { - encodedBytes = encodeToByteArray(coder, elem); - } catch (CoderException exn) { - // TODO: Put in better element printing: - // truncate if too long. - throw new IllegalArgumentException( - "Unable to encode element '" + elem + "' of transform '" + transform - + "' using coder '" + coder + "'.", - exn); - } - String encodedJson = byteArrayToJsonString(encodedBytes); - assert Arrays.equals(encodedBytes, - jsonStringToByteArray(encodedJson)); - elements.add(CloudObject.forString(encodedJson)); - } - context.addInput(PropertyNames.ELEMENT, elements); - context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); - } - }); - registerTransformTranslator( Flatten.FlattenPCollectionList.class, new TransformTranslator() { diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java index 8b024fb8726c..69491287b720 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; @@ -840,9 +841,16 @@ public void testApplyIsScopedToExactClass() throws IOException { CompositeTransformRecorder recorder = new CompositeTransformRecorder(); p.traverseTopologically(recorder); - assertThat("Expected to have seen CreateTimestamped composite transform.", + // The recorder will also have seen a Create.Values composite as well, but we can't obtain that + // transform. + assertThat( + "Expected to have seen CreateTimestamped composite transform.", recorder.getCompositeTransforms(), - Matchers.>contains(transform)); + hasItem(transform)); + assertThat( + "Expected to have two composites, CreateTimestamped and Create.Values", + recorder.getCompositeTransforms(), + hasItem(Matchers.>isA((Class) Create.Values.class))); } @Test diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java index 0d58601d7e8b..a62f55042bf9 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java @@ -751,7 +751,7 @@ public void testToSingletonTranslation() throws Exception { assertEquals(2, steps.size()); Step createStep = steps.get(0); - assertEquals("CreateCollection", createStep.getKind()); + assertEquals("ParallelRead", createStep.getKind()); Step collectionToSingletonStep = steps.get(1); assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); @@ -783,7 +783,7 @@ public void testToIterableTranslation() throws Exception { assertEquals(2, steps.size()); Step createStep = steps.get(0); - assertEquals("CreateCollection", createStep.getKind()); + assertEquals("ParallelRead", createStep.getKind()); Step collectionToSingletonStep = steps.get(1); assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java deleted file mode 100644 index c29d5ce045d0..000000000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners.inprocess; - -import org.apache.beam.sdk.coders.CannotProvideCoderException; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.OffsetBasedSource; -import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader; -import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Create.Values; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Optional; -import com.google.common.collect.ImmutableList; - -import java.io.IOException; -import java.util.List; -import java.util.NoSuchElementException; - -import javax.annotation.Nullable; - -/** - * An in-process implementation of the {@link Values Create.Values} {@link PTransform}, implemented - * using a {@link BoundedSource}. - * - * The coder is inferred via the {@link Values#getDefaultOutputCoder(PInput)} method on the original - * transform. - */ -class InProcessCreate extends ForwardingPTransform> { - private final Create.Values original; - - /** - * A {@link PTransformOverrideFactory} for {@link InProcessCreate}. - */ - public static class InProcessCreateOverrideFactory implements PTransformOverrideFactory { - @Override - public PTransform override( - PTransform transform) { - if (transform instanceof Create.Values) { - @SuppressWarnings("unchecked") - PTransform override = - (PTransform) from((Create.Values) transform); - return override; - } - return transform; - } - } - - public static InProcessCreate from(Create.Values original) { - return new InProcessCreate<>(original); - } - - private InProcessCreate(Values original) { - this.original = original; - } - - @Override - public PCollection apply(PInput input) { - Coder elementCoder; - try { - elementCoder = original.getDefaultOutputCoder(input); - } catch (CannotProvideCoderException e) { - throw new IllegalArgumentException( - "Unable to infer a coder and no Coder was specified. " - + "Please set a coder by invoking Create.withCoder() explicitly.", - e); - } - InMemorySource source; - try { - source = InMemorySource.fromIterable(original.getElements(), elementCoder); - } catch (IOException e) { - throw new RuntimeException(e); - } - PCollection result = input.getPipeline().apply(Read.from(source)); - result.setCoder(elementCoder); - return result; - } - - @Override - public PTransform> delegate() { - return original; - } - - @VisibleForTesting - static class InMemorySource extends OffsetBasedSource { - private final List allElementsBytes; - private final long totalSize; - private final Coder coder; - - public static InMemorySource fromIterable(Iterable elements, Coder elemCoder) - throws CoderException, IOException { - ImmutableList.Builder allElementsBytes = ImmutableList.builder(); - long totalSize = 0L; - for (T element : elements) { - byte[] bytes = CoderUtils.encodeToByteArray(elemCoder, element); - allElementsBytes.add(bytes); - totalSize += bytes.length; - } - return new InMemorySource<>(allElementsBytes.build(), totalSize, elemCoder); - } - - /** - * Create a new source with the specified bytes. The new source owns the input element bytes, - * which must not be modified after this constructor is called. - */ - private InMemorySource(List elementBytes, long totalSize, Coder coder) { - super(0, elementBytes.size(), 1); - this.allElementsBytes = ImmutableList.copyOf(elementBytes); - this.totalSize = totalSize; - this.coder = coder; - } - - @Override - public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { - return totalSize; - } - - @Override - public boolean producesSortedKeys(PipelineOptions options) throws Exception { - return false; - } - - @Override - public BoundedSource.BoundedReader createReader(PipelineOptions options) throws IOException { - return new BytesReader<>(this); - } - - @Override - public void validate() {} - - @Override - public Coder getDefaultOutputCoder() { - return coder; - } - - @Override - public long getMaxEndOffset(PipelineOptions options) throws Exception { - return allElementsBytes.size(); - } - - @Override - public OffsetBasedSource createSourceForSubrange(long start, long end) { - List primaryElems = allElementsBytes.subList((int) start, (int) end); - long primarySizeEstimate = - (long) (totalSize * primaryElems.size() / (double) allElementsBytes.size()); - return new InMemorySource<>(primaryElems, primarySizeEstimate, coder); - } - - @Override - public long getBytesPerOffset() { - if (allElementsBytes.size() == 0) { - return 0L; - } - return totalSize / allElementsBytes.size(); - } - } - - private static class BytesReader extends OffsetBasedReader { - private int index; - /** - * Use an optional to distinguish between null next element (as Optional.absent()) and no next - * element (next is null). - */ - @Nullable private Optional next; - - public BytesReader(InMemorySource source) { - super(source); - index = -1; - } - - @Override - @Nullable - public T getCurrent() throws NoSuchElementException { - if (next == null) { - throw new NoSuchElementException(); - } - return next.orNull(); - } - - @Override - public void close() throws IOException {} - - @Override - protected long getCurrentOffset() { - return index; - } - - @Override - protected boolean startImpl() throws IOException { - return advanceImpl(); - } - - @Override - public synchronized InMemorySource getCurrentSource() { - return (InMemorySource) super.getCurrentSource(); - } - - @Override - protected boolean advanceImpl() throws IOException { - InMemorySource source = getCurrentSource(); - index++; - if (index >= source.allElementsBytes.size()) { - return false; - } - next = - Optional.fromNullable( - CoderUtils.decodeFromByteArray( - source.coder, source.allElementsBytes.get(index))); - return true; - } - } -} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java index 6cc35fb01ee5..7c28238d0dad 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -30,11 +30,9 @@ import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; import org.apache.beam.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOverrideFactory; -import org.apache.beam.sdk.runners.inprocess.InProcessCreate.InProcessCreateOverrideFactory; import org.apache.beam.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessViewOverrideFactory; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -83,7 +81,6 @@ public class InProcessPipelineRunner private static Map, PTransformOverrideFactory> defaultTransformOverrides = ImmutableMap., PTransformOverrideFactory>builder() - .put(Create.Values.class, new InProcessCreateOverrideFactory()) .put(GroupByKey.class, new InProcessGroupByKeyOverrideFactory()) .put(CreatePCollectionView.class, new InProcessViewOverrideFactory()) .put(AvroIO.Write.Bound.class, new AvroIOShardedWriteFactory()) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java index 27fb39d8f8ab..1bd4fb3912f2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java @@ -20,33 +20,42 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.runners.DirectPipelineRunner; -import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.OffsetBasedSource; +import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; import org.apache.beam.sdk.values.TypeDescriptor; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.joda.time.Instant; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; +import javax.annotation.Nullable; + /** * {@code Create} takes a collection of elements of type {@code T} * known when the pipeline is constructed and returns a @@ -237,12 +246,13 @@ public Iterable getElements() { public PCollection apply(PInput input) { try { Coder coder = getDefaultOutputCoder(input); - return PCollection - .createPrimitiveOutputInternal( - input.getPipeline(), - WindowingStrategy.globalDefault(), - IsBounded.BOUNDED) - .setCoder(coder); + try { + CreateSource source = CreateSource.fromIterable(elems, coder); + return input.getPipeline().apply(Read.from(source)); + } catch (IOException e) { + throw new RuntimeException( + String.format("Unable to apply Create %s using Coder %s.", this, coder), e); + } } catch (CannotProvideCoderException e) { throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. " + "Please set a coder by invoking Create.withCoder() explicitly.", e); @@ -320,6 +330,136 @@ private Values(Iterable elems, Optional> coder) { this.elems = elems; this.coder = coder; } + + @VisibleForTesting + static class CreateSource extends OffsetBasedSource { + private final List allElementsBytes; + private final long totalSize; + private final Coder coder; + + public static CreateSource fromIterable(Iterable elements, Coder elemCoder) + throws CoderException, IOException { + ImmutableList.Builder allElementsBytes = ImmutableList.builder(); + long totalSize = 0L; + for (T element : elements) { + byte[] bytes = CoderUtils.encodeToByteArray(elemCoder, element); + allElementsBytes.add(bytes); + totalSize += bytes.length; + } + return new CreateSource<>(allElementsBytes.build(), totalSize, elemCoder); + } + + /** + * Create a new source with the specified bytes. The new source owns the input element bytes, + * which must not be modified after this constructor is called. + */ + private CreateSource(List elementBytes, long totalSize, Coder coder) { + super(0, elementBytes.size(), 1); + this.allElementsBytes = ImmutableList.copyOf(elementBytes); + this.totalSize = totalSize; + this.coder = coder; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return totalSize; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader createReader(PipelineOptions options) + throws IOException { + return new BytesReader<>(this); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + @Override + public long getMaxEndOffset(PipelineOptions options) throws Exception { + return allElementsBytes.size(); + } + + @Override + public OffsetBasedSource createSourceForSubrange(long start, long end) { + List primaryElems = allElementsBytes.subList((int) start, (int) end); + long primarySizeEstimate = + (long) (totalSize * primaryElems.size() / (double) allElementsBytes.size()); + return new CreateSource<>(primaryElems, primarySizeEstimate, coder); + } + + @Override + public long getBytesPerOffset() { + if (allElementsBytes.size() == 0) { + return 0L; + } + return totalSize / allElementsBytes.size(); + } + } + + private static class BytesReader extends OffsetBasedReader { + private int index; + /** + * Use an optional to distinguish between null next element (as Optional.absent()) and no next + * element (next is null). + */ + @Nullable private Optional next; + + public BytesReader(CreateSource source) { + super(source); + index = -1; + } + + @Override + @Nullable + public T getCurrent() throws NoSuchElementException { + if (next == null) { + throw new NoSuchElementException(); + } + return next.orNull(); + } + + @Override + public void close() throws IOException {} + + @Override + protected long getCurrentOffset() { + return index; + } + + @Override + protected boolean startImpl() throws IOException { + return advanceImpl(); + } + + @Override + public synchronized CreateSource getCurrentSource() { + return (CreateSource) super.getCurrentSource(); + } + + @Override + protected boolean advanceImpl() throws IOException { + CreateSource source = getCurrentSource(); + index++; + if (index >= source.allElementsBytes.size()) { + next = null; + return false; + } + next = + Optional.fromNullable( + CoderUtils.decodeFromByteArray(source.coder, source.allElementsBytes.get(index))); + return true; + } + } } ///////////////////////////////////////////////////////////////////////////// @@ -387,42 +527,4 @@ public void processElement(ProcessContext c) { } } } - - ///////////////////////////////////////////////////////////////////////////// - - static { - registerDefaultTransformEvaluator(); - } - - @SuppressWarnings({"rawtypes", "unchecked"}) - private static void registerDefaultTransformEvaluator() { - DirectPipelineRunner.registerDefaultTransformEvaluator( - Create.Values.class, - new DirectPipelineRunner.TransformEvaluator() { - @Override - public void evaluate( - Create.Values transform, - DirectPipelineRunner.EvaluationContext context) { - evaluateHelper(transform, context); - } - }); - } - - private static void evaluateHelper( - Create.Values transform, - DirectPipelineRunner.EvaluationContext context) { - // Convert the Iterable of elems into a List of elems. - List listElems; - if (transform.elems instanceof Collection) { - Collection collectionElems = (Collection) transform.elems; - listElems = new ArrayList<>(collectionElems.size()); - } else { - listElems = new ArrayList<>(); - } - for (T elem : transform.elems) { - listElems.add( - context.ensureElementEncodable(context.getOutput(transform), elem)); - } - context.setPCollection(context.getOutput(transform), listElems); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java index 7690d2ba88dc..e4eb2048be20 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java @@ -158,7 +158,8 @@ public void visitTransform(TransformTreeNode node) { // Pick is a composite, should not be visited here. assertThat(transform, not(instanceOf(Sample.SampleAny.class))); assertThat(transform, not(instanceOf(Write.Bound.class))); - if (transform instanceof Read.Bounded) { + if (transform instanceof Read.Bounded + && node.getEnclosingNode().getTransform() instanceof TextIO.Read.Bound) { assertTrue(visited.add(TransformsSeen.READ)); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java index 85c43226f5f2..8ed26843cf12 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.Matchers.isA; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; @@ -54,16 +55,9 @@ public class EncodabilityEnforcementFactoryTest { @Test public void encodeFailsThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordNoEncodeCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + + ModelEnforcement enforcement = createEnforcement(new RecordNoEncodeCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); @@ -73,16 +67,9 @@ public void encodeFailsThrows() { @Test public void decodeFailsThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordNoDecodeCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = createEnforcement(new RecordNoDecodeCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); @@ -92,12 +79,6 @@ public void decodeFailsThrows() { @Test public void consistentWithEqualsStructuralValueNotEqualThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordStructuralValueCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - WindowedValue record = WindowedValue.valueInGlobalWindow( new Record() { @@ -107,9 +88,8 @@ public String toString() { } }); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = + createEnforcement(new RecordStructuralValueCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(IllegalArgumentException.class)); @@ -143,6 +123,17 @@ public void notConsistentWithEqualsStructuralValueNotEqualSucceeds() { Collections.>emptyList()); } + private ModelEnforcement createEnforcement(Coder coder, WindowedValue record) { + TestPipeline p = TestPipeline.create(); + PCollection unencodable = p.apply(Create.of().withCoder(coder)); + AppliedPTransform consumer = + unencodable.apply(Count.globally()).getProducingTransformInternal(); + CommittedBundle input = + bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); + ModelEnforcement enforcement = factory.forBundle(input, consumer); + return enforcement; + } + @Test public void structurallyEqualResultsSucceeds() { TestPipeline p = TestPipeline.create(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java deleted file mode 100644 index 5c63af1c8e97..000000000000 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners.inprocess; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.runners.inprocess.InProcessCreate.InMemorySource; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.RunnableOnService; -import org.apache.beam.sdk.testing.SourceTestUtils; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.collect.ImmutableList; - -import org.hamcrest.Matchers; -import org.junit.Rule; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -/** - * Tests for {@link InProcessCreate}. - */ -@RunWith(JUnit4.class) -public class InProcessCreateTest { - @Rule - public ExpectedException thrown = ExpectedException.none(); - - @Test - @Category(RunnableOnService.class) - public void testConvertsCreate() { - TestPipeline p = TestPipeline.create(); - Create.Values og = Create.of(1, 2, 3); - - InProcessCreate converted = InProcessCreate.from(og); - - PAssert.that(p.apply(converted)).containsInAnyOrder(2, 1, 3); - - p.run(); - } - - @Test - @Category(RunnableOnService.class) - public void testConvertsCreateWithNullElements() { - Create.Values og = - Create.of("foo", null, "spam", "ham", null, "eggs") - .withCoder(NullableCoder.of(StringUtf8Coder.of())); - - InProcessCreate converted = InProcessCreate.from(og); - TestPipeline p = TestPipeline.create(); - - PAssert.that(p.apply(converted)) - .containsInAnyOrder(null, "foo", null, "spam", "ham", "eggs"); - - p.run(); - } - - static class Record implements Serializable {} - - static class Record2 extends Record {} - - @Test - public void testThrowsIllegalArgumentWhenCannotInferCoder() { - Create.Values og = Create.of(new Record(), new Record2()); - InProcessCreate converted = InProcessCreate.from(og); - - Pipeline p = TestPipeline.create(); - - // Create won't infer a default coder in this case. - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage(Matchers.containsString("Unable to infer a coder")); - - PCollection c = p.apply(converted); - p.run(); - - fail("Unexpectedly Inferred Coder " + c.getCoder()); - } - - /** - * An unserializable class to demonstrate encoding of elements. - */ - private static class UnserializableRecord { - private final String myString; - - private UnserializableRecord(String myString) { - this.myString = myString; - } - - @Override - public int hashCode() { - return myString.hashCode(); - } - - @Override - public boolean equals(Object o) { - return myString.equals(((UnserializableRecord) o).myString); - } - - static class UnserializableRecordCoder extends AtomicCoder { - private final Coder stringCoder = StringUtf8Coder.of(); - - @Override - public void encode( - UnserializableRecord value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - stringCoder.encode(value.myString, outStream, context.nested()); - } - - @Override - public UnserializableRecord decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - return new UnserializableRecord(stringCoder.decode(inStream, context.nested())); - } - } - } - - @Test - @Category(RunnableOnService.class) - public void testConvertsUnserializableElements() throws Exception { - List elements = - ImmutableList.of( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - InProcessCreate create = - InProcessCreate.from( - Create.of(elements).withCoder(new UnserializableRecord.UnserializableRecordCoder())); - - TestPipeline p = TestPipeline.create(); - PAssert.that(p.apply(create)) - .containsInAnyOrder( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - p.run(); - } - - @Test - public void testSerializableOnUnserializableElements() throws Exception { - List elements = - ImmutableList.of( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - InMemorySource source = - InMemorySource.fromIterable(elements, new UnserializableRecord.UnserializableRecordCoder()); - SerializableUtils.ensureSerializable(source); - } - - @Test - public void testSplitIntoBundles() throws Exception { - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable( - ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), BigEndianIntegerCoder.of()); - PipelineOptions options = PipelineOptionsFactory.create(); - List> splitSources = source.splitIntoBundles(12, options); - assertThat(splitSources, hasSize(3)); - SourceTestUtils.assertSourcesEqualReferenceSource(source, splitSources, options); - } - - @Test - public void testDoesNotProduceSortedKeys() throws Exception { - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(ImmutableList.of("spam", "ham", "eggs"), StringUtf8Coder.of()); - assertThat(source.producesSortedKeys(PipelineOptionsFactory.create()), is(false)); - } - - @Test - public void testGetDefaultOutputCoderReturnsConstructorCoder() throws Exception { - Coder coder = VarIntCoder.of(); - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder); - - Coder defaultCoder = source.getDefaultOutputCoder(); - assertThat(defaultCoder, equalTo(coder)); - } - - @Test - public void testSplitAtFraction() throws Exception { - List elements = new ArrayList<>(); - Random random = new Random(); - for (int i = 0; i < 25; i++) { - elements.add(random.nextInt()); - } - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(elements, VarIntCoder.of()); - - SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); - } -} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java index 393fedec80c6..2998489d4733 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java @@ -22,19 +22,36 @@ import static org.apache.beam.sdk.TestUtils.NO_LINES; import static org.apache.beam.sdk.TestUtils.NO_LINES_ARRAY; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create.Values.CreateSource; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; +import com.google.common.collect.ImmutableList; + import org.hamcrest.Matchers; import org.joda.time.Instant; import org.junit.Rule; @@ -44,11 +61,15 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Random; /** * Tests for Create. @@ -142,6 +163,67 @@ public void testCreateParameterizedType() throws Exception { TimestampedValue.of("a", new Instant(0)), TimestampedValue.of("b", new Instant(0))); } + /** + * An unserializable class to demonstrate encoding of elements. + */ + private static class UnserializableRecord { + private final String myString; + + private UnserializableRecord(String myString) { + this.myString = myString; + } + + @Override + public int hashCode() { + return myString.hashCode(); + } + + @Override + public boolean equals(Object o) { + return myString.equals(((UnserializableRecord) o).myString); + } + + static class UnserializableRecordCoder extends AtomicCoder { + private final Coder stringCoder = StringUtf8Coder.of(); + + @Override + public void encode( + UnserializableRecord value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + stringCoder.encode(value.myString, outStream, context.nested()); + } + + @Override + public UnserializableRecord decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new UnserializableRecord(stringCoder.decode(inStream, context.nested())); + } + } + } + + @Test + @Category(RunnableOnService.class) + public void testCreateWithUnserializableElements() throws Exception { + List elements = + ImmutableList.of( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + Create.Values create = + Create.of(elements).withCoder(new UnserializableRecord.UnserializableRecordCoder()); + + TestPipeline p = TestPipeline.create(); + PAssert.that(p.apply(create)) + .containsInAnyOrder( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + p.run(); + } + private static class PrintTimestamps extends DoFn { @Override public void processElement(ProcessContext c) { @@ -239,4 +321,56 @@ public void testCreateGetName() { assertEquals("Create.Values", Create.of(1, 2, 3).getName()); assertEquals("Create.TimestampedValues", Create.timestamped(Collections.EMPTY_LIST).getName()); } + + @Test + public void testSourceIsSerializableWithUnserializableElements() throws Exception { + List elements = + ImmutableList.of( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + CreateSource source = + CreateSource.fromIterable(elements, new UnserializableRecord.UnserializableRecordCoder()); + SerializableUtils.ensureSerializable(source); + } + + @Test + public void testSourceSplitIntoBundles() throws Exception { + CreateSource source = + CreateSource.fromIterable( + ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), BigEndianIntegerCoder.of()); + PipelineOptions options = PipelineOptionsFactory.create(); + List> splitSources = source.splitIntoBundles(12, options); + assertThat(splitSources, hasSize(3)); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splitSources, options); + } + + @Test + public void testSourceDoesNotProduceSortedKeys() throws Exception { + CreateSource source = + CreateSource.fromIterable(ImmutableList.of("spam", "ham", "eggs"), StringUtf8Coder.of()); + assertThat(source.producesSortedKeys(PipelineOptionsFactory.create()), is(false)); + } + + @Test + public void testSourceGetDefaultOutputCoderReturnsConstructorCoder() throws Exception { + Coder coder = VarIntCoder.of(); + CreateSource source = + CreateSource.fromIterable(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder); + + Coder defaultCoder = source.getDefaultOutputCoder(); + assertThat(defaultCoder, equalTo(coder)); + } + + @Test + public void testSourceSplitAtFraction() throws Exception { + List elements = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < 25; i++) { + elements.add(random.nextInt()); + } + CreateSource source = CreateSource.fromIterable(elements, VarIntCoder.of()); + + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } }