diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineRunner.java index 58733f34dc14..c147f025bc74 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineRunner.java @@ -89,7 +89,6 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.util.DataflowReleaseInfo; import org.apache.beam.sdk.util.DataflowTransport; import org.apache.beam.sdk.util.IOChannelUtils; import org.apache.beam.sdk.util.InstanceBuilder; @@ -97,6 +96,7 @@ import org.apache.beam.sdk.util.PCollectionViews; import org.apache.beam.sdk.util.PathValidator; import org.apache.beam.sdk.util.PropertyNames; +import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.util.Reshuffle; import org.apache.beam.sdk.util.SystemDoFnInternal; import org.apache.beam.sdk.util.ValueWithRecordId; @@ -507,10 +507,10 @@ public DataflowPipelineJob run(Pipeline pipeline) { Job newJob = jobSpecification.getJob(); newJob.setClientRequestId(requestId); - String version = DataflowReleaseInfo.getReleaseInfo().getVersion(); + String version = ReleaseInfo.getReleaseInfo().getVersion(); System.out.println("Dataflow SDK version: " + version); - newJob.getEnvironment().setUserAgent(DataflowReleaseInfo.getReleaseInfo()); + newJob.getEnvironment().setUserAgent(ReleaseInfo.getReleaseInfo()); // The Dataflow Service may write to the temporary directory directly, so // must be verified. if (!Strings.isNullOrEmpty(options.getTempLocation())) { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java index 4e60545c20af..5c0745f54ace 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/sdk/runners/DataflowPipelineTranslator.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.runners; -import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray; import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString; import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray; @@ -34,7 +33,6 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.io.BigQueryIO; import org.apache.beam.sdk.io.PubsubIO; @@ -47,7 +45,6 @@ import org.apache.beam.sdk.runners.dataflow.ReadTranslator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; @@ -844,45 +841,6 @@ private void translateHelper( } }); - registerTransformTranslator( - Create.Values.class, - new TransformTranslator() { - @Override - public void translate( - Create.Values transform, - TranslationContext context) { - createHelper(transform, context); - } - - private void createHelper( - Create.Values transform, - TranslationContext context) { - context.addStep(transform, "CreateCollection"); - - Coder coder = context.getOutput(transform).getCoder(); - List elements = new LinkedList<>(); - for (T elem : transform.getElements()) { - byte[] encodedBytes; - try { - encodedBytes = encodeToByteArray(coder, elem); - } catch (CoderException exn) { - // TODO: Put in better element printing: - // truncate if too long. - throw new IllegalArgumentException( - "Unable to encode element '" + elem + "' of transform '" + transform - + "' using coder '" + coder + "'.", - exn); - } - String encodedJson = byteArrayToJsonString(encodedBytes); - assert Arrays.equals(encodedBytes, - jsonStringToByteArray(encodedJson)); - elements.add(CloudObject.forString(encodedJson)); - } - context.addInput(PropertyNames.ELEMENT, elements); - context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); - } - }); - registerTransformTranslator( Flatten.FlattenPCollectionList.class, new TransformTranslator() { diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java index 8b024fb8726c..8b5cbdb9c8e4 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineRunnerTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; @@ -68,9 +69,9 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.util.DataflowReleaseInfo; import org.apache.beam.sdk.util.GcsUtil; import org.apache.beam.sdk.util.NoopPathValidator; +import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.util.TestCredential; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; @@ -375,10 +376,10 @@ public void testRunWithFiles() throws IOException { cloudDataflowDataset, workflowJob.getEnvironment().getDataset()); assertEquals( - DataflowReleaseInfo.getReleaseInfo().getName(), + ReleaseInfo.getReleaseInfo().getName(), workflowJob.getEnvironment().getUserAgent().get("name")); assertEquals( - DataflowReleaseInfo.getReleaseInfo().getVersion(), + ReleaseInfo.getReleaseInfo().getVersion(), workflowJob.getEnvironment().getUserAgent().get("version")); } @@ -840,9 +841,16 @@ public void testApplyIsScopedToExactClass() throws IOException { CompositeTransformRecorder recorder = new CompositeTransformRecorder(); p.traverseTopologically(recorder); - assertThat("Expected to have seen CreateTimestamped composite transform.", + // The recorder will also have seen a Create.Values composite as well, but we can't obtain that + // transform. + assertThat( + "Expected to have seen CreateTimestamped composite transform.", recorder.getCompositeTransforms(), - Matchers.>contains(transform)); + hasItem(transform)); + assertThat( + "Expected to have two composites, CreateTimestamped and Create.Values", + recorder.getCompositeTransforms(), + hasItem(Matchers.>isA((Class) Create.Values.class))); } @Test diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java index 0d58601d7e8b..a62f55042bf9 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java @@ -751,7 +751,7 @@ public void testToSingletonTranslation() throws Exception { assertEquals(2, steps.size()); Step createStep = steps.get(0); - assertEquals("CreateCollection", createStep.getKind()); + assertEquals("ParallelRead", createStep.getKind()); Step collectionToSingletonStep = steps.get(1); assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); @@ -783,7 +783,7 @@ public void testToIterableTranslation() throws Exception { assertEquals(2, steps.size()); Step createStep = steps.get(0); - assertEquals("CreateCollection", createStep.getKind()); + assertEquals("ParallelRead", createStep.getKind()); Step collectionToSingletonStep = steps.get(1); assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); diff --git a/runners/spark/README.md b/runners/spark/README.md index 1d75b3519fc0..5b2e73232ec3 100644 --- a/runners/spark/README.md +++ b/runners/spark/README.md @@ -93,7 +93,7 @@ Switch to the Spark runner directory: Then run the [word count example][wc] from the SDK using a single threaded Spark instance in local mode: - mvn exec:exec -DmainClass=com.google.cloud.dataflow.examples.WordCount \ + mvn exec:exec -DmainClass=org.apache.beam.examples.WordCount \ -Dinput=/tmp/kinglear.txt -Doutput=/tmp/out -Drunner=SparkPipelineRunner \ -DsparkMaster=local @@ -104,7 +104,7 @@ Check the output by running: __Note: running examples using `mvn exec:exec` only works for Spark local mode at the moment. See the next section for how to run on a cluster.__ -[wc]: https://github.com/apache/incubator-beam/blob/master/examples/src/main/java/com/google/cloud/dataflow/examples/WordCount.java +[wc]: https://github.com/apache/incubator-beam/blob/master/examples/java/src/main/java/org/apache/beam/examples/WordCount.java ## Running on a Cluster Spark Beam pipelines can be run on a cluster using the `spark-submit` command. @@ -117,7 +117,7 @@ Then run the word count example using Spark submit with the `yarn-client` master (`yarn-cluster` works just as well): spark-submit \ - --class com.google.cloud.dataflow.examples.WordCount \ + --class org.apache.beam.examples.WordCount \ --master yarn-client \ target/spark-runner-*-spark-app.jar \ --inputFile=kinglear.txt --output=out --runner=SparkPipelineRunner --sparkMaster=yarn-client diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index f12d8a6a3c17..5ccaec5b6bc1 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -70,6 +70,12 @@ guava ${guava.version} + + com.google.auto.service + auto-service + 1.0-rc2 + true + org.apache.beam java-sdk-all diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerRegistrar.java new file mode 100644 index 000000000000..30142f9966fb --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerRegistrar.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark; + +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.apache.beam.sdk.runners.PipelineRunner; +import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; + +/** + * Contains the {@link PipelineRunnerRegistrar} and {@link PipelineOptionsRegistrar} for the + * {@link SparkPipelineRunner}. + * + * {@link AutoService} will register Spark's implementations of the {@link PipelineRunner} + * and {@link PipelineOptions} as available pipeline runner services. + */ +public final class SparkRunnerRegistrar { + private SparkRunnerRegistrar() {} + + /** + * Registers the {@link SparkPipelineRunner}. + */ + @AutoService(PipelineRunnerRegistrar.class) + public static class Runner implements PipelineRunnerRegistrar { + @Override + public Iterable>> getPipelineRunners() { + return ImmutableList.>>of(SparkPipelineRunner.class); + } + } + + /** + * Registers the {@link SparkPipelineOptions} and {@link SparkStreamingPipelineOptions}. + */ + @AutoService(PipelineOptionsRegistrar.class) + public static class Options implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.>of( + SparkPipelineOptions.class, + SparkStreamingPipelineOptions.class); + } + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java deleted file mode 100644 index c882d7b84820..000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineOptionsRegistrar.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.translation; - -import com.google.common.collect.ImmutableList; -import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsRegistrar; - -public class SparkPipelineOptionsRegistrar implements PipelineOptionsRegistrar { - @Override - public Iterable> getPipelineOptions() { - return ImmutableList.>of(SparkPipelineOptions.class); - } -} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java deleted file mode 100644 index 38993fb8543e..000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineRunnerRegistrar.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.translation; - -import com.google.common.collect.ImmutableList; -import org.apache.beam.runners.spark.SparkPipelineRunner; -import org.apache.beam.sdk.runners.PipelineRunner; -import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; - -public class SparkPipelineRunnerRegistrar implements PipelineRunnerRegistrar { - @Override - public Iterable>> getPipelineRunners() { - return ImmutableList.>>of(SparkPipelineRunner.class); - } -} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkStreamingPipelineOptionsRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkStreamingPipelineOptionsRegistrar.java deleted file mode 100644 index 2e3509837e9d..000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkStreamingPipelineOptionsRegistrar.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.translation.streaming; - -import com.google.common.collect.ImmutableList; -import org.apache.beam.runners.spark.SparkStreamingPipelineOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsRegistrar; - -public class SparkStreamingPipelineOptionsRegistrar implements PipelineOptionsRegistrar { - - @Override - public Iterable> getPipelineOptions() { - return ImmutableList.>of(SparkStreamingPipelineOptions - .class); - } -} diff --git a/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar b/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar deleted file mode 100644 index e4a3a737425b..000000000000 --- a/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2014 Cloudera Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -org.apache.beam.runners.spark.translation.SparkPipelineOptionsRegistrar -org.apache.beam.runners.spark.translation.streaming.SparkStreamingPipelineOptionsRegistrar \ No newline at end of file diff --git a/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar b/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar deleted file mode 100644 index 7949db444cad..000000000000 --- a/runners/spark/src/main/resources/META-INF/services/com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar +++ /dev/null @@ -1,16 +0,0 @@ -# -# Copyright 2014 Cloudera Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -org.apache.beam.runners.spark.translation.SparkPipelineRunnerRegistrar \ No newline at end of file diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerRegistrarTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerRegistrarTest.java new file mode 100644 index 000000000000..d51403ffbf7b --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerRegistrarTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; +import org.junit.Test; + +import java.util.ServiceLoader; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * Test {@link SparkRunnerRegistrar}. + */ +public class SparkRunnerRegistrarTest { + @Test + public void testOptions() { + assertEquals( + ImmutableList.of(SparkPipelineOptions.class, SparkStreamingPipelineOptions.class), + new SparkRunnerRegistrar.Options().getPipelineOptions()); + } + + @Test + public void testRunners() { + assertEquals(ImmutableList.of(SparkPipelineRunner.class), + new SparkRunnerRegistrar.Runner().getPipelineRunners()); + } + + @Test + public void testServiceLoaderForOptions() { + for (PipelineOptionsRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class).iterator())) { + if (registrar instanceof SparkRunnerRegistrar.Options) { + return; + } + } + fail("Expected to find " + SparkRunnerRegistrar.Options.class); + } + + @Test + public void testServiceLoaderForRunner() { + for (PipelineRunnerRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class).iterator())) { + if (registrar instanceof SparkRunnerRegistrar.Runner) { + return; + } + } + fail("Expected to find " + SparkRunnerRegistrar.Runner.class); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java index d9debbdcb3d7..923951431326 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java @@ -50,6 +50,8 @@ import org.apache.beam.sdk.util.BigQueryServicesImpl; import org.apache.beam.sdk.util.BigQueryTableInserter; import org.apache.beam.sdk.util.BigQueryTableRowIterator; +import org.apache.beam.sdk.util.IOChannelFactory; +import org.apache.beam.sdk.util.IOChannelUtils; import org.apache.beam.sdk.util.MimeTypes; import org.apache.beam.sdk.util.PropertyNames; import org.apache.beam.sdk.util.Reshuffle; @@ -1015,7 +1017,19 @@ public PDone apply(PCollection input) { table.setProjectId(options.getProject()); } String jobIdToken = UUID.randomUUID().toString(); - String tempFilePrefix = options.getTempLocation() + "/BigQuerySinkTemp/" + jobIdToken; + String tempLocation = options.getTempLocation(); + String tempFilePrefix; + try { + IOChannelFactory factory = IOChannelUtils.getFactory(tempLocation); + tempFilePrefix = factory.resolve( + factory.resolve(tempLocation, "BigQuerySinkTemp"), + jobIdToken); + } catch (IOException e) { + throw new RuntimeException( + String.format("Failed to resolve BigQuery temp location in %s", tempLocation), + e); + } + BigQueryServices bqServices = getBigQueryServices(); return input.apply("Write", org.apache.beam.sdk.io.Write.to( new BigQuerySink( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/bigtable/BigtableIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/bigtable/BigtableIO.java index b2d9cb34fb54..5177262b57b8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/bigtable/BigtableIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/bigtable/BigtableIO.java @@ -35,7 +35,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.util.DataflowReleaseInfo; +import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -978,7 +978,7 @@ public BigtableWriteException(KV> record, Throwab */ private static String getUserAgent() { String javaVersion = System.getProperty("java.specification.version"); - DataflowReleaseInfo info = DataflowReleaseInfo.getReleaseInfo(); + ReleaseInfo info = ReleaseInfo.getReleaseInfo(); return String.format( "%s/%s (%s); %s", info.getName(), diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java deleted file mode 100644 index c29d5ce045d0..000000000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessCreate.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners.inprocess; - -import org.apache.beam.sdk.coders.CannotProvideCoderException; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.OffsetBasedSource; -import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader; -import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Create.Values; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Optional; -import com.google.common.collect.ImmutableList; - -import java.io.IOException; -import java.util.List; -import java.util.NoSuchElementException; - -import javax.annotation.Nullable; - -/** - * An in-process implementation of the {@link Values Create.Values} {@link PTransform}, implemented - * using a {@link BoundedSource}. - * - * The coder is inferred via the {@link Values#getDefaultOutputCoder(PInput)} method on the original - * transform. - */ -class InProcessCreate extends ForwardingPTransform> { - private final Create.Values original; - - /** - * A {@link PTransformOverrideFactory} for {@link InProcessCreate}. - */ - public static class InProcessCreateOverrideFactory implements PTransformOverrideFactory { - @Override - public PTransform override( - PTransform transform) { - if (transform instanceof Create.Values) { - @SuppressWarnings("unchecked") - PTransform override = - (PTransform) from((Create.Values) transform); - return override; - } - return transform; - } - } - - public static InProcessCreate from(Create.Values original) { - return new InProcessCreate<>(original); - } - - private InProcessCreate(Values original) { - this.original = original; - } - - @Override - public PCollection apply(PInput input) { - Coder elementCoder; - try { - elementCoder = original.getDefaultOutputCoder(input); - } catch (CannotProvideCoderException e) { - throw new IllegalArgumentException( - "Unable to infer a coder and no Coder was specified. " - + "Please set a coder by invoking Create.withCoder() explicitly.", - e); - } - InMemorySource source; - try { - source = InMemorySource.fromIterable(original.getElements(), elementCoder); - } catch (IOException e) { - throw new RuntimeException(e); - } - PCollection result = input.getPipeline().apply(Read.from(source)); - result.setCoder(elementCoder); - return result; - } - - @Override - public PTransform> delegate() { - return original; - } - - @VisibleForTesting - static class InMemorySource extends OffsetBasedSource { - private final List allElementsBytes; - private final long totalSize; - private final Coder coder; - - public static InMemorySource fromIterable(Iterable elements, Coder elemCoder) - throws CoderException, IOException { - ImmutableList.Builder allElementsBytes = ImmutableList.builder(); - long totalSize = 0L; - for (T element : elements) { - byte[] bytes = CoderUtils.encodeToByteArray(elemCoder, element); - allElementsBytes.add(bytes); - totalSize += bytes.length; - } - return new InMemorySource<>(allElementsBytes.build(), totalSize, elemCoder); - } - - /** - * Create a new source with the specified bytes. The new source owns the input element bytes, - * which must not be modified after this constructor is called. - */ - private InMemorySource(List elementBytes, long totalSize, Coder coder) { - super(0, elementBytes.size(), 1); - this.allElementsBytes = ImmutableList.copyOf(elementBytes); - this.totalSize = totalSize; - this.coder = coder; - } - - @Override - public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { - return totalSize; - } - - @Override - public boolean producesSortedKeys(PipelineOptions options) throws Exception { - return false; - } - - @Override - public BoundedSource.BoundedReader createReader(PipelineOptions options) throws IOException { - return new BytesReader<>(this); - } - - @Override - public void validate() {} - - @Override - public Coder getDefaultOutputCoder() { - return coder; - } - - @Override - public long getMaxEndOffset(PipelineOptions options) throws Exception { - return allElementsBytes.size(); - } - - @Override - public OffsetBasedSource createSourceForSubrange(long start, long end) { - List primaryElems = allElementsBytes.subList((int) start, (int) end); - long primarySizeEstimate = - (long) (totalSize * primaryElems.size() / (double) allElementsBytes.size()); - return new InMemorySource<>(primaryElems, primarySizeEstimate, coder); - } - - @Override - public long getBytesPerOffset() { - if (allElementsBytes.size() == 0) { - return 0L; - } - return totalSize / allElementsBytes.size(); - } - } - - private static class BytesReader extends OffsetBasedReader { - private int index; - /** - * Use an optional to distinguish between null next element (as Optional.absent()) and no next - * element (next is null). - */ - @Nullable private Optional next; - - public BytesReader(InMemorySource source) { - super(source); - index = -1; - } - - @Override - @Nullable - public T getCurrent() throws NoSuchElementException { - if (next == null) { - throw new NoSuchElementException(); - } - return next.orNull(); - } - - @Override - public void close() throws IOException {} - - @Override - protected long getCurrentOffset() { - return index; - } - - @Override - protected boolean startImpl() throws IOException { - return advanceImpl(); - } - - @Override - public synchronized InMemorySource getCurrentSource() { - return (InMemorySource) super.getCurrentSource(); - } - - @Override - protected boolean advanceImpl() throws IOException { - InMemorySource source = getCurrentSource(); - index++; - if (index >= source.allElementsBytes.size()) { - return false; - } - next = - Optional.fromNullable( - CoderUtils.decodeFromByteArray( - source.coder, source.allElementsBytes.get(index))); - return true; - } - } -} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessExecutionContext.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessExecutionContext.java index 1430c989a6ca..e6441cf1e60d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessExecutionContext.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessExecutionContext.java @@ -22,7 +22,6 @@ import org.apache.beam.sdk.util.BaseExecutionContext; import org.apache.beam.sdk.util.ExecutionContext; import org.apache.beam.sdk.util.TimerInternals; -import org.apache.beam.sdk.util.common.worker.StateSampler; import org.apache.beam.sdk.util.state.CopyOnAccessInMemoryStateInternals; /** @@ -47,8 +46,7 @@ public InProcessExecutionContext(Clock clock, Object key, } @Override - protected InProcessStepContext createStepContext( - String stepName, String transformName, StateSampler stateSampler) { + protected InProcessStepContext createStepContext(String stepName, String transformName) { return new InProcessStepContext(this, stepName, transformName); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java index 6cc35fb01ee5..7c28238d0dad 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -30,11 +30,9 @@ import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; import org.apache.beam.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOverrideFactory; -import org.apache.beam.sdk.runners.inprocess.InProcessCreate.InProcessCreateOverrideFactory; import org.apache.beam.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessViewOverrideFactory; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -83,7 +81,6 @@ public class InProcessPipelineRunner private static Map, PTransformOverrideFactory> defaultTransformOverrides = ImmutableMap., PTransformOverrideFactory>builder() - .put(Create.Values.class, new InProcessCreateOverrideFactory()) .put(GroupByKey.class, new InProcessGroupByKeyOverrideFactory()) .put(CreatePCollectionView.class, new InProcessViewOverrideFactory()) .put(AvroIO.Write.Bound.class, new AvroIOShardedWriteFactory()) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ParDoInProcessEvaluator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ParDoInProcessEvaluator.java index a2f080c19cf9..35639bdcac5b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ParDoInProcessEvaluator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ParDoInProcessEvaluator.java @@ -54,7 +54,7 @@ public static ParDoInProcessEvaluator create( evaluationContext.getExecutionContext(application, inputBundle.getKey()); String stepName = evaluationContext.getStepName(application); InProcessStepContext stepContext = - executionContext.getOrCreateStepContext(stepName, stepName, null); + executionContext.getOrCreateStepContext(stepName, stepName); CounterSet counters = evaluationContext.createCounterSet(); @@ -77,7 +77,11 @@ public static ParDoInProcessEvaluator create( counters.getAddCounterMutator(), application.getInput().getWindowingStrategy()); - runner.startBundle(); + try { + runner.startBundle(); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } return new ParDoInProcessEvaluator<>( runner, application, counters, outputBundles.values(), stepContext); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/worker/IsmFormat.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/worker/IsmFormat.java index 8b23e0a0f795..8df46dd4149a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/worker/IsmFormat.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/worker/IsmFormat.java @@ -54,7 +54,6 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.List; - import javax.annotation.Nullable; /** @@ -746,9 +745,9 @@ public long getEncodedElementByteSize(KeyPrefix value, Coder.Context context) */ @AutoValue public abstract static class Footer { - static final int LONG_BYTES = 8; - static final int FIXED_LENGTH = 3 * LONG_BYTES + 1; - static final byte VERSION = 2; + public static final int LONG_BYTES = 8; + public static final int FIXED_LENGTH = 3 * LONG_BYTES + 1; + public static final byte VERSION = 2; public abstract byte getVersion(); public abstract long getIndexPosition(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java index 27fb39d8f8ab..1bd4fb3912f2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java @@ -20,33 +20,42 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.runners.DirectPipelineRunner; -import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.OffsetBasedSource; +import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; import org.apache.beam.sdk.values.TypeDescriptor; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.joda.time.Instant; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; +import javax.annotation.Nullable; + /** * {@code Create} takes a collection of elements of type {@code T} * known when the pipeline is constructed and returns a @@ -237,12 +246,13 @@ public Iterable getElements() { public PCollection apply(PInput input) { try { Coder coder = getDefaultOutputCoder(input); - return PCollection - .createPrimitiveOutputInternal( - input.getPipeline(), - WindowingStrategy.globalDefault(), - IsBounded.BOUNDED) - .setCoder(coder); + try { + CreateSource source = CreateSource.fromIterable(elems, coder); + return input.getPipeline().apply(Read.from(source)); + } catch (IOException e) { + throw new RuntimeException( + String.format("Unable to apply Create %s using Coder %s.", this, coder), e); + } } catch (CannotProvideCoderException e) { throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. " + "Please set a coder by invoking Create.withCoder() explicitly.", e); @@ -320,6 +330,136 @@ private Values(Iterable elems, Optional> coder) { this.elems = elems; this.coder = coder; } + + @VisibleForTesting + static class CreateSource extends OffsetBasedSource { + private final List allElementsBytes; + private final long totalSize; + private final Coder coder; + + public static CreateSource fromIterable(Iterable elements, Coder elemCoder) + throws CoderException, IOException { + ImmutableList.Builder allElementsBytes = ImmutableList.builder(); + long totalSize = 0L; + for (T element : elements) { + byte[] bytes = CoderUtils.encodeToByteArray(elemCoder, element); + allElementsBytes.add(bytes); + totalSize += bytes.length; + } + return new CreateSource<>(allElementsBytes.build(), totalSize, elemCoder); + } + + /** + * Create a new source with the specified bytes. The new source owns the input element bytes, + * which must not be modified after this constructor is called. + */ + private CreateSource(List elementBytes, long totalSize, Coder coder) { + super(0, elementBytes.size(), 1); + this.allElementsBytes = ImmutableList.copyOf(elementBytes); + this.totalSize = totalSize; + this.coder = coder; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return totalSize; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader createReader(PipelineOptions options) + throws IOException { + return new BytesReader<>(this); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return coder; + } + + @Override + public long getMaxEndOffset(PipelineOptions options) throws Exception { + return allElementsBytes.size(); + } + + @Override + public OffsetBasedSource createSourceForSubrange(long start, long end) { + List primaryElems = allElementsBytes.subList((int) start, (int) end); + long primarySizeEstimate = + (long) (totalSize * primaryElems.size() / (double) allElementsBytes.size()); + return new CreateSource<>(primaryElems, primarySizeEstimate, coder); + } + + @Override + public long getBytesPerOffset() { + if (allElementsBytes.size() == 0) { + return 0L; + } + return totalSize / allElementsBytes.size(); + } + } + + private static class BytesReader extends OffsetBasedReader { + private int index; + /** + * Use an optional to distinguish between null next element (as Optional.absent()) and no next + * element (next is null). + */ + @Nullable private Optional next; + + public BytesReader(CreateSource source) { + super(source); + index = -1; + } + + @Override + @Nullable + public T getCurrent() throws NoSuchElementException { + if (next == null) { + throw new NoSuchElementException(); + } + return next.orNull(); + } + + @Override + public void close() throws IOException {} + + @Override + protected long getCurrentOffset() { + return index; + } + + @Override + protected boolean startImpl() throws IOException { + return advanceImpl(); + } + + @Override + public synchronized CreateSource getCurrentSource() { + return (CreateSource) super.getCurrentSource(); + } + + @Override + protected boolean advanceImpl() throws IOException { + CreateSource source = getCurrentSource(); + index++; + if (index >= source.allElementsBytes.size()) { + next = null; + return false; + } + next = + Optional.fromNullable( + CoderUtils.decodeFromByteArray(source.coder, source.allElementsBytes.get(index))); + return true; + } + } } ///////////////////////////////////////////////////////////////////////////// @@ -387,42 +527,4 @@ public void processElement(ProcessContext c) { } } } - - ///////////////////////////////////////////////////////////////////////////// - - static { - registerDefaultTransformEvaluator(); - } - - @SuppressWarnings({"rawtypes", "unchecked"}) - private static void registerDefaultTransformEvaluator() { - DirectPipelineRunner.registerDefaultTransformEvaluator( - Create.Values.class, - new DirectPipelineRunner.TransformEvaluator() { - @Override - public void evaluate( - Create.Values transform, - DirectPipelineRunner.EvaluationContext context) { - evaluateHelper(transform, context); - } - }); - } - - private static void evaluateHelper( - Create.Values transform, - DirectPipelineRunner.EvaluationContext context) { - // Convert the Iterable of elems into a List of elems. - List listElems; - if (transform.elems instanceof Collection) { - Collection collectionElems = (Collection) transform.elems; - listElems = new ArrayList<>(collectionElems.size()); - } else { - listElems = new ArrayList<>(); - } - for (T elem : transform.elems) { - listElems.add( - context.ensureElementEncodable(context.getOutput(transform), elem)); - } - context.setPCollection(context.getOutput(transform), listElems); - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 1c60259100fa..16dc731816f8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -483,15 +483,16 @@ void initializeState() { runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), entry.getValue()); } outputManager = new DoFnRunnerBase.ListOutputManager(); - fnRunner = DoFnRunners.createDefault( - options, - fn, - DirectSideInputReader.of(runnerSideInputs), - outputManager, - mainOutputTag, - sideOutputTags, - DirectModeExecutionContext.create().getOrCreateStepContext(STEP_NAME, TRANSFORM_NAME, null), - counterSet.getAddCounterMutator(), - WindowingStrategy.globalDefault()); + fnRunner = + DoFnRunners.createDefault( + options, + fn, + DirectSideInputReader.of(runnerSideInputs), + outputManager, + mainOutputTag, + sideOutputTags, + DirectModeExecutionContext.create().getOrCreateStepContext(STEP_NAME, TRANSFORM_NAME), + counterSet.getAddCounterMutator(), + WindowingStrategy.globalDefault()); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index d266155b470c..02464ac2bc5e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -1200,7 +1200,7 @@ private static void evaluateHelpe outputManager, mainOutputTag, sideOutputTags, - executionContext.getOrCreateStepContext(stepName, stepName, null), + executionContext.getOrCreateStepContext(stepName, stepName), context.getAddCounterMutator(), input.getWindowingStrategy()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BaseExecutionContext.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BaseExecutionContext.java index 33df089226bf..a62444fec239 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BaseExecutionContext.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/BaseExecutionContext.java @@ -19,7 +19,6 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.common.worker.StateSampler; import org.apache.beam.sdk.util.state.StateInternals; import org.apache.beam.sdk.values.TupleTag; @@ -37,7 +36,7 @@ * be cached for the lifetime of this {@link ExecutionContext}. * *

BaseExecutionContext is generic to allow implementing subclasses to return a concrete subclass - * of {@link StepContext} from {@link #getOrCreateStepContext(String, String, StateSampler)} and + * of {@link StepContext} from {@link #getOrCreateStepContext(String, String)} and * {@link #getAllStepContexts()} without forcing each subclass to override the method, e.g. *

  * @Override
@@ -47,8 +46,8 @@
  * 
* *

When a subclass of {@code BaseExecutionContext} has been downcast, the return types of - * {@link #createStepContext(String, String, StateSampler)}, - * {@link #getOrCreateStepContext(String, String, StateSampler}, and {@link #getAllStepContexts()} + * {@link #createStepContext(String, String)}, + * {@link #getOrCreateStepContext(String, String)}, and {@link #getAllStepContexts()} * will be appropriately specialized. */ public abstract class BaseExecutionContext @@ -60,21 +59,41 @@ public abstract class BaseExecutionContext() { + @Override + public T create() { + return createStepContext(finalStepName, finalTransformName); + } + }); + } + + /** + * Factory method interface to create an execution context if none exists during + * {@link #getOrCreateStepContext(String, CreateStepContextFunction)}. + */ + protected interface CreateStepContextFunction { + T create(); + } + + protected final T getOrCreateStepContext(String stepName, + CreateStepContextFunction createContextFunc) { T context = cachedStepContexts.get(stepName); if (context == null) { - context = createStepContext(stepName, transformName, stateSampler); + context = createContextFunc.create(); cachedStepContexts.put(stepName, context); } + return context; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DirectModeExecutionContext.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DirectModeExecutionContext.java index c3da3d7fced3..85e36dd6d14b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DirectModeExecutionContext.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DirectModeExecutionContext.java @@ -20,7 +20,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import org.apache.beam.sdk.runners.DirectPipelineRunner.ValueWithMetadata; -import org.apache.beam.sdk.util.common.worker.StateSampler; import org.apache.beam.sdk.util.state.InMemoryStateInternals; import org.apache.beam.sdk.util.state.StateInternals; import org.apache.beam.sdk.values.TupleTag; @@ -48,8 +47,7 @@ public static DirectModeExecutionContext create() { } @Override - protected StepContext createStepContext( - String stepName, String transformName, StateSampler stateSampler) { + protected StepContext createStepContext(String stepName, String transformName) { return new StepContext(this, stepName, transformName); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ExecutionContext.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ExecutionContext.java index 577aa666ec2f..01bde829972d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ExecutionContext.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ExecutionContext.java @@ -19,7 +19,6 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.common.worker.StateSampler; import org.apache.beam.sdk.util.state.StateInternals; import org.apache.beam.sdk.values.TupleTag; @@ -34,8 +33,7 @@ public interface ExecutionContext { /** * Returns the {@link StepContext} associated with the given step. */ - StepContext getOrCreateStepContext( - String stepName, String transformName, StateSampler stateSampler); + StepContext getOrCreateStepContext(String stepName, String transformName); /** * Returns a collection view of all of the {@link StepContext}s. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DataflowReleaseInfo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ReleaseInfo.java similarity index 74% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/util/DataflowReleaseInfo.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/util/ReleaseInfo.java index 8c096fb7695c..77289ac63192 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DataflowReleaseInfo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ReleaseInfo.java @@ -28,27 +28,27 @@ import java.util.Properties; /** - * Utilities for working with the Dataflow distribution. + * Utilities for working with release information. */ -public final class DataflowReleaseInfo extends GenericJson { - private static final Logger LOG = LoggerFactory.getLogger(DataflowReleaseInfo.class); +public final class ReleaseInfo extends GenericJson { + private static final Logger LOG = LoggerFactory.getLogger(ReleaseInfo.class); - private static final String DATAFLOW_PROPERTIES_PATH = - "/org.apache.beam/sdk/sdk.properties"; + private static final String PROPERTIES_PATH = + "/org/apache/beam/sdk/sdk.properties"; private static class LazyInit { - private static final DataflowReleaseInfo INSTANCE = - new DataflowReleaseInfo(DATAFLOW_PROPERTIES_PATH); + private static final ReleaseInfo INSTANCE = + new ReleaseInfo(PROPERTIES_PATH); } /** * Returns an instance of DataflowReleaseInfo. */ - public static DataflowReleaseInfo getReleaseInfo() { + public static ReleaseInfo getReleaseInfo() { return LazyInit.INSTANCE; } - @Key private String name = "Google Cloud Dataflow Java SDK"; + @Key private String name = "Apache Beam SDK for Java"; @Key private String version = "Unknown"; /** Provides the SDK name. */ @@ -61,11 +61,11 @@ public String getVersion() { return version; } - private DataflowReleaseInfo(String resourcePath) { + private ReleaseInfo(String resourcePath) { Properties properties = new Properties(); - InputStream in = DataflowReleaseInfo.class.getResourceAsStream( - DATAFLOW_PROPERTIES_PATH); + InputStream in = ReleaseInfo.class.getResourceAsStream( + PROPERTIES_PATH); if (in == null) { LOG.warn("Dataflow properties resource not found: {}", resourcePath); return; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/worker/StateSampler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/worker/StateSampler.java deleted file mode 100644 index ee95260c9cc9..000000000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/worker/StateSampler.java +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.util.common.worker; - -import org.apache.beam.sdk.util.common.Counter; -import org.apache.beam.sdk.util.common.CounterSet; - -import com.google.common.util.concurrent.ThreadFactoryBuilder; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; - -import javax.annotation.concurrent.ThreadSafe; - -/** - * A StateSampler object may be used to obtain an approximate - * breakdown of the time spent by an execution context in various - * states, as a fraction of the total time. The sampling is taken at - * regular intervals, with adjustment for scheduling delay. - */ -@ThreadSafe -public class StateSampler implements AutoCloseable { - - /** Different kinds of states. */ - public enum StateKind { - /** IO, user code, etc. */ - USER, - /** Reading/writing from/to shuffle service, etc. */ - FRAMEWORK - } - - public static final long DEFAULT_SAMPLING_PERIOD_MS = 200; - - private final String prefix; - private final CounterSet.AddCounterMutator counterSetMutator; - - /** Array of counters indexed by their state. */ - private ArrayList> countersByState = new ArrayList<>(); - - /** Map of state name to state. */ - private Map statesByName = new HashMap<>(); - - /** Map of state id to kind. */ - private Map kindsByState = new HashMap<>(); - - /** The current state. */ - private volatile int currentState; - - /** Special value of {@code currentState} that means we do not sample. */ - public static final int DO_NOT_SAMPLE = -1; - - /** - * A counter that increments with each state transition. May be used - * to detect a context being stuck in a state for some amount of - * time. - */ - private volatile long stateTransitionCount; - - /** - * The timestamp (in nanoseconds) corresponding to the last time the - * state was sampled (and recorded). - */ - private long stateTimestampNs = 0; - - /** Using a fixed number of timers for all StateSampler objects. */ - private static final int NUM_EXECUTOR_THREADS = 16; - - private static final ScheduledExecutorService executorService = - Executors.newScheduledThreadPool(NUM_EXECUTOR_THREADS, - new ThreadFactoryBuilder().setDaemon(true).build()); - - private Random rand = new Random(); - - private List callbacks = new ArrayList<>(); - - private ScheduledFuture invocationTriggerFuture = null; - - private ScheduledFuture invocationFuture = null; - - /** - * Constructs a new {@link StateSampler} that can be used to obtain - * an approximate breakdown of the time spent by an execution - * context in various states, as a fraction of the total time. - * - * @param prefix the prefix of the counter names for the states - * @param counterSetMutator the {@link CounterSet.AddCounterMutator} - * used to create a counter for each distinct state - * @param samplingPeriodMs the sampling period in milliseconds - */ - public StateSampler(String prefix, - CounterSet.AddCounterMutator counterSetMutator, - final long samplingPeriodMs) { - this.prefix = prefix; - this.counterSetMutator = counterSetMutator; - currentState = DO_NOT_SAMPLE; - scheduleSampling(samplingPeriodMs); - } - - /** - * Constructs a new {@link StateSampler} that can be used to obtain - * an approximate breakdown of the time spent by an execution - * context in various states, as a fraction of the total time. - * - * @param prefix the prefix of the counter names for the states - * @param counterSetMutator the {@link CounterSet.AddCounterMutator} - * used to create a counter for each distinct state - */ - public StateSampler(String prefix, - CounterSet.AddCounterMutator counterSetMutator) { - this(prefix, counterSetMutator, DEFAULT_SAMPLING_PERIOD_MS); - } - - /** - * Called by the constructor to schedule sampling at the given period. - * - *

Should not be overridden by sub-classes unless they want to change - * or disable the automatic sampling of state. - */ - protected void scheduleSampling(final long samplingPeriodMs) { - // Here "stratified sampling" is used, which makes sure that there's 1 uniformly chosen sampled - // point in every bucket of samplingPeriodMs, to prevent pathological behavior in case some - // states happen to occur at a similar period. - // The current implementation uses a fixed-rate timer with a period samplingPeriodMs as a - // trampoline to a one-shot random timer which fires with a random delay within - // samplingPeriodMs. - stateTimestampNs = System.nanoTime(); - invocationTriggerFuture = - executorService.scheduleAtFixedRate( - new Runnable() { - @Override - public void run() { - long delay = rand.nextInt((int) samplingPeriodMs); - synchronized (StateSampler.this) { - if (invocationFuture != null) { - invocationFuture.cancel(false); - } - invocationFuture = - executorService.schedule( - new Runnable() { - @Override - public void run() { - StateSampler.this.run(); - } - }, - delay, - TimeUnit.MILLISECONDS); - } - } - }, - 0, - samplingPeriodMs, - TimeUnit.MILLISECONDS); - } - - public synchronized void run() { - long startTimestampNs = System.nanoTime(); - int state = currentState; - if (state != DO_NOT_SAMPLE) { - StateKind kind = null; - long elapsedMs = TimeUnit.NANOSECONDS.toMillis(startTimestampNs - stateTimestampNs); - kind = kindsByState.get(state); - countersByState.get(state).addValue(elapsedMs); - // Invoke all callbacks. - for (SamplingCallback c : callbacks) { - c.run(state, kind, elapsedMs); - } - } - stateTimestampNs = startTimestampNs; - } - - @Override - public synchronized void close() { - currentState = DO_NOT_SAMPLE; - if (invocationTriggerFuture != null) { - invocationTriggerFuture.cancel(false); - } - if (invocationFuture != null) { - invocationFuture.cancel(false); - } - } - - /** - * Returns the state associated with a name; creating a new state if - * necessary. Using states instead of state names during state - * transitions is done for efficiency. - * - * @name the name for the state - * @kind kind of the state, see {#code StateKind} - * @return the state associated with the state name - */ - public int stateForName(String name, StateKind kind) { - if (name.isEmpty()) { - return DO_NOT_SAMPLE; - } - - synchronized (this) { - Integer state = statesByName.get(name); - if (state == null) { - String counterName = prefix + name + "-msecs"; - Counter counter = counterSetMutator.addCounter( - Counter.longs(counterName, Counter.AggregationKind.SUM)); - state = countersByState.size(); - statesByName.put(name, state); - countersByState.add(counter); - kindsByState.put(state, kind); - } - StateKind originalKind = kindsByState.get(state); - if (originalKind != kind) { - throw new IllegalArgumentException( - "for state named " + name - + ", requested kind " + kind + " different from the original kind " + originalKind); - } - return state; - } - } - - /** - * An internal class for representing StateSampler information - * typically used for debugging. - */ - public static class StateSamplerInfo { - public final String state; - public final Long transitionCount; - public final Long stateDurationMillis; - - public StateSamplerInfo(String state, Long transitionCount, - Long stateDurationMillis) { - this.state = state; - this.transitionCount = transitionCount; - this.stateDurationMillis = stateDurationMillis; - } - } - - /** - * Returns information about the current state of this state sampler - * into a {@link StateSamplerInfo} object, or null if sampling is - * not turned on. - * - * @return information about this state sampler or null if sampling is off - */ - public synchronized StateSamplerInfo getInfo() { - return currentState == DO_NOT_SAMPLE ? null - : new StateSamplerInfo(countersByState.get(currentState).getFlatName(), - stateTransitionCount, null); - } - - /** - * Returns the current state of this state sampler. - */ - public int getCurrentState() { - return currentState; - } - - /** - * Sets the current thread state. - * - * @param state the new state to transition to - * @return the previous state - */ - public int setState(int state) { - // Updates to stateTransitionCount are always done by the same - // thread, making the non-atomic volatile update below safe. The - // count is updated first to avoid incorrectly attributing - // stuckness occuring in an old state to the new state. - long previousStateTransitionCount = this.stateTransitionCount; - this.stateTransitionCount = previousStateTransitionCount + 1; - int previousState = currentState; - currentState = state; - return previousState; - } - - /** - * Sets the current thread state. - * - * @param name the name of the new state to transition to - * @param kind kind of the new state - * @return the previous state - */ - public int setState(String name, StateKind kind) { - return setState(stateForName(name, kind)); - } - - /** - * Returns an AutoCloseable {@link ScopedState} that will perform a - * state transition to the given state, and will automatically reset - * the state to the prior state upon closing. - * - * @param state the new state to transition to - * @return a {@link ScopedState} that automatically resets the state - * to the prior state - */ - public ScopedState scopedState(int state) { - return new ScopedState(this, setState(state)); - } - - /** - * Add a callback to the sampler. - * The callbacks will be executed sequentially upon {@link StateSampler#run}. - */ - public synchronized void addSamplingCallback(SamplingCallback callback) { - callbacks.add(callback); - } - - /** Get the counter prefix associated with this sampler. */ - public String getPrefix() { - return prefix; - } - - /** - * A nested class that is used to account for states and state - * transitions based on lexical scopes. - * - *

Thread-safe. - */ - public class ScopedState implements AutoCloseable { - private StateSampler sampler; - private int previousState; - - private ScopedState(StateSampler sampler, int previousState) { - this.sampler = sampler; - this.previousState = previousState; - } - - @Override - public void close() { - sampler.setState(previousState); - } - } - - /** - * Callbacks which supposed to be called sequentially upon {@link StateSampler#run}. - * They should be registered via {@link #addSamplingCallback}. - */ - public static interface SamplingCallback { - /** - * The entrance method of the callback, it is called in {@link StateSampler#run}, - * once per sample. This method should be thread safe. - * - * @param state The state of the StateSampler at the time of sample. - * @param kind The kind associated with the state, see {@link StateKind}. - * @param elapsedMs Milliseconds since last sample. - */ - public void run(int state, StateKind kind, long elapsedMs); - } -} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java index 7690d2ba88dc..e4eb2048be20 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java @@ -158,7 +158,8 @@ public void visitTransform(TransformTreeNode node) { // Pick is a composite, should not be visited here. assertThat(transform, not(instanceOf(Sample.SampleAny.class))); assertThat(transform, not(instanceOf(Write.Bound.class))); - if (transform instanceof Read.Bounded) { + if (transform instanceof Read.Bounded + && node.getEnclosingNode().getTransform() instanceof TextIO.Read.Bound) { assertTrue(visited.add(TransformsSeen.READ)); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java index 85c43226f5f2..8ed26843cf12 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/EncodabilityEnforcementFactoryTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.Matchers.isA; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle; @@ -54,16 +55,9 @@ public class EncodabilityEnforcementFactoryTest { @Test public void encodeFailsThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordNoEncodeCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + + ModelEnforcement enforcement = createEnforcement(new RecordNoEncodeCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); @@ -73,16 +67,9 @@ public void encodeFailsThrows() { @Test public void decodeFailsThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordNoDecodeCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = createEnforcement(new RecordNoDecodeCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); @@ -92,12 +79,6 @@ public void decodeFailsThrows() { @Test public void consistentWithEqualsStructuralValueNotEqualThrows() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply(Create.of(new Record()).withCoder(new RecordStructuralValueCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - WindowedValue record = WindowedValue.valueInGlobalWindow( new Record() { @@ -107,9 +88,8 @@ public String toString() { } }); - CommittedBundle input = - bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = + createEnforcement(new RecordStructuralValueCoder(), record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(IllegalArgumentException.class)); @@ -143,6 +123,17 @@ public void notConsistentWithEqualsStructuralValueNotEqualSucceeds() { Collections.>emptyList()); } + private ModelEnforcement createEnforcement(Coder coder, WindowedValue record) { + TestPipeline p = TestPipeline.create(); + PCollection unencodable = p.apply(Create.of().withCoder(coder)); + AppliedPTransform consumer = + unencodable.apply(Count.globally()).getProducingTransformInternal(); + CommittedBundle input = + bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); + ModelEnforcement enforcement = factory.forBundle(input, consumer); + return enforcement; + } + @Test public void structurallyEqualResultsSucceeds() { TestPipeline p = TestPipeline.create(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java deleted file mode 100644 index 5c63af1c8e97..000000000000 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessCreateTest.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners.inprocess; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.runners.inprocess.InProcessCreate.InMemorySource; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.RunnableOnService; -import org.apache.beam.sdk.testing.SourceTestUtils; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.collect.ImmutableList; - -import org.hamcrest.Matchers; -import org.junit.Rule; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -/** - * Tests for {@link InProcessCreate}. - */ -@RunWith(JUnit4.class) -public class InProcessCreateTest { - @Rule - public ExpectedException thrown = ExpectedException.none(); - - @Test - @Category(RunnableOnService.class) - public void testConvertsCreate() { - TestPipeline p = TestPipeline.create(); - Create.Values og = Create.of(1, 2, 3); - - InProcessCreate converted = InProcessCreate.from(og); - - PAssert.that(p.apply(converted)).containsInAnyOrder(2, 1, 3); - - p.run(); - } - - @Test - @Category(RunnableOnService.class) - public void testConvertsCreateWithNullElements() { - Create.Values og = - Create.of("foo", null, "spam", "ham", null, "eggs") - .withCoder(NullableCoder.of(StringUtf8Coder.of())); - - InProcessCreate converted = InProcessCreate.from(og); - TestPipeline p = TestPipeline.create(); - - PAssert.that(p.apply(converted)) - .containsInAnyOrder(null, "foo", null, "spam", "ham", "eggs"); - - p.run(); - } - - static class Record implements Serializable {} - - static class Record2 extends Record {} - - @Test - public void testThrowsIllegalArgumentWhenCannotInferCoder() { - Create.Values og = Create.of(new Record(), new Record2()); - InProcessCreate converted = InProcessCreate.from(og); - - Pipeline p = TestPipeline.create(); - - // Create won't infer a default coder in this case. - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage(Matchers.containsString("Unable to infer a coder")); - - PCollection c = p.apply(converted); - p.run(); - - fail("Unexpectedly Inferred Coder " + c.getCoder()); - } - - /** - * An unserializable class to demonstrate encoding of elements. - */ - private static class UnserializableRecord { - private final String myString; - - private UnserializableRecord(String myString) { - this.myString = myString; - } - - @Override - public int hashCode() { - return myString.hashCode(); - } - - @Override - public boolean equals(Object o) { - return myString.equals(((UnserializableRecord) o).myString); - } - - static class UnserializableRecordCoder extends AtomicCoder { - private final Coder stringCoder = StringUtf8Coder.of(); - - @Override - public void encode( - UnserializableRecord value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - stringCoder.encode(value.myString, outStream, context.nested()); - } - - @Override - public UnserializableRecord decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - return new UnserializableRecord(stringCoder.decode(inStream, context.nested())); - } - } - } - - @Test - @Category(RunnableOnService.class) - public void testConvertsUnserializableElements() throws Exception { - List elements = - ImmutableList.of( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - InProcessCreate create = - InProcessCreate.from( - Create.of(elements).withCoder(new UnserializableRecord.UnserializableRecordCoder())); - - TestPipeline p = TestPipeline.create(); - PAssert.that(p.apply(create)) - .containsInAnyOrder( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - p.run(); - } - - @Test - public void testSerializableOnUnserializableElements() throws Exception { - List elements = - ImmutableList.of( - new UnserializableRecord("foo"), - new UnserializableRecord("bar"), - new UnserializableRecord("baz")); - InMemorySource source = - InMemorySource.fromIterable(elements, new UnserializableRecord.UnserializableRecordCoder()); - SerializableUtils.ensureSerializable(source); - } - - @Test - public void testSplitIntoBundles() throws Exception { - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable( - ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), BigEndianIntegerCoder.of()); - PipelineOptions options = PipelineOptionsFactory.create(); - List> splitSources = source.splitIntoBundles(12, options); - assertThat(splitSources, hasSize(3)); - SourceTestUtils.assertSourcesEqualReferenceSource(source, splitSources, options); - } - - @Test - public void testDoesNotProduceSortedKeys() throws Exception { - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(ImmutableList.of("spam", "ham", "eggs"), StringUtf8Coder.of()); - assertThat(source.producesSortedKeys(PipelineOptionsFactory.create()), is(false)); - } - - @Test - public void testGetDefaultOutputCoderReturnsConstructorCoder() throws Exception { - Coder coder = VarIntCoder.of(); - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder); - - Coder defaultCoder = source.getDefaultOutputCoder(); - assertThat(defaultCoder, equalTo(coder)); - } - - @Test - public void testSplitAtFraction() throws Exception { - List elements = new ArrayList<>(); - Random random = new Random(); - for (int i = 0; i < 25; i++) { - elements.add(random.nextInt()); - } - InProcessCreate.InMemorySource source = - InMemorySource.fromIterable(elements, VarIntCoder.of()); - - SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); - } -} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessEvaluationContextTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessEvaluationContextTest.java index 50b83fda21f7..ee56954dbf08 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessEvaluationContextTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/InProcessEvaluationContextTest.java @@ -160,7 +160,7 @@ public void getExecutionContextSameStepSameKeyState() { StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); - InProcessStepContext stepContext = fooContext.getOrCreateStepContext("s1", "s1", null); + InProcessStepContext stepContext = fooContext.getOrCreateStepContext("s1", "s1"); stepContext.stateInternals().state(StateNamespaces.global(), intBag).add(1); context.handleResult( @@ -176,7 +176,7 @@ public void getExecutionContextSameStepSameKeyState() { context.getExecutionContext(created.getProducingTransformInternal(), "foo"); assertThat( secondFooContext - .getOrCreateStepContext("s1", "s1", null) + .getOrCreateStepContext("s1", "s1") .stateInternals() .state(StateNamespaces.global(), intBag) .read(), @@ -192,7 +192,7 @@ public void getExecutionContextDifferentKeysIndependentState() { StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); fooContext - .getOrCreateStepContext("s1", "s1", null) + .getOrCreateStepContext("s1", "s1") .stateInternals() .state(StateNamespaces.global(), intBag) .add(1); @@ -202,7 +202,7 @@ public void getExecutionContextDifferentKeysIndependentState() { assertThat(barContext, not(equalTo(fooContext))); assertThat( barContext - .getOrCreateStepContext("s1", "s1", null) + .getOrCreateStepContext("s1", "s1") .stateInternals() .state(StateNamespaces.global(), intBag) .read(), @@ -218,7 +218,7 @@ public void getExecutionContextDifferentStepsIndependentState() { StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); fooContext - .getOrCreateStepContext("s1", "s1", null) + .getOrCreateStepContext("s1", "s1") .stateInternals() .state(StateNamespaces.global(), intBag) .add(1); @@ -227,7 +227,7 @@ public void getExecutionContextDifferentStepsIndependentState() { context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); assertThat( barContext - .getOrCreateStepContext("s1", "s1", null) + .getOrCreateStepContext("s1", "s1") .stateInternals() .state(StateNamespaces.global(), intBag) .read(), @@ -273,7 +273,7 @@ public void handleResultStoresState() { StateTag> intBag = StateTags.bag("myBag", VarIntCoder.of()); CopyOnAccessInMemoryStateInternals state = - fooContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + fooContext.getOrCreateStepContext("s1", "s1").stateInternals(); BagState bag = state.state(StateNamespaces.global(), intBag); bag.add(1); bag.add(2); @@ -293,7 +293,7 @@ public void handleResultStoresState() { context.getExecutionContext(downstream.getProducingTransformInternal(), myKey); CopyOnAccessInMemoryStateInternals afterResultState = - afterResultContext.getOrCreateStepContext("s1", "s1", null).stateInternals(); + afterResultContext.getOrCreateStepContext("s1", "s1").stateInternals(); assertThat(afterResultState.state(StateNamespaces.global(), intBag).read(), contains(1, 2, 4)); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java index 393fedec80c6..2998489d4733 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java @@ -22,19 +22,36 @@ import static org.apache.beam.sdk.TestUtils.NO_LINES; import static org.apache.beam.sdk.TestUtils.NO_LINES_ARRAY; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create.Values.CreateSource; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; +import com.google.common.collect.ImmutableList; + import org.hamcrest.Matchers; import org.joda.time.Instant; import org.junit.Rule; @@ -44,11 +61,15 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Random; /** * Tests for Create. @@ -142,6 +163,67 @@ public void testCreateParameterizedType() throws Exception { TimestampedValue.of("a", new Instant(0)), TimestampedValue.of("b", new Instant(0))); } + /** + * An unserializable class to demonstrate encoding of elements. + */ + private static class UnserializableRecord { + private final String myString; + + private UnserializableRecord(String myString) { + this.myString = myString; + } + + @Override + public int hashCode() { + return myString.hashCode(); + } + + @Override + public boolean equals(Object o) { + return myString.equals(((UnserializableRecord) o).myString); + } + + static class UnserializableRecordCoder extends AtomicCoder { + private final Coder stringCoder = StringUtf8Coder.of(); + + @Override + public void encode( + UnserializableRecord value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + stringCoder.encode(value.myString, outStream, context.nested()); + } + + @Override + public UnserializableRecord decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new UnserializableRecord(stringCoder.decode(inStream, context.nested())); + } + } + } + + @Test + @Category(RunnableOnService.class) + public void testCreateWithUnserializableElements() throws Exception { + List elements = + ImmutableList.of( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + Create.Values create = + Create.of(elements).withCoder(new UnserializableRecord.UnserializableRecordCoder()); + + TestPipeline p = TestPipeline.create(); + PAssert.that(p.apply(create)) + .containsInAnyOrder( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + p.run(); + } + private static class PrintTimestamps extends DoFn { @Override public void processElement(ProcessContext c) { @@ -239,4 +321,56 @@ public void testCreateGetName() { assertEquals("Create.Values", Create.of(1, 2, 3).getName()); assertEquals("Create.TimestampedValues", Create.timestamped(Collections.EMPTY_LIST).getName()); } + + @Test + public void testSourceIsSerializableWithUnserializableElements() throws Exception { + List elements = + ImmutableList.of( + new UnserializableRecord("foo"), + new UnserializableRecord("bar"), + new UnserializableRecord("baz")); + CreateSource source = + CreateSource.fromIterable(elements, new UnserializableRecord.UnserializableRecordCoder()); + SerializableUtils.ensureSerializable(source); + } + + @Test + public void testSourceSplitIntoBundles() throws Exception { + CreateSource source = + CreateSource.fromIterable( + ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), BigEndianIntegerCoder.of()); + PipelineOptions options = PipelineOptionsFactory.create(); + List> splitSources = source.splitIntoBundles(12, options); + assertThat(splitSources, hasSize(3)); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splitSources, options); + } + + @Test + public void testSourceDoesNotProduceSortedKeys() throws Exception { + CreateSource source = + CreateSource.fromIterable(ImmutableList.of("spam", "ham", "eggs"), StringUtf8Coder.of()); + assertThat(source.producesSortedKeys(PipelineOptionsFactory.create()), is(false)); + } + + @Test + public void testSourceGetDefaultOutputCoderReturnsConstructorCoder() throws Exception { + Coder coder = VarIntCoder.of(); + CreateSource source = + CreateSource.fromIterable(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder); + + Coder defaultCoder = source.getDefaultOutputCoder(); + assertThat(defaultCoder, equalTo(coder)); + } + + @Test + public void testSourceSplitAtFraction() throws Exception { + List elements = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < 25; i++) { + elements.add(random.nextInt()); + } + CreateSource source = CreateSource.fromIterable(elements, VarIntCoder.of()); + + SourceTestUtils.assertSplitAtFractionExhaustive(source, PipelineOptionsFactory.create()); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/GroupAlsoByWindowsProperties.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/GroupAlsoByWindowsProperties.java index d21edd16357e..d5aa0daffd40 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/GroupAlsoByWindowsProperties.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/GroupAlsoByWindowsProperties.java @@ -708,7 +708,7 @@ List>> runGABW( outputManager, outputTag, new ArrayList>(), - executionContext.getOrCreateStepContext("GABWStep", "GABWTransform", null), + executionContext.getOrCreateStepContext("GABWStep", "GABWTransform"), counters.getAddCounterMutator(), windowingStrategy); } diff --git a/sdks/java/io/kafka/pom.xml b/sdks/java/io/kafka/pom.xml new file mode 100644 index 000000000000..98a091d51faf --- /dev/null +++ b/sdks/java/io/kafka/pom.xml @@ -0,0 +1,104 @@ + + + + 4.0.0 + + + org.apache.beam + io-parent + 0.1.0-incubating-SNAPSHOT + ../pom.xml + + + kafka + Apache Beam :: SDKs :: Java :: IO :: Kafka + Library to read Kafka topics. + jar + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.12 + + + com.puppycrawl.tools + checkstyle + 6.6 + + + + ../../checkstyle.xml + true + true + + + + + check + + + + + + + + + + org.apache.beam + java-sdk-all + ${project.version} + + + + org.apache.kafka + kafka-clients + [0.9,) + + + + + org.hamcrest + hamcrest-all + ${hamcrest.version} + test + + + + junit + junit + ${junit.version} + test + + + + org.slf4j + slf4j-jdk14 + ${slf4j.version} + test + + + diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java new file mode 100644 index 000000000000..4b6b976fa54d --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCheckpointMark.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.UnboundedSource; + +import org.apache.kafka.common.TopicPartition; + +import java.io.IOException; +import java.io.Serializable; +import java.util.List; + +/** + * Checkpoint for an unbounded KafkaIO.Read. Consists of Kafka topic name, partition id, + * and the latest offset consumed so far. + */ +@DefaultCoder(SerializableCoder.class) +public class KafkaCheckpointMark implements UnboundedSource.CheckpointMark, Serializable { + + private final List partitions; + + public KafkaCheckpointMark(List partitions) { + this.partitions = partitions; + } + + public List getPartitions() { + return partitions; + } + + @Override + public void finalizeCheckpoint() throws IOException { + /* nothing to do */ + + // We might want to support committing offset in Kafka for better resume point when the job + // is restarted (checkpoint is not available for job restarts). + } + + /** + * A tuple to hold topic, partition, and offset that comprise the checkpoint + * for a single partition. + */ + public static class PartitionMark implements Serializable { + private final TopicPartition topicPartition; + private final long offset; + + public PartitionMark(TopicPartition topicPartition, long offset) { + this.topicPartition = topicPartition; + this.offset = offset; + } + + public TopicPartition getTopicPartition() { + return topicPartition; + } + + public long getOffset() { + return offset; + } + } +} + diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java new file mode 100644 index 000000000000..e6053116fbe5 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.Read.Unbounded; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; +import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.io.kafka.KafkaCheckpointMark.PartitionMark; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.ExposedByteArrayInputStream; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Throwables; +import com.google.common.collect.ComparisonChain; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nullable; + +/** + * An unbounded source for Kafka topics. Kafka version 0.9 + * and above are supported. + * + *

Reading from Kafka topics

+ * + *

KafkaIO source returns unbounded collection of Kafka records as + * {@code PCollection>}. A {@link KafkaRecord} includes basic + * metadata like topic-partition and offset, along with key and value associated with a Kafka + * record. + * + *

Although most applications consumer single topic, the source can be configured to consume + * multiple topics or even a specific set of {@link TopicPartition}s. + * + *

To configure a Kafka source, you must specify at the minimum Kafka bootstrapServers + * and one or more topics to consume. The following example illustrates various options for + * configuring the source : + * + *

{@code
+ *
+ *  pipeline
+ *    .apply(KafkaIO.read()
+ *       .withBootstrapServers("broker_1:9092,broker_2:9092")
+ *       .withTopics(ImmutableList.of("topic_a", "topic_b"))
+ *       // above two are required configuration. returns PCollection
+ *
+ *       // rest of the settings are optional :
+ *
+ *       // set a Coder for Key and Value (note the change to return type)
+ *       .withKeyCoder(BigEndianLongCoder.of()) // PCollection
+ *       .withValueCoder(StringUtf8Coder.of())  // PCollection
+ *
+ *       // you can further customize KafkaConsumer used to read the records by adding more
+ *       // settings for ConsumerConfig. e.g :
+ *       .updateConsumerProperties(ImmutableMap.of("receive.buffer.bytes", 1024 * 1024))
+ *
+ *       // custom function for calculating record timestamp (default is processing time)
+ *       .withTimestampFn(new MyTypestampFunction())
+ *
+ *       // custom function for watermark (default is record timestamp)
+ *       .withWatermarkFn(new MyWatermarkFunction())
+ *
+ *       // finally, if you don't need Kafka metadata, you can drop it
+ *       .withoutMetadata() // PCollection>
+ *    )
+ *    .apply(Values.create()) // PCollection
+ *     ...
+ * }
+ * + *

Partition Assignment and Checkpointing

+ * The Kafka partitions are evenly distributed among splits (workers). + * Dataflow checkpointing is fully supported and + * each split can resume from previous checkpoint. See + * {@link UnboundedKafkaSource#generateInitialSplits(int, PipelineOptions)} for more details on + * splits and checkpoint support. + * + *

When the pipeline starts for the first time without any checkpoint, the source starts + * consuming from the latest offsets. You can override this behavior to consume from the + * beginning by setting appropriate appropriate properties in {@link ConsumerConfig}, through + * {@link Read#updateConsumerProperties(Map)}. + * + *

Advanced Kafka Configuration

+ * KafakIO allows setting most of the properties in {@link ConsumerConfig}. E.g. if you would like + * to enable offset auto commit (for external monitoring or other purposes), you can set + * "group.id", "enable.auto.commit", etc. + */ +public class KafkaIO { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaIO.class); + + private static class NowTimestampFn implements SerializableFunction { + @Override + public Instant apply(T input) { + return Instant.now(); + } + } + + + /** + * Creates and uninitialized {@link Read} {@link PTransform}. Before use, basic Kafka + * configuration should set with {@link Read#withBootstrapServers(String)} and + * {@link Read#withTopics(List)}. Other optional settings include key and value coders, + * custom timestamp and watermark functions. + */ + public static Read read() { + return new Read( + new ArrayList(), + new ArrayList(), + ByteArrayCoder.of(), + ByteArrayCoder.of(), + Read.KAFKA_9_CONSUMER_FACTORY_FN, + Read.DEFAULT_CONSUMER_PROPERTIES, + Long.MAX_VALUE, + null); + } + + /** + * A {@link PTransform} to read from Kafka topics. See {@link KafkaIO} for more + * information on usage and configuration. + */ + public static class Read extends TypedRead { + + /** + * Returns a new {@link Read} with Kafka consumer pointing to {@code bootstrapServers}. + */ + public Read withBootstrapServers(String bootstrapServers) { + return updateConsumerProperties( + ImmutableMap.of( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers)); + } + + /** + * Returns a new {@link Read} that reads from the topics. All the partitions are from each + * of the topics is read. + * See {@link UnboundedKafkaSource#generateInitialSplits(int, PipelineOptions)} for description + * of how the partitions are distributed among the splits. + */ + public Read withTopics(List topics) { + checkState(topicPartitions.isEmpty(), "Only topics or topicPartitions can be set, not both"); + + return new Read(ImmutableList.copyOf(topics), topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * Returns a new {@link Read} that reads from the partitions. This allows reading only a subset + * of partitions for one or more topics when (if ever) needed. + * See {@link UnboundedKafkaSource#generateInitialSplits(int, PipelineOptions)} for description + * of how the partitions are distributed among the splits. + */ + public Read withTopicPartitions(List topicPartitions) { + checkState(topics.isEmpty(), "Only topics or topicPartitions can be set, not both"); + + return new Read(topics, ImmutableList.copyOf(topicPartitions), keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * Returns a new {@link Read} with {@link Coder} for key bytes. + */ + public Read withKeyCoder(Coder keyCoder) { + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * Returns a new {@link Read} with {@link Coder} for value bytes. + */ + public Read withValueCoder(Coder valueCoder) { + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * A factory to create Kafka {@link Consumer} from consumer configuration. + * This is useful for supporting another version of Kafka consumer. + * Default is {@link KafkaConsumer}. + */ + public Read withConsumerFactoryFn( + SerializableFunction, Consumer> consumerFactoryFn) { + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * Update consumer configuration with new properties. + */ + public Read updateConsumerProperties(Map configUpdates) { + for (String key : configUpdates.keySet()) { + checkArgument(!IGNORED_CONSUMER_PROPERTIES.containsKey(key), + "No need to configure '%s'. %s", key, IGNORED_CONSUMER_PROPERTIES.get(key)); + } + + Map config = new HashMap<>(consumerConfig); + config.putAll(configUpdates); + + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, config, maxNumRecords, maxReadTime); + } + + /** + * Similar to {@link org.apache.beam.sdk.io.Read.Unbounded#withMaxNumRecords(long)}. + * Mainly used for tests and demo applications. + */ + public Read withMaxNumRecords(long maxNumRecords) { + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, maxNumRecords, null); + } + + /** + * Similar to + * {@link org.apache.beam.sdk.io.Read.Unbounded#withMaxReadTime(Duration)}. + * Mainly used for tests and demo + * applications. + */ + public Read withMaxReadTime(Duration maxReadTime) { + return new Read(topics, topicPartitions, keyCoder, valueCoder, + consumerFactoryFn, consumerConfig, Long.MAX_VALUE, maxReadTime); + } + + /////////////////////////////////////////////////////////////////////////////////////// + + private Read( + List topics, + List topicPartitions, + Coder keyCoder, + Coder valueCoder, + SerializableFunction, Consumer> consumerFactoryFn, + Map consumerConfig, + long maxNumRecords, + @Nullable Duration maxReadTime) { + + super(topics, topicPartitions, keyCoder, valueCoder, null, null, + consumerFactoryFn, consumerConfig, maxNumRecords, maxReadTime); + } + + /** + * A set of properties that are not required or don't make sense for our consumer. + */ + private static final Map IGNORED_CONSUMER_PROPERTIES = ImmutableMap.of( + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyDecoderFn instead", + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueDecoderFn instead" + // "group.id", "enable.auto.commit", "auto.commit.interval.ms" : + // lets allow these, applications can have better resume point for restarts. + ); + + // set config defaults + private static final Map DEFAULT_CONSUMER_PROPERTIES = + ImmutableMap.of( + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName(), + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName(), + + // Use large receive buffer. Once KAFKA-3135 is fixed, this _may_ not be required. + // with default value of of 32K, It takes multiple seconds between successful polls. + // All the consumer work is done inside poll(), with smaller send buffer size, it + // takes many polls before a 1MB chunk from the server is fully read. In my testing + // about half of the time select() inside kafka consumer waited for 20-30ms, though + // the server had lots of data in tcp send buffers on its side. Compared to default, + // this setting increased throughput increased by many fold (3-4x). + ConsumerConfig.RECEIVE_BUFFER_CONFIG, 512 * 1024, + + // default to latest offset when we are not resuming. + ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest", + // disable auto commit of offsets. we don't require group_id. could be enabled by user. + ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + // default Kafka 0.9 Consumer supplier. + private static final SerializableFunction, Consumer> + KAFKA_9_CONSUMER_FACTORY_FN = + new SerializableFunction, Consumer>() { + public Consumer apply(Map config) { + return new KafkaConsumer<>(config); + } + }; + } + + /** + * A {@link PTransform} to read from Kafka topics. See {@link KafkaIO} for more + * information on usage and configuration. + */ + public static class TypedRead + extends PTransform>> { + + /** + * A function to assign a timestamp to a record. Default is processing timestamp. + */ + public TypedRead withTimestampFn2( + SerializableFunction, Instant> timestampFn) { + checkNotNull(timestampFn); + return new TypedRead(topics, topicPartitions, keyCoder, valueCoder, + timestampFn, watermarkFn, consumerFactoryFn, consumerConfig, + maxNumRecords, maxReadTime); + } + + /** + * A function to calculate watermark after a record. Default is last record timestamp + * @see #withTimestampFn(SerializableFunction) + */ + public TypedRead withWatermarkFn2( + SerializableFunction, Instant> watermarkFn) { + checkNotNull(watermarkFn); + return new TypedRead(topics, topicPartitions, keyCoder, valueCoder, + timestampFn, watermarkFn, consumerFactoryFn, consumerConfig, + maxNumRecords, maxReadTime); + } + + /** + * A function to assign a timestamp to a record. Default is processing timestamp. + */ + public TypedRead withTimestampFn(SerializableFunction, Instant> timestampFn) { + checkNotNull(timestampFn); + return withTimestampFn2(unwrapKafkaAndThen(timestampFn)); + } + + /** + * A function to calculate watermark after a record. Default is last record timestamp + * @see #withTimestampFn(SerializableFunction) + */ + public TypedRead withWatermarkFn(SerializableFunction, Instant> watermarkFn) { + checkNotNull(watermarkFn); + return withWatermarkFn2(unwrapKafkaAndThen(watermarkFn)); + } + + /** + * Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. + */ + public PTransform>> withoutMetadata() { + return new TypedWithoutMetadata(this); + } + + @Override + public PCollection> apply(PBegin input) { + // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set. + Unbounded> unbounded = + org.apache.beam.sdk.io.Read.from(makeSource()); + + PTransform>> transform = unbounded; + + if (maxNumRecords < Long.MAX_VALUE) { + transform = unbounded.withMaxNumRecords(maxNumRecords); + } else if (maxReadTime != null) { + transform = unbounded.withMaxReadTime(maxReadTime); + } + + return input.getPipeline().apply(transform); + } + + //////////////////////////////////////////////////////////////////////////////////////// + + protected final List topics; + protected final List topicPartitions; // mutually exclusive with topics + protected final Coder keyCoder; + protected final Coder valueCoder; + @Nullable protected final SerializableFunction, Instant> timestampFn; + @Nullable protected final SerializableFunction, Instant> watermarkFn; + protected final + SerializableFunction, Consumer> consumerFactoryFn; + protected final Map consumerConfig; + protected final long maxNumRecords; // bounded read, mainly for testing + protected final Duration maxReadTime; // bounded read, mainly for testing + + private TypedRead(List topics, + List topicPartitions, + Coder keyCoder, + Coder valueCoder, + @Nullable SerializableFunction, Instant> timestampFn, + @Nullable SerializableFunction, Instant> watermarkFn, + SerializableFunction, Consumer> consumerFactoryFn, + Map consumerConfig, + long maxNumRecords, + @Nullable Duration maxReadTime) { + super("KafkaIO.Read"); + + this.topics = topics; + this.topicPartitions = topicPartitions; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + this.timestampFn = timestampFn; + this.watermarkFn = watermarkFn; + this.consumerFactoryFn = consumerFactoryFn; + this.consumerConfig = consumerConfig; + this.maxNumRecords = maxNumRecords; + this.maxReadTime = maxReadTime; + } + + /** + * Creates an {@link UnboundedSource, ?>} with the configuration in + * {@link TypedRead}. Primary use case is unit tests, should not be used in an + * application. + */ + @VisibleForTesting + UnboundedSource, KafkaCheckpointMark> makeSource() { + return new UnboundedKafkaSource( + -1, + topics, + topicPartitions, + keyCoder, + valueCoder, + timestampFn, + Optional.fromNullable(watermarkFn), + consumerFactoryFn, + consumerConfig); + } + + // utility method to convert KafkRecord to user KV before applying user functions + private static SerializableFunction, OutT> + unwrapKafkaAndThen(final SerializableFunction, OutT> fn) { + return new SerializableFunction, OutT>() { + public OutT apply(KafkaRecord record) { + return fn.apply(record.getKV()); + } + }; + } + } + + /** + * A {@link PTransform} to read from Kafka topics. Similar to {@link KafkaIO.TypedRead}, but + * removes Kafka metatdata and returns a {@link PCollection} of {@link KV}. + * See {@link KafkaIO} for more information on usage and configuration of reader. + */ + public static class TypedWithoutMetadata extends PTransform>> { + + private final TypedRead typedRead; + + TypedWithoutMetadata(TypedRead read) { + super("KafkaIO.Read"); + this.typedRead = read; + } + + @Override + public PCollection> apply(PBegin begin) { + return typedRead + .apply(begin) + .apply("Remove Kafka Metadata", + ParDo.of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext ctx) { + ctx.output(ctx.element().getKV()); + } + })); + } + } + + /** Static class, prevent instantiation. */ + private KafkaIO() {} + + private static class UnboundedKafkaSource + extends UnboundedSource, KafkaCheckpointMark> { + + private final int id; // split id, mainly for debugging + private final List topics; + private final List assignedPartitions; + private final Coder keyCoder; + private final Coder valueCoder; + private final SerializableFunction, Instant> timestampFn; + // would it be a good idea to pass currentTimestamp to watermarkFn? + private final Optional, Instant>> watermarkFn; + private + SerializableFunction, Consumer> consumerFactoryFn; + private final Map consumerConfig; + + public UnboundedKafkaSource( + int id, + List topics, + List assignedPartitions, + Coder keyCoder, + Coder valueCoder, + @Nullable SerializableFunction, Instant> timestampFn, + Optional, Instant>> watermarkFn, + SerializableFunction, Consumer> consumerFactoryFn, + Map consumerConfig) { + + this.id = id; + this.assignedPartitions = assignedPartitions; + this.topics = topics; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + this.timestampFn = + (timestampFn == null ? new NowTimestampFn>() : timestampFn); + this.watermarkFn = watermarkFn; + this.consumerFactoryFn = consumerFactoryFn; + this.consumerConfig = consumerConfig; + } + + /** + * The partitions are evenly distributed among the splits. The number of splits returned is + * {@code min(desiredNumSplits, totalNumPartitions)}, though better not to depend on the exact + * count. + * + *

It is important to assign the partitions deterministically so that we can support + * resuming a split from last checkpoint. The Kafka partitions are sorted by + * {@code } and then assigned to splits in round-robin order. + */ + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { + + List partitions = new ArrayList<>(assignedPartitions); + + // (a) fetch partitions for each topic + // (b) sort by + // (c) round-robin assign the partitions to splits + + if (partitions.isEmpty()) { + try (Consumer consumer = consumerFactoryFn.apply(consumerConfig)) { + for (String topic : topics) { + for (PartitionInfo p : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(p.topic(), p.partition())); + } + } + } + } + + Collections.sort(partitions, new Comparator() { + public int compare(TopicPartition tp1, TopicPartition tp2) { + return ComparisonChain + .start() + .compare(tp1.topic(), tp2.topic()) + .compare(tp1.partition(), tp2.partition()) + .result(); + } + }); + + checkArgument(desiredNumSplits > 0); + checkState(partitions.size() > 0, + "Could not find any partitions. Please check Kafka configuration and topic names"); + + int numSplits = Math.min(desiredNumSplits, partitions.size()); + List> assignments = new ArrayList<>(numSplits); + + for (int i = 0; i < numSplits; i++) { + assignments.add(new ArrayList()); + } + for (int i = 0; i < partitions.size(); i++) { + assignments.get(i % numSplits).add(partitions.get(i)); + } + + List> result = new ArrayList<>(numSplits); + + for (int i = 0; i < numSplits; i++) { + List assignedToSplit = assignments.get(i); + + LOG.info("Partitions assigned to split {} (total {}): {}", + i, assignedToSplit.size(), Joiner.on(",").join(assignedToSplit)); + + result.add(new UnboundedKafkaSource( + i, + this.topics, + assignedToSplit, + this.keyCoder, + this.valueCoder, + this.timestampFn, + this.watermarkFn, + this.consumerFactoryFn, + this.consumerConfig)); + } + + return result; + } + + @Override + public UnboundedKafkaReader createReader(PipelineOptions options, + KafkaCheckpointMark checkpointMark) { + if (assignedPartitions.isEmpty()) { + LOG.warn("Looks like generateSplits() is not called. Generate single split."); + try { + return new UnboundedKafkaReader( + generateInitialSplits(1, options).get(0), checkpointMark); + } catch (Exception e) { + Throwables.propagate(e); + } + } + return new UnboundedKafkaReader(this, checkpointMark); + } + + @Override + public Coder getCheckpointMarkCoder() { + return SerializableCoder.of(KafkaCheckpointMark.class); + } + + @Override + public boolean requiresDeduping() { + // Kafka records are ordered with in partitions. In addition checkpoint guarantees + // records are not consumed twice. + return false; + } + + @Override + public void validate() { + checkNotNull(consumerConfig.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), + "Kafka bootstrap servers should be set"); + checkArgument(topics.size() > 0 || assignedPartitions.size() > 0, + "Kafka topics or topic_partitions are required"); + } + + @Override + public Coder> getDefaultOutputCoder() { + return KafkaRecordCoder.of(keyCoder, valueCoder); + } + } + + private static class UnboundedKafkaReader extends UnboundedReader> { + + private final UnboundedKafkaSource source; + private final String name; + private Consumer consumer; + private final List partitionStates; + private KafkaRecord curRecord; + private Instant curTimestamp; + private Iterator curBatch = Collections.emptyIterator(); + + private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000); + // how long to wait for new records from kafka consumer inside advance() + private static final Duration NEW_RECORDS_POLL_TIMEOUT = Duration.millis(10); + + // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including + // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout + // like 100 milliseconds does not work well. This along with large receive buffer for + // consumer achieved best throughput in tests (see `defaultConsumerProperties`). + private final ExecutorService consumerPollThread = Executors.newSingleThreadExecutor(); + private final SynchronousQueue> availableRecordsQueue = + new SynchronousQueue<>(); + private volatile boolean closed = false; + + // Backlog support : + // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() + // then look at position(). Use another consumer to do this so that the primary consumer does + // not need to be interrupted. The latest offsets are fetched periodically on another thread. + // This is still a hack. There could be unintended side effects, e.g. if user enabled offset + // auto commit in consumer config, this could interfere with the primary consumer (we will + // handle this particular problem). We might have to make this optional. + private Consumer offsetConsumer; + private final ScheduledExecutorService offsetFetcherThread = + Executors.newSingleThreadScheduledExecutor(); + private static final int OFFSET_UPDATE_INTERVAL_SECONDS = 5; + + /** watermark before any records have been read. */ + private static Instant initialWatermark = new Instant(Long.MIN_VALUE); + + public String toString() { + return name; + } + + // maintains state of each assigned partition (buffered records, consumed offset, etc) + private static class PartitionState { + private final TopicPartition topicPartition; + private long consumedOffset; + private long latestOffset; + private Iterator> recordIter = Collections.emptyIterator(); + + // simple moving average for size of each record in bytes + private double avgRecordSize = 0; + private static final int movingAvgWindow = 1000; // very roughly avg of last 1000 elements + + + PartitionState(TopicPartition partition, long offset) { + this.topicPartition = partition; + this.consumedOffset = offset; + this.latestOffset = -1; + } + + // update consumedOffset and avgRecordSize + void recordConsumed(long offset, int size) { + consumedOffset = offset; + + // this is always updated from single thread. probably not worth making it an AtomicDouble + if (avgRecordSize <= 0) { + avgRecordSize = size; + } else { + // initially, first record heavily contributes to average. + avgRecordSize += ((size - avgRecordSize) / movingAvgWindow); + } + } + + synchronized void setLatestOffset(long latestOffset) { + this.latestOffset = latestOffset; + } + + synchronized long approxBacklogInBytes() { + // Note that is an an estimate of uncompressed backlog. + // Messages on Kafka might be comressed. + if (latestOffset < 0 || consumedOffset < 0) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + if (latestOffset <= consumedOffset || consumedOffset < 0) { + return 0; + } + return (long) ((latestOffset - consumedOffset - 1) * avgRecordSize); + } + } + + public UnboundedKafkaReader( + UnboundedKafkaSource source, + @Nullable KafkaCheckpointMark checkpointMark) { + + this.source = source; + this.name = "Reader-" + source.id; + + partitionStates = ImmutableList.copyOf(Lists.transform(source.assignedPartitions, + new Function() { + public PartitionState apply(TopicPartition tp) { + return new PartitionState(tp, -1L); + } + })); + + if (checkpointMark != null) { + // a) verify that assigned and check-pointed partitions match exactly + // b) set consumed offsets + + checkState(checkpointMark.getPartitions().size() == source.assignedPartitions.size(), + "checkPointMark and assignedPartitions should match"); + // we could consider allowing a mismatch, though it is not expected in current Dataflow + + for (int i = 0; i < source.assignedPartitions.size(); i++) { + PartitionMark ckptMark = checkpointMark.getPartitions().get(i); + TopicPartition assigned = source.assignedPartitions.get(i); + + checkState(ckptMark.getTopicPartition().equals(assigned), + "checkpointed partition %s and assigned partition %s don't match", + ckptMark.getTopicPartition(), assigned); + + partitionStates.get(i).consumedOffset = ckptMark.getOffset(); + } + } + } + + private void consumerPollLoop() { + // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue + while (!closed) { + try { + ConsumerRecords records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + if (!records.isEmpty()) { + availableRecordsQueue.put(records); // blocks until dequeued. + } + } catch (InterruptedException e) { + LOG.warn("{}: consumer thread is interrupted", this, e); // not expected + break; + } catch (WakeupException e) { + break; + } + } + + LOG.info("{}: Returning from consumer pool loop", this); + } + + private void nextBatch() { + curBatch = Collections.emptyIterator(); + + ConsumerRecords records; + try { + records = availableRecordsQueue.poll(NEW_RECORDS_POLL_TIMEOUT.getMillis(), + TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + LOG.warn("{}: Unexpected", this, e); + return; + } + + if (records == null) { + return; + } + + List nonEmpty = new LinkedList<>(); + + for (PartitionState p : partitionStates) { + p.recordIter = records.records(p.topicPartition).iterator(); + if (p.recordIter.hasNext()) { + nonEmpty.add(p); + } + } + + // cycle through the partitions in order to interleave records from each. + curBatch = Iterators.cycle(nonEmpty); + } + + @Override + public boolean start() throws IOException { + consumer = source.consumerFactoryFn.apply(source.consumerConfig); + consumer.assign(source.assignedPartitions); + + // seek to consumedOffset + 1 if it is set + for (PartitionState p : partitionStates) { + if (p.consumedOffset >= 0) { + LOG.info("{}: resuming {} at {}", name, p.topicPartition, p.consumedOffset + 1); + consumer.seek(p.topicPartition, p.consumedOffset + 1); + } else { + LOG.info("{}: resuming {} at default offset", name, p.topicPartition); + } + } + + // start consumer read loop. + // Note that consumer is not thread safe, should not accessed out side consumerPollLoop() + consumerPollThread.submit( + new Runnable() { + public void run() { + consumerPollLoop(); + } + }); + + // offsetConsumer setup : + + Object groupId = source.consumerConfig.get(ConsumerConfig.GROUP_ID_CONFIG); + // override group_id and disable auto_commit so that it does not interfere with main consumer + String offsetGroupId = String.format("%s_offset_consumer_%d_%s", name, + (new Random()).nextInt(Integer.MAX_VALUE), (groupId == null ? "none" : groupId)); + Map offsetConsumerConfig = new HashMap<>(source.consumerConfig); + offsetConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, offsetGroupId); + offsetConsumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + offsetConsumer = source.consumerFactoryFn.apply(offsetConsumerConfig); + offsetConsumer.assign(source.assignedPartitions); + + offsetFetcherThread.scheduleAtFixedRate( + new Runnable() { + public void run() { + updateLatestOffsets(); + } + }, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); + + return advance(); + } + + @Override + public boolean advance() throws IOException { + /* Read first record (if any). we need to loop here because : + * - (a) some records initially need to be skipped if they are before consumedOffset + * - (b) if curBatch is empty, we want to fetch next batch and then advance. + * - (c) curBatch is an iterator of iterators. we interleave the records from each. + * curBatch.next() might return an empty iterator. + */ + while (true) { + if (curBatch.hasNext()) { + PartitionState pState = curBatch.next(); + + if (!pState.recordIter.hasNext()) { // -- (c) + pState.recordIter = Collections.emptyIterator(); // drop ref + curBatch.remove(); + continue; + } + + ConsumerRecord rawRecord = pState.recordIter.next(); + long consumed = pState.consumedOffset; + long offset = rawRecord.offset(); + + if (consumed >= 0 && offset <= consumed) { // -- (a) + // this can happen when compression is enabled in Kafka (seems to be fixed in 0.10) + // should we check if the offset is way off from consumedOffset (say > 1M)? + LOG.warn("{}: ignoring already consumed offset {} for {}", + this, offset, pState.topicPartition); + continue; + } + + // sanity check + if (consumed >= 0 && (offset - consumed) != 1) { + LOG.warn("{}: gap in offsets for {} after {}. {} records missing.", + this, pState.topicPartition, consumed, offset - consumed - 1); + } + + if (curRecord == null) { + LOG.info("{}: first record offset {}", name, offset); + } + + curRecord = null; // user coders below might throw. + + // apply user coders. might want to allow skipping records that fail to decode. + // TODO: wrap exceptions from coders to make explicit to users + KafkaRecord record = new KafkaRecord( + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + decode(rawRecord.key(), source.keyCoder), + decode(rawRecord.value(), source.valueCoder)); + + curTimestamp = source.timestampFn.apply(record); + curRecord = record; + + int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length) + + (rawRecord.value() == null ? 0 : rawRecord.value().length); + pState.recordConsumed(offset, recordSize); + return true; + + } else { // -- (b) + nextBatch(); + + if (!curBatch.hasNext()) { + return false; + } + } + } + } + + private static byte[] nullBytes = new byte[0]; + private static T decode(byte[] bytes, Coder coder) throws IOException { + // If 'bytes' is null, use byte[0]. It is common for key in Kakfa record to be null. + // This makes it impossible for user to distinguish between zero length byte and null. + // Alternately, we could have a ByteArrayCoder that handles nulls, and use that for default + // coder. + byte[] toDecode = bytes == null ? nullBytes : bytes; + return coder.decode(new ExposedByteArrayInputStream(toDecode), Coder.Context.OUTER); + } + + // update latest offset for each partition. + // called from offsetFetcher thread + private void updateLatestOffsets() { + for (PartitionState p : partitionStates) { + try { + offsetConsumer.seekToEnd(p.topicPartition); + long offset = offsetConsumer.position(p.topicPartition); + p.setLatestOffset(offset);; + } catch (Exception e) { + LOG.warn("{}: exception while fetching latest offsets. ignored.", this, e); + p.setLatestOffset(-1L); // reset + } + + LOG.debug("{}: latest offset update for {} : {} (consumed offset {}, avg record size {})", + this, p.topicPartition, p.latestOffset, p.consumedOffset, p.avgRecordSize); + } + + LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); + } + + @Override + public Instant getWatermark() { + if (curRecord == null) { + LOG.warn("{}: getWatermark() : no records have been read yet.", name); + return initialWatermark; + } + + return source.watermarkFn.isPresent() ? + source.watermarkFn.get().apply(curRecord) : curTimestamp; + } + + @Override + public CheckpointMark getCheckpointMark() { + return new KafkaCheckpointMark(ImmutableList.copyOf(// avoid lazy (consumedOffset can change) + Lists.transform(partitionStates, + new Function() { + public PartitionMark apply(PartitionState p) { + return new PartitionMark(p.topicPartition, p.consumedOffset); + } + } + ))); + } + + @Override + public UnboundedSource, ?> getCurrentSource() { + return source; + } + + @Override + public KafkaRecord getCurrent() throws NoSuchElementException { + // should we delay updating consumed offset till this point? Mostly not required. + return curRecord; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return curTimestamp; + } + + + @Override + public long getSplitBacklogBytes() { + long backlogBytes = 0; + + for (PartitionState p : partitionStates) { + long pBacklog = p.approxBacklogInBytes(); + if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { + return UnboundedReader.BACKLOG_UNKNOWN; + } + backlogBytes += pBacklog; + } + + return backlogBytes; + } + + @Override + public void close() throws IOException { + closed = true; + availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. + consumer.wakeup(); + consumerPollThread.shutdown(); + offsetFetcherThread.shutdown(); + Closeables.close(offsetConsumer, true); + Closeables.close(consumer, true); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecord.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecord.java new file mode 100644 index 000000000000..76e688b17852 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecord.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.values.KV; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * KafkaRecord contains key and value of the record as well as metadata for the record (topic name, + * partition id, and offset). + */ +public class KafkaRecord implements Serializable { + + private final String topic; + private final int partition; + private final long offset; + private final KV kv; + + public KafkaRecord( + String topic, + int partition, + long offset, + K key, + V value) { + this(topic, partition, offset, KV.of(key, value)); + } + + public KafkaRecord( + String topic, + int partition, + long offset, + KV kv) { + + this.topic = topic; + this.partition = partition; + this.offset = offset; + this.kv = kv; + } + + public String getTopic() { + return topic; + } + + public int getPartition() { + return partition; + } + + public long getOffset() { + return offset; + } + + public KV getKV() { + return kv; + } + + @Override + public int hashCode() { + return Arrays.deepHashCode(new Object[]{topic, partition, offset, kv}); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof KafkaRecord) { + @SuppressWarnings("unchecked") + KafkaRecord other = (KafkaRecord) obj; + return topic.equals(other.topic) + && partition == other.partition + && offset == other.offset + && kv.equals(other.kv); + } else { + return false; + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java new file mode 100644 index 000000000000..8a3e7f51441d --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaRecordCoder.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StandardCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.util.PropertyNames; +import org.apache.beam.sdk.values.KV; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * {@link Coder} for {@link KafkaRecord}. + */ +public class KafkaRecordCoder extends StandardCoder> { + + private static final StringUtf8Coder stringCoder = StringUtf8Coder.of(); + private static final VarLongCoder longCoder = VarLongCoder.of(); + private static final VarIntCoder intCoder = VarIntCoder.of(); + + private final KvCoder kvCoder; + + @JsonCreator + public static KafkaRecordCoder of(@JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> components) { + KvCoder kvCoder = KvCoder.of(components); + return of(kvCoder.getKeyCoder(), kvCoder.getValueCoder()); + } + + public static KafkaRecordCoder of(Coder keyCoder, Coder valueCoder) { + return new KafkaRecordCoder(keyCoder, valueCoder); + } + + public KafkaRecordCoder(Coder keyCoder, Coder valueCoder) { + this.kvCoder = KvCoder.of(keyCoder, valueCoder); + } + + @Override + public void encode(KafkaRecord value, OutputStream outStream, Context context) + throws CoderException, IOException { + Context nested = context.nested(); + stringCoder.encode(value.getTopic(), outStream, nested); + intCoder.encode(value.getPartition(), outStream, nested); + longCoder.encode(value.getOffset(), outStream, nested); + kvCoder.encode(value.getKV(), outStream, nested); + } + + @Override + public KafkaRecord decode(InputStream inStream, Context context) + throws CoderException, IOException { + Context nested = context.nested(); + return new KafkaRecord( + stringCoder.decode(inStream, nested), + intCoder.decode(inStream, nested), + longCoder.decode(inStream, nested), + kvCoder.decode(inStream, nested)); + } + + @Override + public List> getCoderArguments() { + return kvCoder.getCoderArguments(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + kvCoder.verifyDeterministic(); + } + + @Override + public boolean isRegisterByteSizeObserverCheap(KafkaRecord value, Context context) { + return kvCoder.isRegisterByteSizeObserverCheap(value.getKV(), context); + //TODO : do we have to implement getEncodedSize()? + } + + @SuppressWarnings("unchecked") + @Override + public Object structuralValue(KafkaRecord value) throws Exception { + if (consistentWithEquals()) { + return value; + } else { + return new KafkaRecord( + value.getTopic(), + value.getPartition(), + value.getOffset(), + (KV) kvCoder.structuralValue(value.getKV())); + } + } + + @Override + public boolean consistentWithEquals() { + return kvCoder.consistentWithEquals(); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java new file mode 100644 index 000000000000..96ffc9859a21 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.RunnableOnService; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.Max; +import org.apache.beam.sdk.transforms.Min; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.RemoveDuplicates; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.Values; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +/** + * Tests of {@link KafkaSource}. + */ +@RunWith(JUnit4.class) +public class KafkaIOTest { + /* + * The tests below borrow code and structure from CountingSourceTest. In addition verifies + * the reader interleaves the records from multiple partitions. + * + * Other tests to consider : + * - test KafkaRecordCoder + */ + + // Update mock consumer with records distributed among the given topics, each with given number + // of partitions. Records are assigned in round-robin order among the partitions. + private static MockConsumer mkMockConsumer( + List topics, int partitionsPerTopic, int numElements) { + + final List partitions = new ArrayList<>(); + final Map>> records = new HashMap<>(); + Map> partitionMap = new HashMap<>(); + + for (String topic : topics) { + List partIds = new ArrayList<>(partitionsPerTopic); + for (int i = 0; i < partitionsPerTopic; i++) { + partitions.add(new TopicPartition(topic, i)); + partIds.add(new PartitionInfo(topic, i, null, null, null)); + } + partitionMap.put(topic, partIds); + } + + int numPartitions = partitions.size(); + long[] offsets = new long[numPartitions]; + + for (int i = 0; i < numElements; i++) { + int pIdx = i % numPartitions; + TopicPartition tp = partitions.get(pIdx); + + if (!records.containsKey(tp)) { + records.put(tp, new ArrayList>()); + } + records.get(tp).add( + // Note: this interface has changed in 0.10. may get fixed before the release. + new ConsumerRecord( + tp.topic(), + tp.partition(), + offsets[pIdx]++, + null, // key + ByteBuffer.wrap(new byte[8]).putLong(i).array())); // value is 8 byte record id. + } + + MockConsumer consumer = + new MockConsumer(OffsetResetStrategy.EARLIEST) { + // override assign() to add records that belong to the assigned partitions. + public void assign(List assigned) { + super.assign(assigned); + for (TopicPartition tp : assigned) { + for (ConsumerRecord r : records.get(tp)) { + addRecord(r); + } + updateBeginningOffsets(ImmutableMap.of(tp, 0L)); + updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); + seek(tp, 0); + } + } + }; + + for (String topic : topics) { + consumer.updatePartitions(topic, partitionMap.get(topic)); + } + + return consumer; + } + + private static class ConsumerFactoryFn + implements SerializableFunction, Consumer> { + private final List topics; + private final int partitionsPerTopic; + private final int numElements; + + public ConsumerFactoryFn(List topics, int partitionsPerTopic, int numElements) { + this.topics = topics; + this.partitionsPerTopic = partitionsPerTopic; + this.numElements = numElements; + } + + public Consumer apply(Map config) { + return mkMockConsumer(topics, partitionsPerTopic, numElements); + } + } + + /** + * Creates a consumer with two topics, with 5 partitions each. + * numElements are (round-robin) assigned all the 10 partitions. + */ + private static KafkaIO.TypedRead mkKafkaReadTransform( + int numElements, + @Nullable SerializableFunction, Instant> timestampFn) { + + List topics = ImmutableList.of("topic_a", "topic_b"); + + KafkaIO.Read reader = KafkaIO.read() + .withBootstrapServers("none") + .withTopics(topics) + .withConsumerFactoryFn(new ConsumerFactoryFn(topics, 10, numElements)) // 20 partitions + .withValueCoder(BigEndianLongCoder.of()) + .withMaxNumRecords(numElements); + + if (timestampFn != null) { + return reader.withTimestampFn(timestampFn); + } else { + return reader; + } + } + + private static class AssertMultipleOf implements SerializableFunction, Void> { + private final int num; + + public AssertMultipleOf(int num) { + this.num = num; + } + + @Override + public Void apply(Iterable values) { + for (Long v : values) { + assertEquals(0, v % num); + } + return null; + } + } + + public static void addCountingAsserts(PCollection input, long numElements) { + // Count == numElements + PAssert + .thatSingleton(input.apply("Count", Count.globally())) + .isEqualTo(numElements); + // Unique count == numElements + PAssert + .thatSingleton(input.apply(RemoveDuplicates.create()) + .apply("UniqueCount", Count.globally())) + .isEqualTo(numElements); + // Min == 0 + PAssert + .thatSingleton(input.apply("Min", Min.globally())) + .isEqualTo(0L); + // Max == numElements-1 + PAssert + .thatSingleton(input.apply("Max", Max.globally())) + .isEqualTo(numElements - 1); + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSource() { + Pipeline p = TestPipeline.create(); + int numElements = 1000; + + PCollection input = p + .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) + .withoutMetadata()) + .apply(Values.create()); + + addCountingAsserts(input, numElements); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSourceWithExplicitPartitions() { + Pipeline p = TestPipeline.create(); + int numElements = 1000; + + List topics = ImmutableList.of("test"); + + KafkaIO.TypedRead reader = KafkaIO.read() + .withBootstrapServers("none") + .withTopicPartitions(ImmutableList.of(new TopicPartition("test", 5))) + .withConsumerFactoryFn(new ConsumerFactoryFn(topics, 10, numElements)) // 10 partitions + .withValueCoder(BigEndianLongCoder.of()) + .withMaxNumRecords(numElements / 10); + + PCollection input = p + .apply(reader.withoutMetadata()) + .apply(Values.create()); + + // assert that every element is a multiple of 5. + PAssert + .that(input) + .satisfies(new AssertMultipleOf(5)); + + PAssert + .thatSingleton(input.apply(Count.globally())) + .isEqualTo(numElements / 10L); + + p.run(); + } + + private static class ElementValueDiff extends DoFn { + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element() - c.timestamp().getMillis()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSourceTimestamps() { + Pipeline p = TestPipeline.create(); + int numElements = 1000; + + PCollection input = p + .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()).withoutMetadata()) + .apply(Values.create()); + + addCountingAsserts(input, numElements); + + PCollection diffs = input + .apply("TimestampDiff", ParDo.of(new ElementValueDiff())) + .apply("RemoveDuplicateTimestamps", RemoveDuplicates.create()); + // This assert also confirms that diffs only has one unique value. + PAssert.thatSingleton(diffs).isEqualTo(0L); + + p.run(); + } + + private static class RemoveKafkaMetadata extends DoFn, KV> { + @Override + public void processElement(ProcessContext ctx) throws Exception { + ctx.output(ctx.element().getKV()); + } + } + + @Test + @Category(RunnableOnService.class) + public void testUnboundedSourceSplits() throws Exception { + Pipeline p = TestPipeline.create(); + int numElements = 1000; + int numSplits = 10; + + UnboundedSource, ?> initial = + mkKafkaReadTransform(numElements, null).makeSource(); + List, ?>> splits = + initial.generateInitialSplits(numSplits, p.getOptions()); + assertEquals("Expected exact splitting", numSplits, splits.size()); + + long elementsPerSplit = numElements / numSplits; + assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); + PCollectionList pcollections = PCollectionList.empty(p); + for (int i = 0; i < splits.size(); ++i) { + pcollections = pcollections.and( + p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit)) + .apply("Remove Metadata " + i, ParDo.of(new RemoveKafkaMetadata())) + .apply("collection " + i, Values.create())); + } + PCollection input = pcollections.apply(Flatten.pCollections()); + + addCountingAsserts(input, numElements); + p.run(); + } + + /** + * A timestamp function that uses the given value as the timestamp. + */ + private static class ValueAsTimestampFn + implements SerializableFunction, Instant> { + @Override + public Instant apply(KV input) { + return new Instant(input.getValue()); + } + } + + @Test + public void testUnboundedSourceCheckpointMark() throws Exception { + int numElements = 85; // 85 to make sure some partitions have more records than other. + + // create a single split: + UnboundedSource, KafkaCheckpointMark> source = + mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) + .makeSource() + .generateInitialSplits(1, PipelineOptionsFactory.fromArgs(new String[0]).create()) + .get(0); + + UnboundedReader> reader = source.createReader(null, null); + final int numToSkip = 3; + // advance once: + assertTrue(reader.start()); + + // Advance the source numToSkip-1 elements and manually save state. + for (long l = 0; l < numToSkip - 1; ++l) { + assertTrue(reader.advance()); + } + + // Confirm that we get the expected element in sequence before checkpointing. + + assertEquals(numToSkip - 1, (long) reader.getCurrent().getKV().getValue()); + assertEquals(numToSkip - 1, reader.getCurrentTimestamp().getMillis()); + + // Checkpoint and restart, and confirm that the source continues correctly. + KafkaCheckpointMark mark = CoderUtils.clone( + source.getCheckpointMarkCoder(), (KafkaCheckpointMark) reader.getCheckpointMark()); + reader = source.createReader(null, mark); + assertTrue(reader.start()); + + // Confirm that we get the next elements in sequence. + // This also confirms that Reader interleaves records from each partitions by the reader. + for (int i = numToSkip; i < numElements; i++) { + assertEquals(i, (long) reader.getCurrent().getKV().getValue()); + assertEquals(i, reader.getCurrentTimestamp().getMillis()); + reader.advance(); + } + } +} diff --git a/sdks/java/io/pom.xml b/sdks/java/io/pom.xml index 75f192cacc6d..95d1f55c4e1d 100644 --- a/sdks/java/io/pom.xml +++ b/sdks/java/io/pom.xml @@ -36,6 +36,7 @@ hdfs + kafka - \ No newline at end of file +