diff --git a/.travis.yml b/.travis.yml index 52e1d3a5cbd2..973618b10696 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,5 +31,5 @@ install: script: - travis_retry mvn versions:set -DnewVersion=manual_build - - travis_retry mvn $MAVEN_OVERRIDE install -U + - travis_retry mvn $MAVEN_OVERRIDE verify -U - travis_retry travis/test_wordcount.sh diff --git a/pom.xml b/pom.xml index ba130d25a3d2..de47ff5c4fa5 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ pom sdk + runners examples maven-archetypes/starter maven-archetypes/examples diff --git a/runners/flink/README.md b/runners/flink/README.md new file mode 100644 index 000000000000..0fee6f09cbf5 --- /dev/null +++ b/runners/flink/README.md @@ -0,0 +1,202 @@ +Flink Beam Runner (Flink-Runner) +------------------------------- + +Flink-Runner is a Runner for Apache Beam which enables you to +run Beam dataflows with Flink. It integrates seamlessly with the Beam +API, allowing you to execute Apache Beam programs in streaming or batch mode. + +## Streaming + +### Full Beam Windowing and Triggering Semantics + +The Flink Beam Runner supports *Event Time* allowing you to analyze data with respect to its +associated timestamp. It handles out-or-order and late-arriving elements. You may leverage the full +power of the Beam windowing semantics like *time-based*, *sliding*, *tumbling*, or *count* +windows. You may build *session* windows which allow you to keep track of events associated with +each other. + +### Fault-Tolerance + +The program's state is persisted by Apache Flink. You may re-run and resume your program upon +failure or if you decide to continue computation at a later time. + +### Sources and Sinks + +Build your own data ingestion or digestion using the source/sink interface. Re-use Flink's sources +and sinks or use the provided support for Apache Kafka. + +### Seamless integration + +To execute a Beam program in streaming mode, just enable streaming in the `PipelineOptions`: + + options.setStreaming(true); + +That's it. If you prefer batched execution, simply disable streaming mode. + +## Batch + +### Batch optimization + +Flink gives you out-of-core algorithms which operate on its managed memory to perform sorting, +caching, and hash table operations. We have optimized operations like CoGroup to use Flink's +optimized out-of-core implementation. + +### Fault-Tolerance + +We guarantee job-level fault-tolerance which gracefully restarts failed batch jobs. + +### Sources and Sinks + +Build your own data ingestion or digestion using the source/sink interface or re-use Flink's sources +and sinks. + +## Features + +The Flink Beam Runner maintains as much compatibility with the Beam API as possible. We +support transformations on data like: + +- Grouping +- Windowing +- ParDo +- CoGroup +- Flatten +- Combine +- Side inputs/outputs +- Encoding + +# Getting Started + +To get started using the Flink Runner, we first need to install the latest version. + +## Install Flink-Runner ## + +To retrieve the latest version of Flink-Runner, run the following command + + git clone https://github.com/apache/incubator-beam + +Then switch to the newly created directory and run Maven to build the Beam runner: + + cd incubator-beam + mvn clean install -DskipTests + +Flink-Runner is now installed in your local maven repository. + +## Executing an example + +Next, let's run the classic WordCount example. It's semantically identically to +the example provided with ApacheBeam. Only this time, we chose the +`FlinkPipelineRunner` to execute the WordCount on top of Flink. + +Here's an excerpt from the WordCount class file: + +```java +Options options = PipelineOptionsFactory.fromArgs(args).as(Options.class); +// yes, we want to run WordCount with Flink +options.setRunner(FlinkPipelineRunner.class); + +Pipeline p = Pipeline.create(options); + +p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(new CountWords()) + .apply(TextIO.Write.named("WriteCounts") + .to(options.getOutput()) + .withNumShards(options.getNumShards())); + +p.run(); +``` + +To execute the example, let's first get some sample data: + + curl http://www.gutenberg.org/cache/epub/1128/pg1128.txt > kinglear.txt + +Then let's run the included WordCount locally on your machine: + + mvn exec:exec -Dinput=kinglear.txt -Doutput=wordcounts.txt + +Congratulations, you have run your first ApacheBeam program on top of Apache Flink! + + +# Running Beam programs on a Flink cluster + +You can run your Beam program on an Apache Flink cluster. Please start off by creating a new +Maven project. + + mvn archetype:generate -DgroupId=com.mycompany.beam -DartifactId=beam-test \ + -DarchetypeArtifactId=maven-archetype-quickstart -DinteractiveMode=false + +The contents of the root `pom.xml` should be slightly changed aftewards (explanation below): + +```xml + + + 4.0.0 + + com.mycompany.beam + beam-test + 1.0 + + + + org.apache.beam + flink-runner + 0.2 + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.4.1 + + + package + + shade + + + + + WordCount + + + + + org.apache.flink:* + + + + + + + + + + + + +``` + +The following changes have been made: + +1. The Flink Beam Runner was added as a dependency. + +2. The Maven Shade plugin was added to build a fat jar. + +A fat jar is necessary if you want to submit your Beam code to a Flink cluster. The fat jar +includes your program code but also Beam code which is necessary during runtime. Note that this +step is necessary because the Beam Runner is not part of Flink. + +You can then build the jar using `mvn clean package`. Please submit the fat jar in the `target` +folder to the Flink cluster using the command-line utility like so: + + ./bin/flink run /path/to/fat.jar + + +# More + +For more information, please visit the [Apache Flink Website](http://flink.apache.org) or contact +the [Mailinglists](http://flink.apache.org/community.html#mailing-lists). \ No newline at end of file diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml new file mode 100644 index 000000000000..2110c2c8023f --- /dev/null +++ b/runners/flink/pom.xml @@ -0,0 +1,264 @@ + + + + + 4.0.0 + + + org.apache.beam + runners + 1.5.0-SNAPSHOT + + + flink-runner + 0.3-SNAPSHOT + + Flink Beam Runner + jar + + 2015 + + + + The Apache Software License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + UTF-8 + UTF-8 + 1.0-SNAPSHOT + 1.5.0-SNAPSHOT + + org.apache.beam.runners.flink.examples.WordCount + kinglear.txt + wordcounts.txt + 1 + + + + + apache.snapshots + Apache Development Snapshot Repository + https://repository.apache.org/content/repositories/snapshots/ + + false + + + true + + + + + + + org.apache.flink + flink-core + ${flink.version} + + + org.apache.flink + flink-streaming-java_2.10 + ${flink.version} + + + org.apache.flink + flink-streaming-java_2.10 + ${flink.version} + test + test-jar + + + org.apache.flink + flink-java + ${flink.version} + + + org.apache.flink + flink-avro_2.10 + ${flink.version} + + + org.apache.flink + flink-clients_2.10 + ${flink.version} + + + org.apache.flink + flink-test-utils_2.10 + ${flink.version} + test + + + org.apache.flink + flink-connector-kafka-0.8_2.10 + ${flink.version} + + + org.apache.flink + flink-avro + ${flink.version} + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-all + ${beam.version} + + + org.slf4j + slf4j-jdk14 + + + + + org.mockito + mockito-all + 1.9.5 + test + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.6 + + + + true + true + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 1.7 + 1.7 + + + + + + maven-failsafe-plugin + 2.17 + + + + integration-test + verify + + + + + -Dlog4j.configuration=log4j-test.properties -XX:-UseGCOverheadLimit + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.17 + + -Dlog4j.configuration=log4j-test.properties -XX:-UseGCOverheadLimit + + + + + + org.apache.maven.plugins + maven-eclipse-plugin + 2.8 + + + org.eclipse.jdt.launching.JRE_CONTAINER + + true + true + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 1.3.1 + + + enforce-maven + + enforce + + + + + [1.7,) + + + + [3.0.3,) + + + + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.2.1 + + + none + + + + java + + -classpath + + ${clazz} + --input=${input} + --output=${output} + --parallelism=${parallelism} + + + + + + + + + diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java new file mode 100644 index 000000000000..8825ed36dee0 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.translation.FlinkPipelineTranslator; +import org.apache.beam.runners.flink.translation.FlinkBatchPipelineTranslator; +import org.apache.beam.runners.flink.translation.FlinkStreamingPipelineTranslator; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.java.CollectionEnvironment; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +/** + * The class that instantiates and manages the execution of a given job. + * Depending on if the job is a Streaming or Batch processing one, it creates + * the adequate execution environment ({@link ExecutionEnvironment} or {@link StreamExecutionEnvironment}), + * the necessary {@link FlinkPipelineTranslator} ({@link FlinkBatchPipelineTranslator} or + * {@link FlinkStreamingPipelineTranslator})to transform the Beam job into a Flink one, and + * executes the (translated) job. + */ +public class FlinkPipelineExecutionEnvironment { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkPipelineExecutionEnvironment.class); + + private final FlinkPipelineOptions options; + + /** + * The Flink Batch execution environment. This is instantiated to either a + * {@link org.apache.flink.api.java.CollectionEnvironment}, + * a {@link org.apache.flink.api.java.LocalEnvironment} or + * a {@link org.apache.flink.api.java.RemoteEnvironment}, depending on the configuration + * options. + */ + private ExecutionEnvironment flinkBatchEnv; + + + /** + * The Flink Streaming execution environment. This is instantiated to either a + * {@link org.apache.flink.streaming.api.environment.LocalStreamEnvironment} or + * a {@link org.apache.flink.streaming.api.environment.RemoteStreamEnvironment}, depending + * on the configuration options, and more specifically, the url of the master. + */ + private StreamExecutionEnvironment flinkStreamEnv; + + /** + * Translator for this FlinkPipelineRunner. Its role is to translate the Beam operators to + * their Flink counterparts. Based on the options provided by the user, if we have a streaming job, + * this is instantiated as a {@link FlinkStreamingPipelineTranslator}. In other case, i.e. a batch job, + * a {@link FlinkBatchPipelineTranslator} is created. + */ + private FlinkPipelineTranslator flinkPipelineTranslator; + + /** + * Creates a {@link FlinkPipelineExecutionEnvironment} with the user-specified parameters in the + * provided {@link FlinkPipelineOptions}. + * + * @param options the user-defined pipeline options. + * */ + public FlinkPipelineExecutionEnvironment(FlinkPipelineOptions options) { + this.options = Preconditions.checkNotNull(options); + this.createPipelineExecutionEnvironment(); + this.createPipelineTranslator(); + } + + /** + * Depending on the type of job (Streaming or Batch) and the user-specified options, + * this method creates the adequate ExecutionEnvironment. + */ + private void createPipelineExecutionEnvironment() { + if (options.isStreaming()) { + createStreamExecutionEnvironment(); + } else { + createBatchExecutionEnvironment(); + } + } + + /** + * Depending on the type of job (Streaming or Batch), this method creates the adequate job graph + * translator. In the case of batch, it will work with {@link org.apache.flink.api.java.DataSet}, + * while for streaming, it will work with {@link org.apache.flink.streaming.api.datastream.DataStream}. + */ + private void createPipelineTranslator() { + checkInitializationState(); + if (this.flinkPipelineTranslator != null) { + throw new IllegalStateException("FlinkPipelineTranslator already initialized."); + } + + this.flinkPipelineTranslator = options.isStreaming() ? + new FlinkStreamingPipelineTranslator(flinkStreamEnv, options) : + new FlinkBatchPipelineTranslator(flinkBatchEnv, options); + } + + /** + * Depending on if the job is a Streaming or a Batch one, this method creates + * the necessary execution environment and pipeline translator, and translates + * the {@link com.google.cloud.dataflow.sdk.values.PCollection} program into + * a {@link org.apache.flink.api.java.DataSet} or {@link org.apache.flink.streaming.api.datastream.DataStream} + * one. + * */ + public void translate(Pipeline pipeline) { + checkInitializationState(); + if(this.flinkBatchEnv == null && this.flinkStreamEnv == null) { + createPipelineExecutionEnvironment(); + } + if (this.flinkPipelineTranslator == null) { + createPipelineTranslator(); + } + this.flinkPipelineTranslator.translate(pipeline); + } + + /** + * Launches the program execution. + * */ + public JobExecutionResult executePipeline() throws Exception { + if (options.isStreaming()) { + if (this.flinkStreamEnv == null) { + throw new RuntimeException("FlinkPipelineExecutionEnvironment not initialized."); + } + if (this.flinkPipelineTranslator == null) { + throw new RuntimeException("FlinkPipelineTranslator not initialized."); + } + return this.flinkStreamEnv.execute(); + } else { + if (this.flinkBatchEnv == null) { + throw new RuntimeException("FlinkPipelineExecutionEnvironment not initialized."); + } + if (this.flinkPipelineTranslator == null) { + throw new RuntimeException("FlinkPipelineTranslator not initialized."); + } + return this.flinkBatchEnv.execute(); + } + } + + /** + * If the submitted job is a batch processing job, this method creates the adequate + * Flink {@link org.apache.flink.api.java.ExecutionEnvironment} depending + * on the user-specified options. + */ + private void createBatchExecutionEnvironment() { + if (this.flinkStreamEnv != null || this.flinkBatchEnv != null) { + throw new RuntimeException("FlinkPipelineExecutionEnvironment already initialized."); + } + + LOG.info("Creating the required Batch Execution Environment."); + + String masterUrl = options.getFlinkMaster(); + this.flinkStreamEnv = null; + + // depending on the master, create the right environment. + if (masterUrl.equals("[local]")) { + this.flinkBatchEnv = ExecutionEnvironment.createLocalEnvironment(); + } else if (masterUrl.equals("[collection]")) { + this.flinkBatchEnv = new CollectionEnvironment(); + } else if (masterUrl.equals("[auto]")) { + this.flinkBatchEnv = ExecutionEnvironment.getExecutionEnvironment(); + } else if (masterUrl.matches(".*:\\d*")) { + String[] parts = masterUrl.split(":"); + List stagingFiles = options.getFilesToStage(); + this.flinkBatchEnv = ExecutionEnvironment.createRemoteEnvironment(parts[0], + Integer.parseInt(parts[1]), + stagingFiles.toArray(new String[stagingFiles.size()])); + } else { + LOG.warn("Unrecognized Flink Master URL {}. Defaulting to [auto].", masterUrl); + this.flinkBatchEnv = ExecutionEnvironment.getExecutionEnvironment(); + } + + // set the correct parallelism. + if (options.getParallelism() != -1 && !(this.flinkBatchEnv instanceof CollectionEnvironment)) { + this.flinkBatchEnv.setParallelism(options.getParallelism()); + } + + // set parallelism in the options (required by some execution code) + options.setParallelism(flinkBatchEnv.getParallelism()); + } + + /** + * If the submitted job is a stream processing job, this method creates the adequate + * Flink {@link org.apache.flink.streaming.api.environment.StreamExecutionEnvironment} depending + * on the user-specified options. + */ + private void createStreamExecutionEnvironment() { + if (this.flinkStreamEnv != null || this.flinkBatchEnv != null) { + throw new RuntimeException("FlinkPipelineExecutionEnvironment already initialized."); + } + + LOG.info("Creating the required Streaming Environment."); + + String masterUrl = options.getFlinkMaster(); + this.flinkBatchEnv = null; + + // depending on the master, create the right environment. + if (masterUrl.equals("[local]")) { + this.flinkStreamEnv = StreamExecutionEnvironment.createLocalEnvironment(); + } else if (masterUrl.equals("[auto]")) { + this.flinkStreamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); + } else if (masterUrl.matches(".*:\\d*")) { + String[] parts = masterUrl.split(":"); + List stagingFiles = options.getFilesToStage(); + this.flinkStreamEnv = StreamExecutionEnvironment.createRemoteEnvironment(parts[0], + Integer.parseInt(parts[1]), stagingFiles.toArray(new String[stagingFiles.size()])); + } else { + LOG.warn("Unrecognized Flink Master URL {}. Defaulting to [auto].", masterUrl); + this.flinkStreamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); + } + + // set the correct parallelism. + if (options.getParallelism() != -1) { + this.flinkStreamEnv.setParallelism(options.getParallelism()); + } + + // set parallelism in the options (required by some execution code) + options.setParallelism(flinkStreamEnv.getParallelism()); + + // default to event time + this.flinkStreamEnv.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); + + // for the following 2 parameters, a value of -1 means that Flink will use + // the default values as specified in the configuration. + int numRetries = options.getNumberOfExecutionRetries(); + if (numRetries != -1) { + this.flinkStreamEnv.setNumberOfExecutionRetries(numRetries); + } + long retryDelay = options.getExecutionRetryDelay(); + if (retryDelay != -1) { + this.flinkStreamEnv.getConfig().setExecutionRetryDelay(retryDelay); + } + + // A value of -1 corresponds to disabled checkpointing (see CheckpointConfig in Flink). + // If the value is not -1, then the validity checks are applied. + // By default, checkpointing is disabled. + long checkpointInterval = options.getCheckpointingInterval(); + if(checkpointInterval != -1) { + if (checkpointInterval < 1) { + throw new IllegalArgumentException("The checkpoint interval must be positive"); + } + this.flinkStreamEnv.enableCheckpointing(checkpointInterval); + } + } + + private void checkInitializationState() { + if (options.isStreaming() && this.flinkBatchEnv != null) { + throw new IllegalStateException("Attempted to run a Streaming Job with a Batch Execution Environment."); + } else if (!options.isStreaming() && this.flinkStreamEnv != null) { + throw new IllegalStateException("Attempted to run a Batch Job with a Streaming Execution Environment."); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java new file mode 100644 index 000000000000..2f4b3ea47457 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions; +import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.StreamingOptions; + +import java.util.List; + +/** + * Options which can be used to configure a Flink PipelineRunner. + */ +public interface FlinkPipelineOptions extends PipelineOptions, ApplicationNameOptions, StreamingOptions { + + /** + * List of local files to make available to workers. + *

+ * Jars are placed on the worker's classpath. + *

+ * The default value is the list of jars from the main program's classpath. + */ + @Description("Jar-Files to send to all workers and put on the classpath. " + + "The default value is all files from the classpath.") + @JsonIgnore + List getFilesToStage(); + void setFilesToStage(List value); + + /** + * The job name is used to identify jobs running on a Flink cluster. + */ + @Description("Dataflow job name, to uniquely identify active jobs. " + + "Defaults to using the ApplicationName-UserName-Date.") + @Default.InstanceFactory(DataflowPipelineOptions.JobNameFactory.class) + String getJobName(); + void setJobName(String value); + + /** + * The url of the Flink JobManager on which to execute pipelines. This can either be + * the the address of a cluster JobManager, in the form "host:port" or one of the special + * Strings "[local]", "[collection]" or "[auto]". "[local]" will start a local Flink + * Cluster in the JVM, "[collection]" will execute the pipeline on Java Collections while + * "[auto]" will let the system decide where to execute the pipeline based on the environment. + */ + @Description("Address of the Flink Master where the Pipeline should be executed. Can" + + " either be of the form \"host:port\" or one of the special values [local], " + + "[collection] or [auto].") + String getFlinkMaster(); + void setFlinkMaster(String value); + + @Description("The degree of parallelism to be used when distributing operations onto workers.") + @Default.Integer(-1) + Integer getParallelism(); + void setParallelism(Integer value); + + @Description("The interval between consecutive checkpoints (i.e. snapshots of the current pipeline state used for " + + "fault tolerance).") + @Default.Long(-1L) + Long getCheckpointingInterval(); + void setCheckpointingInterval(Long interval); + + @Description("Sets the number of times that failed tasks are re-executed. " + + "A value of zero effectively disables fault tolerance. A value of -1 indicates " + + "that the system default value (as defined in the configuration) should be used.") + @Default.Integer(-1) + Integer getNumberOfExecutionRetries(); + void setNumberOfExecutionRetries(Integer retries); + + @Description("Sets the delay between executions. A value of {@code -1} indicates that the default value should be used.") + @Default.Long(-1L) + Long getExecutionRetryDelay(); + void setExecutionRetryDelay(Long delay); +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java new file mode 100644 index 000000000000..fe773d98ad39 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsValidator; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.JobExecutionResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * A {@link PipelineRunner} that executes the operations in the + * pipeline by first translating them to a Flink Plan and then executing them either locally + * or on a Flink cluster, depending on the configuration. + *

+ * This is based on {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner}. + */ +public class FlinkPipelineRunner extends PipelineRunner { + + private static final Logger LOG = LoggerFactory.getLogger(FlinkPipelineRunner.class); + + /** + * Provided options. + */ + private final FlinkPipelineOptions options; + + private final FlinkPipelineExecutionEnvironment flinkJobEnv; + + /** + * Construct a runner from the provided options. + * + * @param options Properties which configure the runner. + * @return The newly created runner. + */ + public static FlinkPipelineRunner fromOptions(PipelineOptions options) { + FlinkPipelineOptions flinkOptions = + PipelineOptionsValidator.validate(FlinkPipelineOptions.class, options); + ArrayList missing = new ArrayList<>(); + + if (flinkOptions.getAppName() == null) { + missing.add("appName"); + } + if (missing.size() > 0) { + throw new IllegalArgumentException( + "Missing required values: " + Joiner.on(',').join(missing)); + } + + if (flinkOptions.getFilesToStage() == null) { + flinkOptions.setFilesToStage(detectClassPathResourcesToStage( + DataflowPipelineRunner.class.getClassLoader())); + LOG.info("PipelineOptions.filesToStage was not specified. " + + "Defaulting to files from the classpath: will stage {} files. " + + "Enable logging at DEBUG level to see which files will be staged.", + flinkOptions.getFilesToStage().size()); + LOG.debug("Classpath elements: {}", flinkOptions.getFilesToStage()); + } + + // Verify jobName according to service requirements. + String jobName = flinkOptions.getJobName().toLowerCase(); + Preconditions.checkArgument(jobName.matches("[a-z]([-a-z0-9]*[a-z0-9])?"), "JobName invalid; " + + "the name must consist of only the characters " + "[-a-z0-9], starting with a letter " + + "and ending with a letter " + "or number"); + Preconditions.checkArgument(jobName.length() <= 40, + "JobName too long; must be no more than 40 characters in length"); + + // Set Flink Master to [auto] if no option was specified. + if (flinkOptions.getFlinkMaster() == null) { + flinkOptions.setFlinkMaster("[auto]"); + } + + return new FlinkPipelineRunner(flinkOptions); + } + + private FlinkPipelineRunner(FlinkPipelineOptions options) { + this.options = options; + this.flinkJobEnv = new FlinkPipelineExecutionEnvironment(options); + } + + @Override + public FlinkRunnerResult run(Pipeline pipeline) { + LOG.info("Executing pipeline using FlinkPipelineRunner."); + + LOG.info("Translating pipeline to Flink program."); + + this.flinkJobEnv.translate(pipeline); + + LOG.info("Starting execution of Flink program."); + + JobExecutionResult result; + try { + result = this.flinkJobEnv.executePipeline(); + } catch (Exception e) { + LOG.error("Pipeline execution failed", e); + throw new RuntimeException("Pipeline execution failed", e); + } + + LOG.info("Execution finished in {} msecs", result.getNetRuntime()); + + Map accumulators = result.getAllAccumulatorResults(); + if (accumulators != null && !accumulators.isEmpty()) { + LOG.info("Final aggregator values:"); + + for (Map.Entry entry : result.getAllAccumulatorResults().entrySet()) { + LOG.info("{} : {}", entry.getKey(), entry.getValue()); + } + } + + return new FlinkRunnerResult(accumulators, result.getNetRuntime()); + } + + /** + * For testing. + */ + public FlinkPipelineOptions getPipelineOptions() { + return options; + } + + /** + * Constructs a runner with default properties for testing. + * + * @return The newly created runner. + */ + public static FlinkPipelineRunner createForTest(boolean streaming) { + FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class); + // we use [auto] for testing since this will make it pick up the Testing + // ExecutionEnvironment + options.setFlinkMaster("[auto]"); + options.setStreaming(streaming); + return new FlinkPipelineRunner(options); + } + + @Override + public Output apply( + PTransform transform, Input input) { + return super.apply(transform, input); + } + + ///////////////////////////////////////////////////////////////////////////// + + @Override + public String toString() { + return "DataflowPipelineRunner#" + hashCode(); + } + + /** + * Attempts to detect all the resources the class loader has access to. This does not recurse + * to class loader parents stopping it from pulling in resources from the system class loader. + * + * @param classLoader The URLClassLoader to use to detect resources to stage. + * @return A list of absolute paths to the resources the class loader uses. + * @throws IllegalArgumentException If either the class loader is not a URLClassLoader or one + * of the resources the class loader exposes is not a file resource. + */ + protected static List detectClassPathResourcesToStage(ClassLoader classLoader) { + if (!(classLoader instanceof URLClassLoader)) { + String message = String.format("Unable to use ClassLoader to detect classpath elements. " + + "Current ClassLoader is %s, only URLClassLoaders are supported.", classLoader); + LOG.error(message); + throw new IllegalArgumentException(message); + } + + List files = new ArrayList<>(); + for (URL url : ((URLClassLoader) classLoader).getURLs()) { + try { + files.add(new File(url.toURI()).getAbsolutePath()); + } catch (IllegalArgumentException | URISyntaxException e) { + String message = String.format("Unable to convert url (%s) to file.", url); + LOG.error(message); + throw new IllegalArgumentException(message, e); + } + } + return files; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java new file mode 100644 index 000000000000..8fd08ec09132 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; + +import java.util.Collections; +import java.util.Map; + +/** + * Result of executing a {@link com.google.cloud.dataflow.sdk.Pipeline} with Flink. This + * has methods to query to job runtime and the final values of + * {@link com.google.cloud.dataflow.sdk.transforms.Aggregator}s. + */ +public class FlinkRunnerResult implements PipelineResult { + + private final Map aggregators; + + private final long runtime; + + public FlinkRunnerResult(Map aggregators, long runtime) { + this.aggregators = (aggregators == null || aggregators.isEmpty()) ? + Collections.emptyMap() : + Collections.unmodifiableMap(aggregators); + + this.runtime = runtime; + } + + @Override + public State getState() { + return null; + } + + @Override + public AggregatorValues getAggregatorValues(final Aggregator aggregator) throws AggregatorRetrievalException { + // TODO provide a list of all accumulator step values + Object value = aggregators.get(aggregator.getName()); + if (value != null) { + return new AggregatorValues() { + @Override + public Map getValuesAtSteps() { + return (Map) aggregators; + } + }; + } else { + throw new AggregatorRetrievalException("Accumulator results not found.", + new RuntimeException("Accumulator does not exist.")); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/TFIDF.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/TFIDF.java new file mode 100644 index 000000000000..ab23b926e2a7 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/TFIDF.java @@ -0,0 +1,452 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.beam.runners.flink.examples; + +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.GcsOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.WithKeys; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.util.GcsUtil; +import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashSet; +import java.util.Set; + +/** + * An example that computes a basic TF-IDF search table for a directory or GCS prefix. + * + *

Concepts: joining data; side inputs; logging + * + *

To execute this pipeline locally, specify general pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ * }
+ * and a local output file or output prefix on GCS: + *
{@code
+ *   --output=[YOUR_LOCAL_FILE | gs://YOUR_OUTPUT_PREFIX]
+ * }
+ * + *

To execute this pipeline using the Dataflow service, specify pipeline configuration: + *

{@code
+ *   --project=YOUR_PROJECT_ID
+ *   --stagingLocation=gs://YOUR_STAGING_DIRECTORY
+ *   --runner=BlockingDataflowPipelineRunner
+ * and an output prefix on GCS:
+ *   --output=gs://YOUR_OUTPUT_PREFIX
+ * }
+ * + *

The default input is {@code gs://dataflow-samples/shakespeare/} and can be overridden with + * {@code --input}. + */ +public class TFIDF { + /** + * Options supported by {@link TFIDF}. + *

+ * Inherits standard configuration options. + */ + private interface Options extends PipelineOptions, FlinkPipelineOptions { + @Description("Path to the directory or GCS prefix containing files to read from") + @Default.String("gs://dataflow-samples/shakespeare/") + String getInput(); + void setInput(String value); + + @Description("Prefix of output URI to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + /** + * Lists documents contained beneath the {@code options.input} prefix/directory. + */ + public static Set listInputDocuments(Options options) + throws URISyntaxException, IOException { + URI baseUri = new URI(options.getInput()); + + // List all documents in the directory or GCS prefix. + URI absoluteUri; + if (baseUri.getScheme() != null) { + absoluteUri = baseUri; + } else { + absoluteUri = new URI( + "file", + baseUri.getAuthority(), + baseUri.getPath(), + baseUri.getQuery(), + baseUri.getFragment()); + } + + Set uris = new HashSet<>(); + if (absoluteUri.getScheme().equals("file")) { + File directory = new File(absoluteUri); + for (String entry : directory.list()) { + File path = new File(directory, entry); + uris.add(path.toURI()); + } + } else if (absoluteUri.getScheme().equals("gs")) { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + URI gcsUriGlob = new URI( + absoluteUri.getScheme(), + absoluteUri.getAuthority(), + absoluteUri.getPath() + "*", + absoluteUri.getQuery(), + absoluteUri.getFragment()); + for (GcsPath entry : gcsUtil.expand(GcsPath.fromUri(gcsUriGlob))) { + uris.add(entry.toUri()); + } + } + + return uris; + } + + /** + * Reads the documents at the provided uris and returns all lines + * from the documents tagged with which document they are from. + */ + public static class ReadDocuments + extends PTransform>> { + private static final long serialVersionUID = 0; + + private Iterable uris; + + public ReadDocuments(Iterable uris) { + this.uris = uris; + } + + @Override + public Coder getDefaultOutputCoder() { + return KvCoder.of(StringDelegateCoder.of(URI.class), StringUtf8Coder.of()); + } + + @Override + public PCollection> apply(PInput input) { + Pipeline pipeline = input.getPipeline(); + + // Create one TextIO.Read transform for each document + // and add its output to a PCollectionList + PCollectionList> urisToLines = + PCollectionList.empty(pipeline); + + // TextIO.Read supports: + // - file: URIs and paths locally + // - gs: URIs on the service + for (final URI uri : uris) { + String uriString; + if (uri.getScheme().equals("file")) { + uriString = new File(uri).getPath(); + } else { + uriString = uri.toString(); + } + + PCollection> oneUriToLines = pipeline + .apply(TextIO.Read.from(uriString) + .named("TextIO.Read(" + uriString + ")")) + .apply("WithKeys(" + uriString + ")", WithKeys.of(uri)); + + urisToLines = urisToLines.and(oneUriToLines); + } + + return urisToLines.apply(Flatten.>pCollections()); + } + } + + /** + * A transform containing a basic TF-IDF pipeline. The input consists of KV objects + * where the key is the document's URI and the value is a piece + * of the document's content. The output is mapping from terms to + * scores for each document URI. + */ + public static class ComputeTfIdf + extends PTransform>, PCollection>>> { + private static final long serialVersionUID = 0; + + public ComputeTfIdf() { } + + @Override + public PCollection>> apply( + PCollection> uriToContent) { + + // Compute the total number of documents, and + // prepare this singleton PCollectionView for + // use as a side input. + final PCollectionView totalDocuments = + uriToContent + .apply("GetURIs", Keys.create()) + .apply("RemoveDuplicateDocs", RemoveDuplicates.create()) + .apply(Count.globally()) + .apply(View.asSingleton()); + + // Create a collection of pairs mapping a URI to each + // of the words in the document associated with that that URI. + PCollection> uriToWords = uriToContent + .apply(ParDo.named("SplitWords").of( + new DoFn, KV>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + String line = c.element().getValue(); + for (String word : line.split("\\W+")) { + // Log INFO messages when the word “love” is found. + if (word.toLowerCase().equals("love")) { + LOG.info("Found {}", word.toLowerCase()); + } + + if (!word.isEmpty()) { + c.output(KV.of(uri, word.toLowerCase())); + } + } + } + })); + + // Compute a mapping from each word to the total + // number of documents in which it appears. + PCollection> wordToDocCount = uriToWords + .apply("RemoveDuplicateWords", RemoveDuplicates.>create()) + .apply(Values.create()) + .apply("CountDocs", Count.perElement()); + + // Compute a mapping from each URI to the total + // number of words in the document associated with that URI. + PCollection> uriToWordTotal = uriToWords + .apply("GetURIs2", Keys.create()) + .apply("CountWords", Count.perElement()); + + // Count, for each (URI, word) pair, the number of + // occurrences of that word in the document associated + // with the URI. + PCollection, Long>> uriAndWordToCount = uriToWords + .apply("CountWordDocPairs", Count.>perElement()); + + // Adjust the above collection to a mapping from + // (URI, word) pairs to counts into an isomorphic mapping + // from URI to (word, count) pairs, to prepare for a join + // by the URI key. + PCollection>> uriToWordAndCount = uriAndWordToCount + .apply(ParDo.named("ShiftKeys").of( + new DoFn, Long>, KV>>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey().getKey(); + String word = c.element().getKey().getValue(); + Long occurrences = c.element().getValue(); + c.output(KV.of(uri, KV.of(word, occurrences))); + } + })); + + // Prepare to join the mapping of URI to (word, count) pairs with + // the mapping of URI to total word counts, by associating + // each of the input PCollection> with + // a tuple tag. Each input must have the same key type, URI + // in this case. The type parameter of the tuple tag matches + // the types of the values for each collection. + final TupleTag wordTotalsTag = new TupleTag<>(); + final TupleTag> wordCountsTag = new TupleTag<>(); + KeyedPCollectionTuple coGbkInput = KeyedPCollectionTuple + .of(wordTotalsTag, uriToWordTotal) + .and(wordCountsTag, uriToWordAndCount); + + // Perform a CoGroupByKey (a sort of pre-join) on the prepared + // inputs. This yields a mapping from URI to a CoGbkResult + // (CoGroupByKey Result). The CoGbkResult is a mapping + // from the above tuple tags to the values in each input + // associated with a particular URI. In this case, each + // KV group a URI with the total number of + // words in that document as well as all the (word, count) + // pairs for particular words. + PCollection> uriToWordAndCountAndTotal = coGbkInput + .apply("CoGroupByUri", CoGroupByKey.create()); + + // Compute a mapping from each word to a (URI, term frequency) + // pair for each URI. A word's term frequency for a document + // is simply the number of times that word occurs in the document + // divided by the total number of words in the document. + PCollection>> wordToUriAndTf = uriToWordAndCountAndTotal + .apply(ParDo.named("ComputeTermFrequencies").of( + new DoFn, KV>>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + URI uri = c.element().getKey(); + Long wordTotal = c.element().getValue().getOnly(wordTotalsTag); + + for (KV wordAndCount + : c.element().getValue().getAll(wordCountsTag)) { + String word = wordAndCount.getKey(); + Long wordCount = wordAndCount.getValue(); + Double termFrequency = wordCount.doubleValue() / wordTotal.doubleValue(); + c.output(KV.of(word, KV.of(uri, termFrequency))); + } + } + })); + + // Compute a mapping from each word to its document frequency. + // A word's document frequency in a corpus is the number of + // documents in which the word appears divided by the total + // number of documents in the corpus. Note how the total number of + // documents is passed as a side input; the same value is + // presented to each invocation of the DoFn. + PCollection> wordToDf = wordToDocCount + .apply(ParDo + .named("ComputeDocFrequencies") + .withSideInputs(totalDocuments) + .of(new DoFn, KV>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Long documentCount = c.element().getValue(); + Long documentTotal = c.sideInput(totalDocuments); + Double documentFrequency = documentCount.doubleValue() + / documentTotal.doubleValue(); + + c.output(KV.of(word, documentFrequency)); + } + })); + + // Join the term frequency and document frequency + // collections, each keyed on the word. + final TupleTag> tfTag = new TupleTag<>(); + final TupleTag dfTag = new TupleTag<>(); + PCollection> wordToUriAndTfAndDf = KeyedPCollectionTuple + .of(tfTag, wordToUriAndTf) + .and(dfTag, wordToDf) + .apply(CoGroupByKey.create()); + + // Compute a mapping from each word to a (URI, TF-IDF) score + // for each URI. There are a variety of definitions of TF-IDF + // ("term frequency - inverse document frequency") score; + // here we use a basic version that is the term frequency + // divided by the log of the document frequency. + + return wordToUriAndTfAndDf + .apply(ParDo.named("ComputeTfIdf").of( + new DoFn, KV>>() { + private static final long serialVersionUID1 = 0; + + @Override + public void processElement(ProcessContext c) { + String word = c.element().getKey(); + Double df = c.element().getValue().getOnly(dfTag); + + for (KV uriAndTf : c.element().getValue().getAll(tfTag)) { + URI uri = uriAndTf.getKey(); + Double tf = uriAndTf.getValue(); + Double tfIdf = tf * Math.log(1 / df); + c.output(KV.of(word, KV.of(uri, tfIdf))); + } + } + })); + } + + // Instantiate Logger. + // It is suggested that the user specify the class name of the containing class + // (in this case ComputeTfIdf). + private static final Logger LOG = LoggerFactory.getLogger(ComputeTfIdf.class); + } + + /** + * A {@link PTransform} to write, in CSV format, a mapping from term and URI + * to score. + */ + public static class WriteTfIdf + extends PTransform>>, PDone> { + private static final long serialVersionUID = 0; + + private String output; + + public WriteTfIdf(String output) { + this.output = output; + } + + @Override + public PDone apply(PCollection>> wordToUriAndTfIdf) { + return wordToUriAndTfIdf + .apply(ParDo.named("Format").of(new DoFn>, String>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + c.output(String.format("%s,\t%s,\t%f", + c.element().getKey(), + c.element().getValue().getKey(), + c.element().getValue().getValue())); + } + })) + .apply(TextIO.Write + .to(output) + .withSuffix(".csv")); + } + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + + options.setRunner(FlinkPipelineRunner.class); + + Pipeline pipeline = Pipeline.create(options); + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + pipeline + .apply(new ReadDocuments(listInputDocuments(options))) + .apply(new ComputeTfIdf()) + .apply(new WriteTfIdf(options.getOutput())); + + pipeline.run(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/WordCount.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/WordCount.java new file mode 100644 index 000000000000..7d12fedab294 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/WordCount.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.examples; + +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +public class WordCount { + + public static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + public static class CountWords extends PTransform, + PCollection>> { + @Override + public PCollection> apply(PCollection lines) { + + // Convert lines of text into individual words. + PCollection words = lines.apply( + ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + PCollection> wordCounts = + words.apply(Count.perElement()); + + return wordCounts; + } + } + + /** A SimpleFunction that converts a Word and Count into a printable string. */ + public static class FormatAsTextFn extends SimpleFunction, String> { + @Override + public String apply(KV input) { + return input.getKey() + ": " + input.getValue(); + } + } + + /** + * Options supported by {@link WordCount}. + *

+ * Inherits standard configuration options. + */ + public interface Options extends PipelineOptions, FlinkPipelineOptions { + @Description("Path of the file to read from") + @Default.String("gs://dataflow-samples/shakespeare/kinglear.txt") + String getInput(); + void setInput(String value); + + @Description("Path of the file to write to") + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(Options.class); + options.setRunner(FlinkPipelineRunner.class); + + Pipeline p = Pipeline.create(options); + + p.apply(TextIO.Read.named("ReadLines").from(options.getInput())) + .apply(new CountWords()) + .apply(MapElements.via(new FormatAsTextFn())) + .apply(TextIO.Write.named("WriteCounts").to(options.getOutput())); + + p.run(); + } + +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/AutoComplete.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/AutoComplete.java new file mode 100644 index 000000000000..816812215feb --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/AutoComplete.java @@ -0,0 +1,387 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.beam.runners.flink.examples.streaming; + +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSocketSource; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.dataflow.sdk.io.*; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import org.joda.time.Duration; + +import java.io.IOException; +import java.util.List; + +/** + * To run the example, first open a socket on a terminal by executing the command: + *

  • + *
  • + * nc -lk 9999 + *
  • + * + * and then launch the example. Now whatever you type in the terminal is going to be + * the input to the program. + * */ +public class AutoComplete { + + /** + * A PTransform that takes as input a list of tokens and returns + * the most common tokens per prefix. + */ + public static class ComputeTopCompletions + extends PTransform, PCollection>>> { + private static final long serialVersionUID = 0; + + private final int candidatesPerPrefix; + private final boolean recursive; + + protected ComputeTopCompletions(int candidatesPerPrefix, boolean recursive) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.recursive = recursive; + } + + public static ComputeTopCompletions top(int candidatesPerPrefix, boolean recursive) { + return new ComputeTopCompletions(candidatesPerPrefix, recursive); + } + + @Override + public PCollection>> apply(PCollection input) { + PCollection candidates = input + // First count how often each token appears. + .apply(new Count.PerElement()) + + // Map the KV outputs of Count into our own CompletionCandiate class. + .apply(ParDo.named("CreateCompletionCandidates").of( + new DoFn, CompletionCandidate>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + CompletionCandidate cand = new CompletionCandidate(c.element().getKey(), c.element().getValue()); + c.output(cand); + } + })); + + // Compute the top via either a flat or recursive algorithm. + if (recursive) { + return candidates + .apply(new ComputeTopRecursive(candidatesPerPrefix, 1)) + .apply(Flatten.>>pCollections()); + } else { + return candidates + .apply(new ComputeTopFlat(candidatesPerPrefix, 1)); + } + } + } + + /** + * Lower latency, but more expensive. + */ + private static class ComputeTopFlat + extends PTransform, + PCollection>>> { + private static final long serialVersionUID = 0; + + private final int candidatesPerPrefix; + private final int minPrefix; + + public ComputeTopFlat(int candidatesPerPrefix, int minPrefix) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.minPrefix = minPrefix; + } + + @Override + public PCollection>> apply( + PCollection input) { + return input + // For each completion candidate, map it to all prefixes. + .apply(ParDo.of(new AllPrefixes(minPrefix))) + + // Find and return the top candiates for each prefix. + .apply(Top.largestPerKey(candidatesPerPrefix) + .withHotKeyFanout(new HotKeyFanout())); + } + + private static class HotKeyFanout implements SerializableFunction { + private static final long serialVersionUID = 0; + + @Override + public Integer apply(String input) { + return (int) Math.pow(4, 5 - input.length()); + } + } + } + + /** + * Cheaper but higher latency. + * + *

    Returns two PCollections, the first is top prefixes of size greater + * than minPrefix, and the second is top prefixes of size exactly + * minPrefix. + */ + private static class ComputeTopRecursive + extends PTransform, + PCollectionList>>> { + private static final long serialVersionUID = 0; + + private final int candidatesPerPrefix; + private final int minPrefix; + + public ComputeTopRecursive(int candidatesPerPrefix, int minPrefix) { + this.candidatesPerPrefix = candidatesPerPrefix; + this.minPrefix = minPrefix; + } + + private class KeySizePartitionFn implements PartitionFn>> { + private static final long serialVersionUID = 0; + + @Override + public int partitionFor(KV> elem, int numPartitions) { + return elem.getKey().length() > minPrefix ? 0 : 1; + } + } + + private static class FlattenTops + extends DoFn>, CompletionCandidate> { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + for (CompletionCandidate cc : c.element().getValue()) { + c.output(cc); + } + } + } + + @Override + public PCollectionList>> apply( + PCollection input) { + if (minPrefix > 10) { + // Base case, partitioning to return the output in the expected format. + return input + .apply(new ComputeTopFlat(candidatesPerPrefix, minPrefix)) + .apply(Partition.of(2, new KeySizePartitionFn())); + } else { + // If a candidate is in the top N for prefix a...b, it must also be in the top + // N for a...bX for every X, which is typlically a much smaller set to consider. + // First, compute the top candidate for prefixes of size at least minPrefix + 1. + PCollectionList>> larger = input + .apply(new ComputeTopRecursive(candidatesPerPrefix, minPrefix + 1)); + // Consider the top candidates for each prefix of length minPrefix + 1... + PCollection>> small = + PCollectionList + .of(larger.get(1).apply(ParDo.of(new FlattenTops()))) + // ...together with those (previously excluded) candidates of length + // exactly minPrefix... + .and(input.apply(Filter.by(new SerializableFunction() { + private static final long serialVersionUID = 0; + + @Override + public Boolean apply(CompletionCandidate c) { + return c.getValue().length() == minPrefix; + } + }))) + .apply("FlattenSmall", Flatten.pCollections()) + // ...set the key to be the minPrefix-length prefix... + .apply(ParDo.of(new AllPrefixes(minPrefix, minPrefix))) + // ...and (re)apply the Top operator to all of them together. + .apply(Top.largestPerKey(candidatesPerPrefix)); + + PCollection>> flattenLarger = larger + .apply("FlattenLarge", Flatten.>>pCollections()); + + return PCollectionList.of(flattenLarger).and(small); + } + } + } + + /** + * A DoFn that keys each candidate by all its prefixes. + */ + private static class AllPrefixes + extends DoFn> { + private static final long serialVersionUID = 0; + + private final int minPrefix; + private final int maxPrefix; + public AllPrefixes(int minPrefix) { + this(minPrefix, Integer.MAX_VALUE); + } + public AllPrefixes(int minPrefix, int maxPrefix) { + this.minPrefix = minPrefix; + this.maxPrefix = maxPrefix; + } + @Override + public void processElement(ProcessContext c) { + String word = c.element().value; + for (int i = minPrefix; i <= Math.min(word.length(), maxPrefix); i++) { + KV kv = KV.of(word.substring(0, i), c.element()); + c.output(kv); + } + } + } + + /** + * Class used to store tag-count pairs. + */ + @DefaultCoder(AvroCoder.class) + static class CompletionCandidate implements Comparable { + private long count; + private String value; + + public CompletionCandidate(String value, long count) { + this.value = value; + this.count = count; + } + + public String getValue() { + return value; + } + + // Empty constructor required for Avro decoding. + @SuppressWarnings("unused") + public CompletionCandidate() {} + + @Override + public int compareTo(CompletionCandidate o) { + if (this.count < o.count) { + return -1; + } else if (this.count == o.count) { + return this.value.compareTo(o.value); + } else { + return 1; + } + } + + @Override + public boolean equals(Object other) { + if (other instanceof CompletionCandidate) { + CompletionCandidate that = (CompletionCandidate) other; + return this.count == that.count && this.value.equals(that.value); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Long.valueOf(count).hashCode() ^ value.hashCode(); + } + + @Override + public String toString() { + return "CompletionCandidate[" + value + ", " + count + "]"; + } + } + + static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** + * Takes as input a the top candidates per prefix, and emits an entity + * suitable for writing to Datastore. + */ + static class FormatForPerTaskLocalFile extends DoFn>, String> + implements DoFn.RequiresWindowAccess{ + + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + StringBuilder str = new StringBuilder(); + KV> elem = c.element(); + + str.append(elem.getKey() +" @ "+ c.window() +" -> "); + for(CompletionCandidate cand: elem.getValue()) { + str.append(cand.toString() + " "); + } + System.out.println(str.toString()); + c.output(str.toString()); + } + } + + /** + * Options supported by this class. + * + *

    Inherits standard Dataflow configuration options. + */ + private interface Options extends WindowedWordCount.StreamingWordCountOptions { + @Description("Whether to use the recursive algorithm") + @Default.Boolean(true) + Boolean getRecursive(); + void setRecursive(Boolean value); + } + + public static void main(String[] args) throws IOException { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + options.setStreaming(true); + options.setCheckpointingInterval(1000L); + options.setNumberOfExecutionRetries(5); + options.setExecutionRetryDelay(3000L); + options.setRunner(FlinkPipelineRunner.class); + + PTransform> readSource = + Read.from(new UnboundedSocketSource<>("localhost", 9999, '\n', 3)).named("WordStream"); + WindowFn windowFn = FixedWindows.of(Duration.standardSeconds(options.getWindowSize())); + + // Create the pipeline. + Pipeline p = Pipeline.create(options); + PCollection>> toWrite = p + .apply(readSource) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Window.into(windowFn) + .triggering(AfterWatermark.pastEndOfWindow()).withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()) + .apply(ComputeTopCompletions.top(10, options.getRecursive())); + + toWrite + .apply(ParDo.named("FormatForPerTaskFile").of(new FormatForPerTaskLocalFile())) + .apply(TextIO.Write.to("./outputAutoComplete.txt")); + + p.run(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/JoinExamples.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/JoinExamples.java new file mode 100644 index 000000000000..3a8bdb0078d6 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/JoinExamples.java @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.beam.runners.flink.examples.streaming; + +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSocketSource; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PBegin; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import org.joda.time.Duration; + +/** + * To run the example, first open two sockets on two terminals by executing the commands: + *

  • + *
  • + * nc -lk 9999, and + *
  • + *
  • + * nc -lk 9998 + *
  • + * + * and then launch the example. Now whatever you type in the terminal is going to be + * the input to the program. + * */ +public class JoinExamples { + + static PCollection joinEvents(PCollection streamA, + PCollection streamB) throws Exception { + + final TupleTag firstInfoTag = new TupleTag<>(); + final TupleTag secondInfoTag = new TupleTag<>(); + + // transform both input collections to tuple collections, where the keys are country + // codes in both cases. + PCollection> firstInfo = streamA.apply( + ParDo.of(new ExtractEventDataFn())); + PCollection> secondInfo = streamB.apply( + ParDo.of(new ExtractEventDataFn())); + + // country code 'key' -> CGBKR (, ) + PCollection> kvpCollection = KeyedPCollectionTuple + .of(firstInfoTag, firstInfo) + .and(secondInfoTag, secondInfo) + .apply(CoGroupByKey.create()); + + // Process the CoGbkResult elements generated by the CoGroupByKey transform. + // country code 'key' -> string of , + PCollection> finalResultCollection = + kvpCollection.apply(ParDo.named("Process").of( + new DoFn, KV>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + String key = e.getKey(); + + String defaultA = "NO_VALUE"; + + // the following getOnly is a bit tricky because it expects to have + // EXACTLY ONE value in the corresponding stream and for the corresponding key. + + String lineA = e.getValue().getOnly(firstInfoTag, defaultA); + for (String lineB : c.element().getValue().getAll(secondInfoTag)) { + // Generate a string that combines information from both collection values + c.output(KV.of(key, "Value A: " + lineA + " - Value B: " + lineB)); + } + } + })); + + return finalResultCollection + .apply(ParDo.named("Format").of(new DoFn, String>() { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + String result = c.element().getKey() + " -> " + c.element().getValue(); + System.out.println(result); + c.output(result); + } + })); + } + + static class ExtractEventDataFn extends DoFn> { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + String line = c.element().toLowerCase(); + String key = line.split("\\s")[0]; + c.output(KV.of(key, line)); + } + } + + private interface Options extends WindowedWordCount.StreamingWordCountOptions { + + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + options.setStreaming(true); + options.setCheckpointingInterval(1000L); + options.setNumberOfExecutionRetries(5); + options.setExecutionRetryDelay(3000L); + options.setRunner(FlinkPipelineRunner.class); + + PTransform> readSourceA = + Read.from(new UnboundedSocketSource<>("localhost", 9999, '\n', 3)).named("FirstStream"); + PTransform> readSourceB = + Read.from(new UnboundedSocketSource<>("localhost", 9998, '\n', 3)).named("SecondStream"); + + WindowFn windowFn = FixedWindows.of(Duration.standardSeconds(options.getWindowSize())); + + Pipeline p = Pipeline.create(options); + + // the following two 'applys' create multiple inputs to our pipeline, one for each + // of our two input sources. + PCollection streamA = p.apply(readSourceA) + .apply(Window.into(windowFn) + .triggering(AfterWatermark.pastEndOfWindow()).withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + PCollection streamB = p.apply(readSourceB) + .apply(Window.into(windowFn) + .triggering(AfterWatermark.pastEndOfWindow()).withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + + PCollection formattedResults = joinEvents(streamA, streamB); + formattedResults.apply(TextIO.Write.to("./outputJoin.txt")); + p.run(); + } + +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/KafkaWindowedWordCountExample.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/KafkaWindowedWordCountExample.java new file mode 100644 index 000000000000..55cdc225b998 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/KafkaWindowedWordCountExample.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.examples.streaming; + +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedFlinkSource; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer082; +import org.apache.flink.streaming.util.serialization.SimpleStringSchema; +import org.joda.time.Duration; + +import java.util.Properties; + +public class KafkaWindowedWordCountExample { + + static final String KAFKA_TOPIC = "test"; // Default kafka topic to read from + static final String KAFKA_BROKER = "localhost:9092"; // Default kafka broker to contact + static final String GROUP_ID = "myGroup"; // Default groupId + static final String ZOOKEEPER = "localhost:2181"; // Default zookeeper to connect to for Kafka + + public static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + public static class FormatAsStringFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + String row = c.element().getKey() + " - " + c.element().getValue() + " @ " + c.timestamp().toString(); + System.out.println(row); + c.output(row); + } + } + + public interface KafkaStreamingWordCountOptions extends WindowedWordCount.StreamingWordCountOptions { + @Description("The Kafka topic to read from") + @Default.String(KAFKA_TOPIC) + String getKafkaTopic(); + + void setKafkaTopic(String value); + + @Description("The Kafka Broker to read from") + @Default.String(KAFKA_BROKER) + String getBroker(); + + void setBroker(String value); + + @Description("The Zookeeper server to connect to") + @Default.String(ZOOKEEPER) + String getZookeeper(); + + void setZookeeper(String value); + + @Description("The groupId") + @Default.String(GROUP_ID) + String getGroup(); + + void setGroup(String value); + + } + + public static void main(String[] args) { + PipelineOptionsFactory.register(KafkaStreamingWordCountOptions.class); + KafkaStreamingWordCountOptions options = PipelineOptionsFactory.fromArgs(args).as(KafkaStreamingWordCountOptions.class); + options.setJobName("KafkaExample"); + options.setStreaming(true); + options.setCheckpointingInterval(1000L); + options.setNumberOfExecutionRetries(5); + options.setExecutionRetryDelay(3000L); + options.setRunner(FlinkPipelineRunner.class); + + System.out.println(options.getKafkaTopic() +" "+ options.getZookeeper() +" "+ options.getBroker() +" "+ options.getGroup() ); + Pipeline pipeline = Pipeline.create(options); + + Properties p = new Properties(); + p.setProperty("zookeeper.connect", options.getZookeeper()); + p.setProperty("bootstrap.servers", options.getBroker()); + p.setProperty("group.id", options.getGroup()); + + // this is the Flink consumer that reads the input to + // the program from a kafka topic. + FlinkKafkaConsumer082 kafkaConsumer = new FlinkKafkaConsumer082<>( + options.getKafkaTopic(), + new SimpleStringSchema(), p); + + PCollection words = pipeline + .apply(Read.from(new UnboundedFlinkSource(options, kafkaConsumer)).named("StreamingWordCount")) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Window.into(FixedWindows.of(Duration.standardSeconds(options.getWindowSize()))) + .triggering(AfterWatermark.pastEndOfWindow()).withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + + PCollection> wordCounts = + words.apply(Count.perElement()); + + wordCounts.apply(ParDo.of(new FormatAsStringFn())) + .apply(TextIO.Write.to("./outputKafka.txt")); + + pipeline.run(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/WindowedWordCount.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/WindowedWordCount.java new file mode 100644 index 000000000000..7eb69ba870a2 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/examples/streaming/WindowedWordCount.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.examples.streaming; + +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSocketSource; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.*; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * To run the example, first open a socket on a terminal by executing the command: + *
  • + *
  • + * nc -lk 9999 + *
  • + * + * and then launch the example. Now whatever you type in the terminal is going to be + * the input to the program. + * */ +public class WindowedWordCount { + + private static final Logger LOG = LoggerFactory.getLogger(WindowedWordCount.class); + + static final long WINDOW_SIZE = 10; // Default window duration in seconds + static final long SLIDE_SIZE = 5; // Default window slide in seconds + + static class FormatAsStringFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + String row = c.element().getKey() + " - " + c.element().getValue() + " @ " + c.timestamp().toString(); + c.output(row); + } + } + + static class ExtractWordsFn extends DoFn { + private final Aggregator emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.addValue(1L); + } + + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + public interface StreamingWordCountOptions extends org.apache.beam.runners.flink.examples.WordCount.Options { + @Description("Sliding window duration, in seconds") + @Default.Long(WINDOW_SIZE) + Long getWindowSize(); + + void setWindowSize(Long value); + + @Description("Window slide, in seconds") + @Default.Long(SLIDE_SIZE) + Long getSlide(); + + void setSlide(Long value); + } + + public static void main(String[] args) throws IOException { + StreamingWordCountOptions options = PipelineOptionsFactory.fromArgs(args).withValidation().as(StreamingWordCountOptions.class); + options.setStreaming(true); + options.setWindowSize(10L); + options.setSlide(5L); + options.setCheckpointingInterval(1000L); + options.setNumberOfExecutionRetries(5); + options.setExecutionRetryDelay(3000L); + options.setRunner(FlinkPipelineRunner.class); + + LOG.info("Windpwed WordCount with Sliding Windows of " + options.getWindowSize() + + " sec. and a slide of " + options.getSlide()); + + Pipeline pipeline = Pipeline.create(options); + + PCollection words = pipeline + .apply(Read.from(new UnboundedSocketSource<>("localhost", 9999, '\n', 3)).named("StreamingWordCount")) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(options.getWindowSize())) + .every(Duration.standardSeconds(options.getSlide()))) + .triggering(AfterWatermark.pastEndOfWindow()).withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + + PCollection> wordCounts = + words.apply(Count.perElement()); + + wordCounts.apply(ParDo.of(new FormatAsStringFn())) + .apply(TextIO.Write.to("./outputWordCount.txt")); + + pipeline.run(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java new file mode 100644 index 000000000000..71e3b54b3d57 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.io; + +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PDone; + +/** + * Transform for printing the contents of a {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * to standard output. + * + * This is Flink-specific and will only work when executed using the + * {@link org.apache.beam.runners.flink.FlinkPipelineRunner}. + */ +public class ConsoleIO { + + /** + * A PTransform that writes a PCollection to a standard output. + */ + public static class Write { + + /** + * Returns a ConsoleIO.Write PTransform with a default step name. + */ + public static Bound create() { + return new Bound(); + } + + /** + * Returns a ConsoleIO.Write PTransform with the given step name. + */ + public static Bound named(String name) { + return new Bound().named(name); + } + + /** + * A PTransform that writes a bounded PCollection to standard output. + */ + public static class Bound extends PTransform, PDone> { + private static final long serialVersionUID = 0; + + Bound() { + super("ConsoleIO.Write"); + } + + Bound(String name) { + super(name); + } + + /** + * Returns a new ConsoleIO.Write PTransform that's like this one but with the given + * step + * name. Does not modify this object. + */ + public Bound named(String name) { + return new Bound(name); + } + + @Override + public PDone apply(PCollection input) { + return PDone.in(input.getPipeline()); + } + } + } +} + diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java new file mode 100644 index 000000000000..9b47a08339b1 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.values.PValue; +import org.apache.flink.api.java.ExecutionEnvironment; + +/** + * FlinkBatchPipelineTranslator knows how to translate Pipeline objects into Flink Jobs. + * This is based on {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator} + */ +public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator { + + /** + * The necessary context in the case of a batch job. + */ + private final FlinkBatchTranslationContext batchContext; + + private int depth = 0; + + /** + * Composite transform that we want to translate before proceeding with other transforms. + */ + private PTransform currentCompositeTransform; + + public FlinkBatchPipelineTranslator(ExecutionEnvironment env, PipelineOptions options) { + this.batchContext = new FlinkBatchTranslationContext(env, options); + } + + // -------------------------------------------------------------------------------------------- + // Pipeline Visitor Methods + // -------------------------------------------------------------------------------------------- + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); + + PTransform transform = node.getTransform(); + if (transform != null && currentCompositeTransform == null) { + + BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); + if (translator != null) { + currentCompositeTransform = transform; + if (transform instanceof CoGroupByKey && node.getInput().expand().size() != 2) { + // we can only optimize CoGroupByKey for input size 2 + currentCompositeTransform = null; + } + } + } + this.depth++; + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform != null && currentCompositeTransform == transform) { + + BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); + if (translator != null) { + System.out.println(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); + applyBatchTransform(transform, node, translator); + currentCompositeTransform = null; + } else { + throw new IllegalStateException("Attempted to translate composite transform " + + "but no translator was found: " + currentCompositeTransform); + } + } + this.depth--; + System.out.println(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); + } + + @Override + public void visitTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); + if (currentCompositeTransform != null) { + // ignore it + return; + } + + // get the transformation corresponding to hte node we are + // currently visiting and translate it into its Flink alternative. + + PTransform transform = node.getTransform(); + BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); + if (translator == null) { + System.out.println(node.getTransform().getClass()); + throw new UnsupportedOperationException("The transform " + transform + " is currently not supported."); + } + applyBatchTransform(transform, node, translator); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + // do nothing here + } + + private > void applyBatchTransform(PTransform transform, TransformTreeNode node, BatchTransformTranslator translator) { + + @SuppressWarnings("unchecked") + T typedTransform = (T) transform; + + @SuppressWarnings("unchecked") + BatchTransformTranslator typedTranslator = (BatchTransformTranslator) translator; + + // create the applied PTransform on the batchContext + batchContext.setCurrentTransform(AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform)); + typedTranslator.translateNode(typedTransform, batchContext); + } + + /** + * A translator of a {@link PTransform}. + */ + public interface BatchTransformTranslator { + void translateNode(Type transform, FlinkBatchTranslationContext context); + } + + private static String genSpaces(int n) { + String s = ""; + for (int i = 0; i < n; i++) { + s += "| "; + } + return s; + } + + private static String formatNodeName(TransformTreeNode node) { + return node.toString().split("@")[1] + node.getTransform(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java new file mode 100644 index 000000000000..48c783d39c9d --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -0,0 +1,594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.io.ConsoleIO; +import org.apache.beam.runners.flink.translation.functions.FlinkCoGroupKeyedListAggregator; +import org.apache.beam.runners.flink.translation.functions.FlinkCreateFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkKeyedListAggregationFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputDoFnFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction; +import org.apache.beam.runners.flink.translation.functions.UnionCoder; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.SinkOutputFormat; +import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.Write; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.operators.Keys; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.io.AvroInputFormat; +import org.apache.flink.api.java.io.AvroOutputFormat; +import org.apache.flink.api.java.io.TextInputFormat; +import org.apache.flink.api.java.operators.CoGroupOperator; +import org.apache.flink.api.java.operators.DataSink; +import org.apache.flink.api.java.operators.DataSource; +import org.apache.flink.api.java.operators.FlatMapOperator; +import org.apache.flink.api.java.operators.GroupCombineOperator; +import org.apache.flink.api.java.operators.GroupReduceOperator; +import org.apache.flink.api.java.operators.Grouping; +import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.operators.UnsortedGrouping; +import org.apache.flink.core.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Translators for transforming + * Dataflow {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s to + * Flink {@link org.apache.flink.api.java.DataSet}s + */ +public class FlinkBatchTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map, FlinkBatchPipelineTranslator.BatchTransformTranslator> TRANSLATORS = new HashMap<>(); + + // register the known translators + static { + TRANSLATORS.put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); + + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); + // we don't need this because we translate the Combine.PerKey directly + //TRANSLATORS.put(Combine.GroupedValues.class, new CombineGroupedValuesTranslator()); + + TRANSLATORS.put(Create.Values.class, new CreateTranslatorBatch()); + + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslatorBatch()); + + TRANSLATORS.put(GroupByKey.GroupByKeyOnly.class, new GroupByKeyOnlyTranslatorBatch()); + // TODO we're currently ignoring windows here but that has to change in the future + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); + + TRANSLATORS.put(CoGroupByKey.class, new CoGroupByKeyTranslatorBatch()); + + TRANSLATORS.put(AvroIO.Read.Bound.class, new AvroIOReadTranslatorBatch()); + TRANSLATORS.put(AvroIO.Write.Bound.class, new AvroIOWriteTranslatorBatch()); + + TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); + TRANSLATORS.put(Write.Bound.class, new WriteSinkTranslatorBatch()); + + TRANSLATORS.put(TextIO.Read.Bound.class, new TextIOReadTranslatorBatch()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteTranslatorBatch()); + + // Flink-specific + TRANSLATORS.put(ConsoleIO.Write.Bound.class, new ConsoleIOWriteTranslatorBatch()); + + } + + + public static FlinkBatchPipelineTranslator.BatchTransformTranslator getTranslator(PTransform transform) { + return TRANSLATORS.get(transform.getClass()); + } + + private static class ReadSourceTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Read.Bounded transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + BoundedSource source = transform.getSource(); + PCollection output = context.getOutput(transform); + Coder coder = output.getCoder(); + + TypeInformation typeInformation = context.getTypeInfo(output); + + DataSource dataSource = new DataSource<>(context.getExecutionEnvironment(), + new SourceInputFormat<>(source, context.getPipelineOptions()), typeInformation, name); + + context.setOutputDataSet(output, dataSource); + } + } + + private static class AvroIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOReadTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Read.Bound transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); +// Schema schema = transform.getSchema(); + PValue output = context.getOutput(transform); + + TypeInformation typeInformation = context.getTypeInfo(output); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class avroType = (Class) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSource source = new DataSource<>(context.getExecutionEnvironment(), + new AvroInputFormat<>(new Path(path), extractedAvroType), + typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class AvroIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOWriteTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Write.Bound transform, FlinkBatchTranslationContext context) { + DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", + filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class avroType = (Class) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSink dataSink = inputDataSet.output(new AvroOutputFormat<>(new Path + (filenamePrefix), extractedAvroType)); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class TextIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOReadTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Read.Bound transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); + + TextIO.CompressionType compressionType = transform.getCompressionType(); + boolean needsValidation = transform.needsValidation(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.CompressionType not yet supported. Is: {}.", compressionType); + LOG.warn("Translation of TextIO.Read.needsValidation not yet supported. Is: {}.", needsValidation); + + PValue output = context.getOutput(transform); + + TypeInformation typeInformation = context.getTypeInfo(output); + DataSource source = new DataSource<>(context.getExecutionEnvironment(), new TextInputFormat(new Path(path)), typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class TextIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Write.Bound transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet inputDataSet = context.getInputDataSet(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + //inputDataSet.print(); + DataSink dataSink = inputDataSet.writeAsText(filenamePrefix); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class ConsoleIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator { + @Override + public void translateNode(ConsoleIO.Write.Bound transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet inputDataSet = context.getInputDataSet(input); + inputDataSet.printOnTaskManager(transform.getName()); + } + } + + private static class WriteSinkTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Write.Bound transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + PValue input = context.getInput(transform); + DataSet inputDataSet = context.getInputDataSet(input); + + inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())).name(name); + } + } + + private static class GroupByKeyOnlyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(GroupByKey.GroupByKeyOnly transform, FlinkBatchTranslationContext context) { + DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction, KV>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator, KV>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + /** + * Translates a GroupByKey while ignoring window assignments. This is identical to the {@link GroupByKeyOnlyTranslatorBatch} + */ + private static class GroupByKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(GroupByKey transform, FlinkBatchTranslationContext context) { + DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction, KV>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator, KV>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class CombinePerKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Combine.PerKey transform, FlinkBatchTranslationContext context) { + DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + @SuppressWarnings("unchecked") + Combine.KeyedCombineFn keyedCombineFn = (Combine.KeyedCombineFn) transform.getFn(); + + KvCoder inputCoder = (KvCoder) context.getInput(transform).getCoder(); + + Coder accumulatorCoder = + null; + try { + accumulatorCoder = keyedCombineFn.getAccumulatorCoder(context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + e.printStackTrace(); + // TODO + } + + TypeInformation> kvCoderTypeInformation = new KvCoderTypeInformation<>(inputCoder); + TypeInformation> partialReduceTypeInfo = new KvCoderTypeInformation<>(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)); + + Grouping> inputGrouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); + + FlinkPartialReduceFunction partialReduceFunction = new FlinkPartialReduceFunction<>(keyedCombineFn); + + // Partially GroupReduce the values into the intermediate format VA (combine) + GroupCombineOperator, KV> groupCombine = + new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, + "GroupCombine: " + transform.getName()); + + // Reduce fully to VO + GroupReduceFunction, KV> reduceFunction = new FlinkReduceFunction<>(keyedCombineFn); + + TypeInformation> reduceTypeInfo = context.getTypeInfo(context.getOutput(transform)); + + Grouping> intermediateGrouping = new UnsortedGrouping<>(groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + + // Fully reduce the values and create output format VO + GroupReduceOperator, KV> outputDataSet = + new GroupReduceOperator<>(intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + +// private static class CombineGroupedValuesTranslator implements FlinkPipelineTranslator.TransformTranslator> { +// +// @Override +// public void translateNode(Combine.GroupedValues transform, TranslationContext context) { +// DataSet> inputDataSet = context.getInputDataSet(transform.getInput()); +// +// Combine.KeyedCombineFn keyedCombineFn = transform.getFn(); +// +// GroupReduceFunction, KV> groupReduceFunction = new FlinkCombineFunction<>(keyedCombineFn); +// +// TypeInformation> typeInformation = context.getTypeInfo(transform.getOutput()); +// +// Grouping> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{""}, inputDataSet.getType())); +// +// GroupReduceOperator, KV> outputDataSet = +// new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); +// context.setOutputDataSet(transform.getOutput(), outputDataSet); +// } +// } + + private static class ParDoBoundTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorBatch.class); + + @Override + public void translateNode(ParDo.Bound transform, FlinkBatchTranslationContext context) { + DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn doFn = transform.getFn(); + + TypeInformation typeInformation = context.getTypeInfo(context.getOutput(transform)); + + FlinkDoFnFunction doFnWrapper = new FlinkDoFnFunction<>(doFn, context.getPipelineOptions()); + MapPartitionOperator outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class ParDoBoundMultiTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslatorBatch.class); + + @Override + public void translateNode(ParDo.BoundMulti transform, FlinkBatchTranslationContext context) { + DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn doFn = transform.getFn(); + + Map, PCollection> outputs = context.getOutput(transform).getAll(); + + Map, Integer> outputMap = Maps.newHashMap(); + // put the main output at index 0, FlinkMultiOutputDoFnFunction also expects this + outputMap.put(transform.getMainOutputTag(), 0); + int count = 1; + for (TupleTag tag: outputs.keySet()) { + if (!outputMap.containsKey(tag)) { + outputMap.put(tag, count++); + } + } + + // collect all output Coders and create a UnionCoder for our tagged outputs + List> outputCoders = Lists.newArrayList(); + for (PCollection coll: outputs.values()) { + outputCoders.add(coll.getCoder()); + } + + UnionCoder unionCoder = UnionCoder.of(outputCoders); + + @SuppressWarnings("unchecked") + TypeInformation typeInformation = new CoderTypeInformation<>(unionCoder); + + @SuppressWarnings("unchecked") + FlinkMultiOutputDoFnFunction doFnWrapper = new FlinkMultiOutputDoFnFunction(doFn, context.getPipelineOptions(), outputMap); + MapPartitionOperator outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + for (Map.Entry, PCollection> output: outputs.entrySet()) { + TypeInformation outputType = context.getTypeInfo(output.getValue()); + int outputTag = outputMap.get(output.getKey()); + FlinkMultiOutputPruningFunction pruningFunction = new FlinkMultiOutputPruningFunction<>(outputTag); + FlatMapOperator pruningOperator = new + FlatMapOperator<>(outputDataSet, outputType, + pruningFunction, output.getValue().getName()); + context.setOutputDataSet(output.getValue(), pruningOperator); + + } + } + } + + private static class FlattenPCollectionTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList transform, FlinkBatchTranslationContext context) { + List> allInputs = context.getInput(transform).getAll(); + DataSet result = null; + for(PCollection collection : allInputs) { + DataSet current = context.getInputDataSet(collection); + if (result == null) { + result = current; + } else { + result = result.union(current); + } + } + context.setOutputDataSet(context.getOutput(transform), result); + } + } + + private static class CreatePCollectionViewTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + @Override + public void translateNode(View.CreatePCollectionView transform, FlinkBatchTranslationContext context) { + DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + PCollectionView input = transform.apply(null); + context.setSideInputDataSet(input, inputDataSet); + } + } + + private static class CreateTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Create.Values transform, FlinkBatchTranslationContext context) { + TypeInformation typeInformation = context.getOutputTypeInfo(); + Iterable elements = transform.getElements(); + + // we need to serialize the elements to byte arrays, since they might contain + // elements that are not serializable by Java serialization. We deserialize them + // in the FlatMap function using the Coder. + + List serializedElements = Lists.newArrayList(); + Coder coder = context.getOutput(transform).getCoder(); + for (OUT element: elements) { + ByteArrayOutputStream bao = new ByteArrayOutputStream(); + try { + coder.encode(element, bao, Coder.Context.OUTER); + serializedElements.add(bao.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Could not serialize Create elements using Coder: " + e); + } + } + + DataSet initDataSet = context.getExecutionEnvironment().fromElements(1); + FlinkCreateFunction flatMapFunction = new FlinkCreateFunction<>(serializedElements, coder); + FlatMapOperator outputDataSet = new FlatMapOperator<>(initDataSet, typeInformation, flatMapFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static void transformSideInputs(List> sideInputs, + MapPartitionOperator outputDataSet, + FlinkBatchTranslationContext context) { + // get corresponding Flink broadcast DataSets + for(PCollectionView input : sideInputs) { + DataSet broadcastSet = context.getSideInputDataSet(input); + outputDataSet.withBroadcastSet(broadcastSet, input.getTagInternal().getId()); + } + } + +// Disabled because it depends on a pending pull request to the DataFlowSDK + /** + * Special composite transform translator. Only called if the CoGroup is two dimensional. + * @param + */ + private static class CoGroupByKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(CoGroupByKey transform, FlinkBatchTranslationContext context) { + KeyedPCollectionTuple input = context.getInput(transform); + + CoGbkResultSchema schema = input.getCoGbkResultSchema(); + List> keyedCollections = input.getKeyedCollections(); + + KeyedPCollectionTuple.TaggedKeyedPCollection taggedCollection1 = keyedCollections.get(0); + KeyedPCollectionTuple.TaggedKeyedPCollection taggedCollection2 = keyedCollections.get(1); + + TupleTag tupleTag1 = taggedCollection1.getTupleTag(); + TupleTag tupleTag2 = taggedCollection2.getTupleTag(); + + PCollection> collection1 = taggedCollection1.getCollection(); + PCollection> collection2 = taggedCollection2.getCollection(); + + DataSet> inputDataSet1 = context.getInputDataSet(collection1); + DataSet> inputDataSet2 = context.getInputDataSet(collection2); + + TypeInformation> typeInfo = context.getOutputTypeInfo(); + + FlinkCoGroupKeyedListAggregator aggregator = new FlinkCoGroupKeyedListAggregator<>(schema, tupleTag1, tupleTag2); + + Keys.ExpressionKeys> keySelector1 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet1.getType()); + Keys.ExpressionKeys> keySelector2 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet2.getType()); + + DataSet> out = new CoGroupOperator<>(inputDataSet1, inputDataSet2, + keySelector1, keySelector2, + aggregator, typeInfo, null, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), out); + } + } + + // -------------------------------------------------------------------------------------------- + // Miscellaneous + // -------------------------------------------------------------------------------------------- + + private FlinkBatchTransformTranslators() {} +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java new file mode 100644 index 000000000000..22943183db86 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkBatchTranslationContext { + + private final Map> dataSets; + private final Map, DataSet> broadcastDataSets; + + private final ExecutionEnvironment env; + private final PipelineOptions options; + + private AppliedPTransform currentTransform; + + // ------------------------------------------------------------------------ + + public FlinkBatchTranslationContext(ExecutionEnvironment env, PipelineOptions options) { + this.env = env; + this.options = options; + this.dataSets = new HashMap<>(); + this.broadcastDataSets = new HashMap<>(); + } + + // ------------------------------------------------------------------------ + + public ExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public DataSet getInputDataSet(PValue value) { + return (DataSet) dataSets.get(value); + } + + public void setOutputDataSet(PValue value, DataSet set) { + if (!dataSets.containsKey(value)) { + dataSets.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public DataSet getSideInputDataSet(PCollectionView value) { + return (DataSet) broadcastDataSets.get(value); + } + + public void setSideInputDataSet(PCollectionView value, DataSet set) { + if (!broadcastDataSets.containsKey(value)) { + broadcastDataSets.put(value, set); + } + } + + @SuppressWarnings("unchecked") + public TypeInformation getTypeInfo(PInput output) { + if (output instanceof TypedPValue) { + Coder outputCoder = ((TypedPValue) output).getCoder(); + if (outputCoder instanceof KvCoder) { + return new KvCoderTypeInformation((KvCoder) outputCoder); + } else { + return new CoderTypeInformation(outputCoder); + } + } + return new GenericTypeInfo<>((Class)Object.class); + } + + public TypeInformation getInputTypeInfo() { + return getTypeInfo(currentTransform.getInput()); + } + + public TypeInformation getOutputTypeInfo() { + return getTypeInfo((PValue) currentTransform.getOutput()); + } + + @SuppressWarnings("unchecked") + I getInput(PTransform transform) { + return (I) currentTransform.getInput(); + } + + @SuppressWarnings("unchecked") + O getOutput(PTransform transform) { + return (O) currentTransform.getOutput(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java new file mode 100644 index 000000000000..9407bf564aef --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; + +/** + * The role of this class is to translate the Beam operators to + * their Flink counterparts. If we have a streaming job, this is instantiated as a + * {@link FlinkStreamingPipelineTranslator}. In other case, i.e. for a batch job, + * a {@link FlinkBatchPipelineTranslator} is created. Correspondingly, the + * {@link com.google.cloud.dataflow.sdk.values.PCollection}-based user-provided job is translated into + * a {@link org.apache.flink.streaming.api.datastream.DataStream} (for streaming) or a + * {@link org.apache.flink.api.java.DataSet} (for batch) one. + */ +public abstract class FlinkPipelineTranslator implements Pipeline.PipelineVisitor { + + public void translate(Pipeline pipeline) { + pipeline.traverseTopologically(this); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java new file mode 100644 index 000000000000..60fba0f0815f --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PValue; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +/** + * This is a {@link FlinkPipelineTranslator} for streaming jobs. Its role is to translate the user-provided + * {@link com.google.cloud.dataflow.sdk.values.PCollection}-based job into a + * {@link org.apache.flink.streaming.api.datastream.DataStream} one. + * + * This is based on {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator} + * */ +public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator { + + /** The necessary context in the case of a straming job. */ + private final FlinkStreamingTranslationContext streamingContext; + + private int depth = 0; + + /** Composite transform that we want to translate before proceeding with other transforms. */ + private PTransform currentCompositeTransform; + + public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, PipelineOptions options) { + this.streamingContext = new FlinkStreamingTranslationContext(env, options); + } + + // -------------------------------------------------------------------------------------------- + // Pipeline Visitor Methods + // -------------------------------------------------------------------------------------------- + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); + + PTransform transform = node.getTransform(); + if (transform != null && currentCompositeTransform == null) { + + StreamTransformTranslator translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + currentCompositeTransform = transform; + } + } + this.depth++; + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform transform = node.getTransform(); + if (transform != null && currentCompositeTransform == transform) { + + StreamTransformTranslator translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + System.out.println(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); + applyStreamingTransform(transform, node, translator); + currentCompositeTransform = null; + } else { + throw new IllegalStateException("Attempted to translate composite transform " + + "but no translator was found: " + currentCompositeTransform); + } + } + this.depth--; + System.out.println(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); + } + + @Override + public void visitTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); + if (currentCompositeTransform != null) { + // ignore it + return; + } + + // get the transformation corresponding to hte node we are + // currently visiting and translate it into its Flink alternative. + + PTransform transform = node.getTransform(); + StreamTransformTranslator translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator == null) { + System.out.println(node.getTransform().getClass()); + throw new UnsupportedOperationException("The transform " + transform + " is currently not supported."); + } + applyStreamingTransform(transform, node, translator); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + // do nothing here + } + + private > void applyStreamingTransform(PTransform transform, TransformTreeNode node, StreamTransformTranslator translator) { + + @SuppressWarnings("unchecked") + T typedTransform = (T) transform; + + @SuppressWarnings("unchecked") + StreamTransformTranslator typedTranslator = (StreamTransformTranslator) translator; + + // create the applied PTransform on the streamingContext + streamingContext.setCurrentTransform(AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform)); + typedTranslator.translateNode(typedTransform, streamingContext); + } + + /** + * The interface that every Flink translator of a Beam operator should implement. + * This interface is for streaming jobs. For examples of such translators see + * {@link FlinkStreamingTransformTranslators}. + */ + public interface StreamTransformTranslator { + void translateNode(Type transform, FlinkStreamingTranslationContext context); + } + + private static String genSpaces(int n) { + String s = ""; + for (int i = 0; i < n; i++) { + s += "| "; + } + return s; + } + + private static String formatNodeName(TransformTreeNode node) { + return node.toString().split("@")[1] + node.getTransform(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java new file mode 100644 index 000000000000..bdefeaf80a7c --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink.translation; + +import org.apache.beam.runners.flink.translation.functions.UnionCoder; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.streaming.*; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.FlinkStreamingCreateFunction; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedFlinkSource; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.streaming.api.datastream.*; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.*; + +/** + * This class contains all the mappings between Beam and Flink + * streaming transformations. The {@link FlinkStreamingPipelineTranslator} + * traverses the Beam job and comes here to translate the encountered Beam transformations + * into Flink one, based on the mapping available in this class. + */ +public class FlinkStreamingTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map, FlinkStreamingPipelineTranslator.StreamTransformTranslator> TRANSLATORS = new HashMap<>(); + + // here you can find all the available translators. + static { + TRANSLATORS.put(Create.Values.class, new CreateStreamingTranslator()); + TRANSLATORS.put(Read.Unbounded.class, new UnboundedReadSourceTranslator()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundStreamingTranslator()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteBoundStreamingTranslator()); + TRANSLATORS.put(Window.Bound.class, new WindowBoundTranslator()); + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslator()); + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslator()); + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslator()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiStreamingTranslator()); + } + + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getTranslator(PTransform transform) { + return TRANSLATORS.get(transform.getClass()); + } + + // -------------------------------------------------------------------------------------------- + // Transformation Implementations + // -------------------------------------------------------------------------------------------- + + private static class CreateStreamingTranslator implements + FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(Create.Values transform, FlinkStreamingTranslationContext context) { + PCollection output = context.getOutput(transform); + Iterable elements = transform.getElements(); + + // we need to serialize the elements to byte arrays, since they might contain + // elements that are not serializable by Java serialization. We deserialize them + // in the FlatMap function using the Coder. + + List serializedElements = Lists.newArrayList(); + Coder elementCoder = context.getOutput(transform).getCoder(); + for (OUT element: elements) { + ByteArrayOutputStream bao = new ByteArrayOutputStream(); + try { + elementCoder.encode(element, bao, Coder.Context.OUTER); + serializedElements.add(bao.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Could not serialize Create elements using Coder: " + e); + } + } + + + DataStream initDataSet = context.getExecutionEnvironment().fromElements(1); + + FlinkStreamingCreateFunction createFunction = + new FlinkStreamingCreateFunction<>(serializedElements, elementCoder); + + WindowedValue.ValueOnlyWindowedValueCoder windowCoder = WindowedValue.getValueOnlyCoder(elementCoder); + TypeInformation> outputType = new CoderTypeInformation<>(windowCoder); + + DataStream> outputDataStream = initDataSet.flatMap(createFunction) + .returns(outputType); + + context.setOutputDataStream(context.getOutput(transform), outputDataStream); + } + } + + + private static class TextIOWriteBoundStreamingTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteBoundStreamingTranslator.class); + + @Override + public void translateNode(TextIO.Write.Bound transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream> inputDataStream = context.getInputDataStream(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + DataStream dataSink = inputDataStream.flatMap(new FlatMapFunction, String>() { + @Override + public void flatMap(WindowedValue value, Collector out) throws Exception { + out.collect(value.getValue().toString()); + } + }); + DataStreamSink output = dataSink.writeAsText(filenamePrefix, FileSystem.WriteMode.OVERWRITE); + + if (numShards > 0) { + output.setParallelism(numShards); + } + } + } + + private static class UnboundedReadSourceTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(Read.Unbounded transform, FlinkStreamingTranslationContext context) { + PCollection output = context.getOutput(transform); + + DataStream> source; + if (transform.getSource().getClass().equals(UnboundedFlinkSource.class)) { + UnboundedFlinkSource flinkSource = (UnboundedFlinkSource) transform.getSource(); + source = context.getExecutionEnvironment() + .addSource(flinkSource.getFlinkSource()) + .flatMap(new FlatMapFunction>() { + @Override + public void flatMap(String s, Collector> collector) throws Exception { + collector.collect(WindowedValue.of(s, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + } + }); + } else { + source = context.getExecutionEnvironment() + .addSource(new UnboundedSourceWrapper<>(context.getPipelineOptions(), transform)); + } + context.setOutputDataStream(output, source); + } + } + + private static class ParDoBoundStreamingTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(ParDo.Bound transform, FlinkStreamingTranslationContext context) { + PCollection output = context.getOutput(transform); + + final WindowingStrategy windowingStrategy = + (WindowingStrategy) + context.getOutput(transform).getWindowingStrategy(); + + WindowedValue.WindowedValueCoder outputStreamCoder = WindowedValue.getFullCoder(output.getCoder(), + windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation> outputWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundWrapper doFnWrapper = new FlinkParDoBoundWrapper<>( + context.getPipelineOptions(), windowingStrategy, transform.getFn()); + DataStream> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator> outDataStream = inputDataStream.flatMap(doFnWrapper) + .returns(outputWindowedValueCoder); + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + + public static class WindowBoundTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(Window.Bound transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream> inputDataStream = context.getInputDataStream(input); + + final WindowingStrategy windowingStrategy = + (WindowingStrategy) + context.getOutput(transform).getWindowingStrategy(); + + final WindowFn windowFn = windowingStrategy.getWindowFn(); + + WindowedValue.WindowedValueCoder outputStreamCoder = WindowedValue.getFullCoder( + context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation> outputWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + final FlinkParDoBoundWrapper windowDoFnAssigner = new FlinkParDoBoundWrapper<>( + context.getPipelineOptions(), windowingStrategy, createWindowAssigner(windowFn)); + + SingleOutputStreamOperator> windowedStream = + inputDataStream.flatMap(windowDoFnAssigner).returns(outputWindowedValueCoder); + context.setOutputDataStream(context.getOutput(transform), windowedStream); + } + + private static DoFn createWindowAssigner(final WindowFn windowFn) { + return new DoFn() { + + @Override + public void processElement(final ProcessContext c) throws Exception { + Collection windows = windowFn.assignWindows( + windowFn.new AssignContext() { + @Override + public T element() { + return c.element(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public Collection windows() { + return c.windowingInternals().windows(); + } + }); + + c.windowingInternals().outputWindowedValue( + c.element(), c.timestamp(), windows, c.pane()); + } + }; + } + } + + public static class GroupByKeyTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(GroupByKey transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream>> inputDataStream = context.getInputDataStream(input); + KvCoder inputKvCoder = (KvCoder) context.getInput(transform).getCoder(); + + KeyedStream>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + DataStream>>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.createForIterable(context.getPipelineOptions(), + context.getInput(transform), groupByKStream); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class CombinePerKeyTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(Combine.PerKey transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream>> inputDataStream = context.getInputDataStream(input); + KvCoder inputKvCoder = (KvCoder) context.getInput(transform).getCoder(); + KvCoder outputKvCoder = (KvCoder) context.getOutput(transform).getCoder(); + + KeyedStream>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + Combine.KeyedCombineFn combineFn = (Combine.KeyedCombineFn) transform.getFn(); + DataStream>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.create(context.getPipelineOptions(), + context.getInput(transform), groupByKStream, combineFn, outputKvCoder); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class FlattenPCollectionTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList transform, FlinkStreamingTranslationContext context) { + List> allInputs = context.getInput(transform).getAll(); + DataStream result = null; + for (PCollection collection : allInputs) { + DataStream current = context.getInputDataStream(collection); + result = (result == null) ? current : result.union(current); + } + context.setOutputDataStream(context.getOutput(transform), result); + } + } + + public static class ParDoBoundMultiStreamingTranslator implements FlinkStreamingPipelineTranslator.StreamTransformTranslator> { + + private final int MAIN_TAG_INDEX = 0; + + @Override + public void translateNode(ParDo.BoundMulti transform, FlinkStreamingTranslationContext context) { + + // we assume that the transformation does not change the windowing strategy. + WindowingStrategy windowingStrategy = context.getInput(transform).getWindowingStrategy(); + + Map, PCollection> outputs = context.getOutput(transform).getAll(); + Map, Integer> tagsToLabels = transformTupleTagsToLabels( + transform.getMainOutputTag(), outputs.keySet()); + + UnionCoder intermUnionCoder = getIntermUnionCoder(outputs.values()); + WindowedValue.WindowedValueCoder outputStreamCoder = WindowedValue.getFullCoder( + intermUnionCoder, windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation> intermWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundMultiWrapper doFnWrapper = new FlinkParDoBoundMultiWrapper<>( + context.getPipelineOptions(), windowingStrategy, transform.getFn(), + transform.getMainOutputTag(), tagsToLabels); + + DataStream> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator> intermDataStream = + inputDataStream.flatMap(doFnWrapper).returns(intermWindowedValueCoder); + + for (Map.Entry, PCollection> output : outputs.entrySet()) { + final int outputTag = tagsToLabels.get(output.getKey()); + + WindowedValue.WindowedValueCoder coderForTag = WindowedValue.getFullCoder( + output.getValue().getCoder(), + windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation> windowedValueCoder = + new CoderTypeInformation(coderForTag); + + context.setOutputDataStream(output.getValue(), + intermDataStream.filter(new FilterFunction>() { + @Override + public boolean filter(WindowedValue value) throws Exception { + return value.getValue().getUnionTag() == outputTag; + } + }).flatMap(new FlatMapFunction, WindowedValue>() { + @Override + public void flatMap(WindowedValue value, Collector> collector) throws Exception { + collector.collect(WindowedValue.of( + value.getValue().getValue(), + value.getTimestamp(), + value.getWindows(), + value.getPane())); + } + }).returns(windowedValueCoder)); + } + } + + private Map, Integer> transformTupleTagsToLabels(TupleTag mainTag, Set> secondaryTags) { + Map, Integer> tagToLabelMap = Maps.newHashMap(); + tagToLabelMap.put(mainTag, MAIN_TAG_INDEX); + int count = MAIN_TAG_INDEX + 1; + for (TupleTag tag : secondaryTags) { + if (!tagToLabelMap.containsKey(tag)) { + tagToLabelMap.put(tag, count++); + } + } + return tagToLabelMap; + } + + private UnionCoder getIntermUnionCoder(Collection> taggedCollections) { + List> outputCoders = Lists.newArrayList(); + for (PCollection coll : taggedCollections) { + outputCoders.add(coll.getCoder()); + } + return UnionCoder.of(outputCoders); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java new file mode 100644 index 000000000000..f6bdecd88929 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.common.base.Preconditions; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkStreamingTranslationContext { + + private final StreamExecutionEnvironment env; + private final PipelineOptions options; + + /** + * Keeps a mapping between the output value of the PTransform (in Dataflow) and the + * Flink Operator that produced it, after the translation of the correspondinf PTransform + * to its Flink equivalent. + * */ + private final Map> dataStreams; + + private AppliedPTransform currentTransform; + + public FlinkStreamingTranslationContext(StreamExecutionEnvironment env, PipelineOptions options) { + this.env = Preconditions.checkNotNull(env); + this.options = Preconditions.checkNotNull(options); + this.dataStreams = new HashMap<>(); + } + + public StreamExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public DataStream getInputDataStream(PValue value) { + return (DataStream) dataStreams.get(value); + } + + public void setOutputDataStream(PValue value, DataStream set) { + if (!dataStreams.containsKey(value)) { + dataStreams.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public I getInput(PTransform transform) { + return (I) currentTransform.getInput(); + } + + @SuppressWarnings("unchecked") + public O getOutput(PTransform transform) { + return (O) currentTransform.getOutput(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java new file mode 100644 index 000000000000..d5562b8653ff --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import org.apache.flink.api.common.functions.CoGroupFunction; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.List; + + +public class FlinkCoGroupKeyedListAggregator implements CoGroupFunction, KV, KV>{ + + private CoGbkResultSchema schema; + private TupleTag tupleTag1; + private TupleTag tupleTag2; + + public FlinkCoGroupKeyedListAggregator(CoGbkResultSchema schema, TupleTag tupleTag1, TupleTag tupleTag2) { + this.schema = schema; + this.tupleTag1 = tupleTag1; + this.tupleTag2 = tupleTag2; + } + + @Override + public void coGroup(Iterable> first, Iterable> second, Collector> out) throws Exception { + K k = null; + List result = new ArrayList<>(); + int index1 = schema.getIndex(tupleTag1); + for (KV entry : first) { + k = entry.getKey(); + result.add(new RawUnionValue(index1, entry.getValue())); + } + int index2 = schema.getIndex(tupleTag2); + for (KV entry : second) { + k = entry.getKey(); + result.add(new RawUnionValue(index2, entry.getValue())); + } + out.collect(KV.of(k, new CoGbkResult(schema, result))); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java new file mode 100644 index 000000000000..56af39758ea8 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +import java.io.ByteArrayInputStream; +import java.util.List; + +/** + * This is a hack for transforming a {@link com.google.cloud.dataflow.sdk.transforms.Create} + * operation. Flink does not allow {@code null} in it's equivalent operation: + * {@link org.apache.flink.api.java.ExecutionEnvironment#fromElements(Object[])}. Therefore + * we use a DataSource with one dummy element and output the elements of the Create operation + * inside this FlatMap. + */ +public class FlinkCreateFunction implements FlatMapFunction { + + private final List elements; + private final Coder coder; + + public FlinkCreateFunction(List elements, Coder coder) { + this.elements = elements; + this.coder = coder; + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap(IN value, Collector out) throws Exception { + + for (byte[] element : elements) { + ByteArrayInputStream bai = new ByteArrayInputStream(element); + OUT outValue = coder.decode(bai, Coder.Context.OUTER); + if (outValue == null) { + // TODO Flink doesn't allow null values in records + out.collect((OUT) VoidCoderTypeSerializer.VoidValue.INSTANCE); + } else { + out.collect(outValue); + } + } + + out.close(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java new file mode 100644 index 000000000000..fe77e6434f64 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Encapsulates a {@link com.google.cloud.dataflow.sdk.transforms.DoFn} + * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. + */ +public class FlinkDoFnFunction extends RichMapPartitionFunction { + + private final DoFn doFn; + private transient PipelineOptions options; + + public FlinkDoFnFunction(DoFn doFn, PipelineOptions options) { + this.doFn = doFn; + this.options = options; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + ProcessContext context = new ProcessContext(doFn, out); + this.doFn.startBundle(context); + for (IN value : values) { + context.inValue = value; + doFn.processElement(context); + } + this.doFn.finishBundle(context); + } + + private class ProcessContext extends DoFn.ProcessContext { + + IN inValue; + Collector outCollector; + + public ProcessContext(DoFn fn, Collector outCollector) { + fn.super(); + super.setupDelegateAggregators(); + this.outCollector = outCollector; + } + + @Override + public IN element() { + return this.inValue; + } + + + @Override + public Instant timestamp() { + return Instant.now(); + } + + @Override + public BoundedWindow window() { + return GlobalWindow.INSTANCE; + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public WindowingInternals windowingInternals() { + return new WindowingInternals() { + @Override + public StateInternals stateInternals() { + return null; + } + + @Override + public void outputWindowedValue(OUT output, Instant timestamp, Collection windows, PaneInfo pane) { + + } + + @Override + public TimerInternals timerInternals() { + return null; + } + + @Override + public Collection windows() { + return ImmutableList.of(GlobalWindow.INSTANCE); + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public void writePCollectionViewData(TupleTag tag, Iterable> data, Coder elemCoder) throws IOException { + } + + @Override + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + throw new RuntimeException("sideInput() not implemented."); + } + }; + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + List sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal().getId()); + List> windowedValueList = new ArrayList<>(sideInput.size()); + for (T input : sideInput) { + windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); + } + return view.fromIterableInternal(windowedValueList); + } + + @Override + public void output(OUT output) { + outCollector.collect(output); + } + + @Override + public void outputWithTimestamp(OUT output, Instant timestamp) { + // not FLink's way, just output normally + output(output); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + // ignore the side output, this can happen when a user does not register + // side outputs but then outputs using a freshly created TupleTag. + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { + SerializableFnAggregatorWrapper wrapper = new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, wrapper); + return wrapper; + } + + + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java new file mode 100644 index 000000000000..f92f888734d0 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.GroupByKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements, extracts the key and collects + * the values in a {@code List}. + */ +public class FlinkKeyedListAggregationFunction implements GroupReduceFunction, KV>> { + + @Override + public void reduce(Iterable> values, Collector>> out) throws Exception { + Iterator> it = values.iterator(); + KV first = it.next(); + Iterable passThrough = new PassThroughIterable<>(first, it); + out.collect(KV.of(first.getKey(), passThrough)); + } + + private static class PassThroughIterable implements Iterable, Iterator { + private KV first; + private Iterator> iterator; + + public PassThroughIterable(KV first, Iterator> iterator) { + this.first = first; + this.iterator = iterator; + } + + @Override + public Iterator iterator() { + return this; + } + + @Override + public boolean hasNext() { + return first != null || iterator.hasNext(); + } + + @Override + public V next() { + if (first != null) { + V result = first.getValue(); + first = null; + return result; + } else { + return iterator.next().getValue(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Cannot remove elements from input."); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java new file mode 100644 index 000000000000..ca667ee4e338 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Encapsulates a {@link com.google.cloud.dataflow.sdk.transforms.DoFn} that uses side outputs + * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. + * + * We get a mapping from {@link com.google.cloud.dataflow.sdk.values.TupleTag} to output index + * and must tag all outputs with the output number. Afterwards a filter will filter out + * those elements that are not to be in a specific output. + */ +public class FlinkMultiOutputDoFnFunction extends RichMapPartitionFunction { + + private final DoFn doFn; + private transient PipelineOptions options; + private final Map, Integer> outputMap; + + public FlinkMultiOutputDoFnFunction(DoFn doFn, PipelineOptions options, Map, Integer> outputMap) { + this.doFn = doFn; + this.options = options; + this.outputMap = outputMap; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + ProcessContext context = new ProcessContext(doFn, out); + this.doFn.startBundle(context); + for (IN value : values) { + context.inValue = value; + doFn.processElement(context); + } + this.doFn.finishBundle(context); + } + + private class ProcessContext extends DoFn.ProcessContext { + + IN inValue; + Collector outCollector; + + public ProcessContext(DoFn fn, Collector outCollector) { + fn.super(); + this.outCollector = outCollector; + } + + @Override + public IN element() { + return this.inValue; + } + + @Override + public Instant timestamp() { + return Instant.now(); + } + + @Override + public BoundedWindow window() { + return GlobalWindow.INSTANCE; + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public WindowingInternals windowingInternals() { + return null; + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + List sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal() + .getId()); + List> windowedValueList = new ArrayList<>(sideInput.size()); + for (T input : sideInput) { + windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); + } + return view.fromIterableInternal(windowedValueList); + } + + @Override + public void output(OUT value) { + // assume that index 0 is the default output + outCollector.collect(new RawUnionValue(0, value)); + } + + @Override + public void outputWithTimestamp(OUT output, Instant timestamp) { + // not FLink's way, just output normally + output(output); + } + + @Override + @SuppressWarnings("unchecked") + public void sideOutput(TupleTag tag, T value) { + Integer index = outputMap.get(tag); + if (index != null) { + outCollector.collect(new RawUnionValue(index, value)); + } + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { + SerializableFnAggregatorWrapper wrapper = new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, wrapper); + return null; + } + + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java new file mode 100644 index 000000000000..37de37e4d68c --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +/** + * A FlatMap function that filters out those elements that don't belong in this output. We need + * this to implement MultiOutput ParDo functions. + */ +public class FlinkMultiOutputPruningFunction implements FlatMapFunction { + + private final int outputTag; + + public FlinkMultiOutputPruningFunction(int outputTag) { + this.outputTag = outputTag; + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap(RawUnionValue rawUnionValue, Collector collector) throws Exception { + if (rawUnionValue.getUnionTag() == outputTag) { + collector.collect((T) rawUnionValue.getValue()); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java new file mode 100644 index 000000000000..2de681b54739 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupCombineFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.Combine.PerKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements VI, extracts the key and emits accumulated + * values which have the intermediate format VA. + */ +public class FlinkPartialReduceFunction implements GroupCombineFunction, KV> { + + private final Combine.KeyedCombineFn keyedCombineFn; + + public FlinkPartialReduceFunction(Combine.KeyedCombineFn + keyedCombineFn) { + this.keyedCombineFn = keyedCombineFn; + } + + @Override + public void combine(Iterable> elements, Collector> out) throws Exception { + + final Iterator> iterator = elements.iterator(); + // create accumulator using the first elements key + KV first = iterator.next(); + K key = first.getKey(); + VI value = first.getValue(); + VA accumulator = keyedCombineFn.createAccumulator(key); + accumulator = keyedCombineFn.addInput(key, accumulator, value); + + while(iterator.hasNext()) { + value = iterator.next().getValue(); + accumulator = keyedCombineFn.addInput(key, accumulator, value); + } + + out.collect(KV.of(key, accumulator)); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java new file mode 100644 index 000000000000..29193a2a8e05 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.Combine.PerKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements, extracts the key and merges the + * accumulators resulting from the PartialReduce which produced the input VA. + */ +public class FlinkReduceFunction implements GroupReduceFunction, KV> { + + private final Combine.KeyedCombineFn keyedCombineFn; + + public FlinkReduceFunction(Combine.KeyedCombineFn keyedCombineFn) { + this.keyedCombineFn = keyedCombineFn; + } + + @Override + public void reduce(Iterable> values, Collector> out) throws Exception { + Iterator> it = values.iterator(); + + KV current = it.next(); + K k = current.getKey(); + VA accumulator = current.getValue(); + + while (it.hasNext()) { + current = it.next(); + keyedCombineFn.mergeAccumulators(k, ImmutableList.of(accumulator, current.getValue()) ); + } + + out.collect(KV.of(k, keyedCombineFn.extractOutput(k, accumulator))); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java new file mode 100644 index 000000000000..05f441551a7b --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.beam.runners.flink.translation.functions; + + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A UnionCoder encodes RawUnionValues. + * + * This file copied from {@link com.google.cloud.dataflow.sdk.transforms.join.UnionCoder} + */ +@SuppressWarnings("serial") +public class UnionCoder extends StandardCoder { + // TODO: Think about how to integrate this with a schema object (i.e. + // a tuple of tuple tags). + /** + * Builds a union coder with the given list of element coders. This list + * corresponds to a mapping of union tag to Coder. Union tags start at 0. + */ + public static UnionCoder of(List> elementCoders) { + return new UnionCoder(elementCoders); + } + + @JsonCreator + public static UnionCoder jsonOf( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List> elements) { + return UnionCoder.of(elements); + } + + private int getIndexForEncoding(RawUnionValue union) { + if (union == null) { + throw new IllegalArgumentException("cannot encode a null tagged union"); + } + int index = union.getUnionTag(); + if (index < 0 || index >= elementCoders.size()) { + throw new IllegalArgumentException( + "union value index " + index + " not in range [0.." + + (elementCoders.size() - 1) + "]"); + } + return index; + } + + @SuppressWarnings("unchecked") + @Override + public void encode( + RawUnionValue union, + OutputStream outStream, + Context context) + throws IOException { + int index = getIndexForEncoding(union); + // Write out the union tag. + VarInt.encode(index, outStream); + + // Write out the actual value. + Coder coder = (Coder) elementCoders.get(index); + coder.encode( + union.getValue(), + outStream, + context); + } + + @Override + public RawUnionValue decode(InputStream inStream, Context context) + throws IOException { + int index = VarInt.decodeInt(inStream); + Object value = elementCoders.get(index).decode(inStream, context); + return new RawUnionValue(index, value); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public List> getComponents() { + return elementCoders; + } + + /** + * Since this coder uses elementCoders.get(index) and coders that are known to run in constant + * time, we defer the return value to that coder. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(RawUnionValue union, Context context) { + int index = getIndexForEncoding(union); + @SuppressWarnings("unchecked") + Coder coder = (Coder) elementCoders.get(index); + return coder.isRegisterByteSizeObserverCheap(union.getValue(), context); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + RawUnionValue union, ElementByteSizeObserver observer, Context context) + throws Exception { + int index = getIndexForEncoding(union); + // Write out the union tag. + observer.update(VarInt.getLength(index)); + // Write out the actual value. + @SuppressWarnings("unchecked") + Coder coder = (Coder) elementCoders.get(index); + coder.registerByteSizeObserver(union.getValue(), observer, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + private final List> elementCoders; + + private UnionCoder(List> elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "UnionCoder is only deterministic if all element coders are", + elementCoders); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java new file mode 100644 index 000000000000..12490363040e --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemorySegment; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for + * {@link com.google.cloud.dataflow.sdk.coders.Coder}. + */ +public class CoderComparator extends TypeComparator { + + private Coder coder; + + // We use these for internal encoding/decoding for creating copies and comparing + // serialized forms using a Coder + private transient InspectableByteArrayOutputStream buffer1; + private transient InspectableByteArrayOutputStream buffer2; + + // For storing the Reference in encoded form + private transient InspectableByteArrayOutputStream referenceBuffer; + + public CoderComparator(Coder coder) { + this.coder = coder; + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + @Override + public int hash(T record) { + return record.hashCode(); + } + + @Override + public void setReference(T toCompare) { + referenceBuffer.reset(); + try { + coder.encode(toCompare, referenceBuffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not set reference " + toCompare + ": " + e); + } + } + + @Override + public boolean equalToReference(T candidate) { + try { + buffer2.reset(); + coder.encode(candidate, buffer2, Coder.Context.OUTER); + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (referenceBuffer.size() != buffer2.size()) { + return false; + } + int len = buffer2.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return false; + } + } + return true; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareToReference(TypeComparator other) { + InspectableByteArrayOutputStream otherReferenceBuffer = ((CoderComparator) other).referenceBuffer; + + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = otherReferenceBuffer.getBuffer(); + if (referenceBuffer.size() != otherReferenceBuffer.size()) { + return referenceBuffer.size() - otherReferenceBuffer.size(); + } + int len = referenceBuffer.size(); + for (int i = 0; i < len; i++) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } + + @Override + public int compare(T first, T second) { + try { + buffer1.reset(); + buffer2.reset(); + coder.encode(first, buffer1, Coder.Context.OUTER); + coder.encode(second, buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare: ", e); + } + } + + @Override + public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { + CoderTypeSerializer serializer = new CoderTypeSerializer<>(coder); + T first = serializer.deserialize(firstSource); + T second = serializer.deserialize(secondSource); + return compare(first, second); + } + + @Override + public boolean supportsNormalizedKey() { + return true; + } + + @Override + public boolean supportsSerializationWithKeyNormalization() { + return false; + } + + @Override + public int getNormalizeKeyLen() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isNormalizedKeyPrefixOnly(int keyBytes) { + return true; + } + + @Override + public void putNormalizedKey(T record, MemorySegment target, int offset, int numBytes) { + buffer1.reset(); + try { + coder.encode(record, buffer1, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + } + final byte[] data = buffer1.getBuffer(); + final int limit = offset + numBytes; + + target.put(offset, data, 0, Math.min(numBytes, buffer1.size())); + + offset += buffer1.size(); + + while (offset < limit) { + target.put(offset++, (byte) 0); + } + } + + @Override + public void writeWithKeyNormalization(T record, DataOutputView target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean invertNormalizedKey() { + return false; + } + + @Override + public TypeComparator duplicate() { + return new CoderComparator<>(coder); + } + + @Override + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; + } + + @Override + public TypeComparator[] getFlatComparators() { + return new TypeComparator[] { this.duplicate() }; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java new file mode 100644 index 000000000000..f9d4dcd20cab --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import com.google.common.base.Preconditions; + +/** + * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s. + */ +public class CoderTypeInformation extends TypeInformation implements AtomicType { + + private final Coder coder; + + public CoderTypeInformation(Coder coder) { + Preconditions.checkNotNull(coder); + this.coder = coder; + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + @SuppressWarnings("unchecked") + public Class getTypeClass() { + // We don't have the Class, so we have to pass null here. What a shame... + return (Class) Object.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + public TypeSerializer createSerializer(ExecutionConfig config) { + if (coder instanceof VoidCoder) { + return (TypeSerializer) new VoidCoderTypeSerializer(); + } + return new CoderTypeSerializer<>(coder); + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CoderTypeInformation that = (CoderTypeInformation) o; + + return coder.equals(that.coder); + + } + + @Override + public int hashCode() { + return coder.hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CoderTypeInformation; + } + + @Override + public String toString() { + return "CoderTypeInformation{" + + "coder=" + coder + + '}'; + } + + @Override + public TypeComparator createComparator(boolean sortOrderAscending, ExecutionConfig + executionConfig) { + return new CoderComparator<>(coder); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java new file mode 100644 index 000000000000..4e81054d1c9b --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; +import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s + */ +public class CoderTypeSerializer extends TypeSerializer { + + private Coder coder; + private transient DataInputViewWrapper inputWrapper; + private transient DataOutputViewWrapper outputWrapper; + + // We use this for internal encoding/decoding for creating copies using the Coder. + private transient InspectableByteArrayOutputStream buffer; + + public CoderTypeSerializer(Coder coder) { + this.coder = coder; + this.inputWrapper = new DataInputViewWrapper(null); + this.outputWrapper = new DataOutputViewWrapper(null); + + buffer = new InspectableByteArrayOutputStream(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + this.inputWrapper = new DataInputViewWrapper(null); + this.outputWrapper = new DataOutputViewWrapper(null); + + buffer = new InspectableByteArrayOutputStream(); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public CoderTypeSerializer duplicate() { + return new CoderTypeSerializer<>(coder); + } + + @Override + public T createInstance() { + return null; + } + + @Override + public T copy(T t) { + buffer.reset(); + try { + coder.encode(t, buffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not copy.", e); + } + try { + return coder.decode(new ByteArrayInputStream(buffer.getBuffer(), 0, buffer + .size()), Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not copy.", e); + } + } + + @Override + public T copy(T t, T reuse) { + return copy(t); + } + + @Override + public int getLength() { + return 0; + } + + @Override + public void serialize(T t, DataOutputView dataOutputView) throws IOException { + outputWrapper.setOutputView(dataOutputView); + coder.encode(t, outputWrapper, Coder.Context.NESTED); + } + + @Override + public T deserialize(DataInputView dataInputView) throws IOException { + try { + inputWrapper.setInputView(dataInputView); + return coder.decode(inputWrapper, Coder.Context.NESTED); + } catch (CoderException e) { + Throwable cause = e.getCause(); + if (cause instanceof EOFException) { + throw (EOFException) cause; + } else { + throw e; + } + } + } + + @Override + public T deserialize(T t, DataInputView dataInputView) throws IOException { + return deserialize(dataInputView); + } + + @Override + public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { + serialize(deserialize(dataInputView), dataOutputView); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CoderTypeSerializer that = (CoderTypeSerializer) o; + return coder.equals(that.coder); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CoderTypeSerializer; + } + + @Override + public int hashCode() { + return coder.hashCode(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java new file mode 100644 index 000000000000..36b5ba319180 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import java.io.ByteArrayOutputStream; + +/** + * Version of {@link java.io.ByteArrayOutputStream} that allows to retrieve the internal + * byte[] buffer without incurring an array copy. + */ +public class InspectableByteArrayOutputStream extends ByteArrayOutputStream { + + /** + * Get the underlying byte array. + */ + public byte[] getBuffer() { + return buf; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java new file mode 100644 index 000000000000..3912295afb0b --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemorySegment; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for + * {@link com.google.cloud.dataflow.sdk.coders.KvCoder}. We have a special comparator + * for {@link KV} that always compares on the key only. + */ +public class KvCoderComperator extends TypeComparator> { + + private KvCoder coder; + private Coder keyCoder; + + // We use these for internal encoding/decoding for creating copies and comparing + // serialized forms using a Coder + private transient InspectableByteArrayOutputStream buffer1; + private transient InspectableByteArrayOutputStream buffer2; + + // For storing the Reference in encoded form + private transient InspectableByteArrayOutputStream referenceBuffer; + + + // For deserializing the key + private transient DataInputViewWrapper inputWrapper; + + public KvCoderComperator(KvCoder coder) { + this.coder = coder; + this.keyCoder = coder.getKeyCoder(); + + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + + inputWrapper = new DataInputViewWrapper(null); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + + inputWrapper = new DataInputViewWrapper(null); + } + + @Override + public int hash(KV record) { + K key = record.getKey(); + if (key != null) { + return key.hashCode(); + } else { + return 0; + } + } + + @Override + public void setReference(KV toCompare) { + referenceBuffer.reset(); + try { + keyCoder.encode(toCompare.getKey(), referenceBuffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not set reference " + toCompare + ": " + e); + } + } + + @Override + public boolean equalToReference(KV candidate) { + try { + buffer2.reset(); + keyCoder.encode(candidate.getKey(), buffer2, Coder.Context.OUTER); + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (referenceBuffer.size() != buffer2.size()) { + return false; + } + int len = buffer2.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return false; + } + } + return true; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareToReference(TypeComparator> other) { + InspectableByteArrayOutputStream otherReferenceBuffer = ((KvCoderComperator) other).referenceBuffer; + + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = otherReferenceBuffer.getBuffer(); + if (referenceBuffer.size() != otherReferenceBuffer.size()) { + return referenceBuffer.size() - otherReferenceBuffer.size(); + } + int len = referenceBuffer.size(); + for (int i = 0; i < len; i++) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } + + + @Override + public int compare(KV first, KV second) { + try { + buffer1.reset(); + buffer2.reset(); + keyCoder.encode(first.getKey(), buffer1, Coder.Context.OUTER); + keyCoder.encode(second.getKey(), buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { + + inputWrapper.setInputView(firstSource); + K firstKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); + inputWrapper.setInputView(secondSource); + K secondKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); + + try { + buffer1.reset(); + buffer2.reset(); + keyCoder.encode(firstKey, buffer1, Coder.Context.OUTER); + keyCoder.encode(secondKey, buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public boolean supportsNormalizedKey() { + return true; + } + + @Override + public boolean supportsSerializationWithKeyNormalization() { + return false; + } + + @Override + public int getNormalizeKeyLen() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isNormalizedKeyPrefixOnly(int keyBytes) { + return true; + } + + @Override + public void putNormalizedKey(KV record, MemorySegment target, int offset, int numBytes) { + buffer1.reset(); + try { + keyCoder.encode(record.getKey(), buffer1, Coder.Context.NESTED); + } catch (IOException e) { + throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + } + final byte[] data = buffer1.getBuffer(); + final int limit = offset + numBytes; + + int numBytesPut = Math.min(numBytes, buffer1.size()); + + target.put(offset, data, 0, numBytesPut); + + offset += numBytesPut; + + while (offset < limit) { + target.put(offset++, (byte) 0); + } + } + + @Override + public void writeWithKeyNormalization(KV record, DataOutputView target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public KV readWithKeyDenormalization(KV reuse, DataInputView source) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean invertNormalizedKey() { + return false; + } + + @Override + public TypeComparator> duplicate() { + return new KvCoderComperator<>(coder); + } + + @Override + public int extractKeys(Object record, Object[] target, int index) { + KV kv = (KV) record; + K k = kv.getKey(); + target[index] = k; + return 1; + } + + @Override + public TypeComparator[] getFlatComparators() { + return new TypeComparator[] {new CoderComparator<>(keyCoder)}; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java new file mode 100644 index 000000000000..8862d48b9f36 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import com.google.common.base.Preconditions; + +import java.util.List; + +/** + * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.KvCoder}. + */ +public class KvCoderTypeInformation extends CompositeType> { + + private KvCoder coder; + + // We don't have the Class, so we have to pass null here. What a shame... + private static Object DUMMY = new Object(); + + @SuppressWarnings("unchecked") + public KvCoderTypeInformation(KvCoder coder) { + super(((Class>) DUMMY.getClass())); + this.coder = coder; + Preconditions.checkNotNull(coder); + } + + @Override + @SuppressWarnings("unchecked") + public TypeComparator> createComparator(int[] logicalKeyFields, boolean[] orders, int logicalFieldOffset, ExecutionConfig config) { + return new KvCoderComperator((KvCoder) coder); + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 2; + } + + @Override + @SuppressWarnings("unchecked") + public Class> getTypeClass() { + return privateGetTypeClass(); + } + + @SuppressWarnings("unchecked") + private static Class privateGetTypeClass() { + return (Class) Object.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + public TypeSerializer> createSerializer(ExecutionConfig config) { + return new CoderTypeSerializer<>(coder); + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + KvCoderTypeInformation that = (KvCoderTypeInformation) o; + + return coder.equals(that.coder); + + } + + @Override + public int hashCode() { + return coder.hashCode(); + } + + @Override + public String toString() { + return "CoderTypeInformation{" + + "coder=" + coder + + '}'; + } + + @Override + @SuppressWarnings("unchecked") + public TypeInformation getTypeAt(int pos) { + if (pos == 0) { + return (TypeInformation) new CoderTypeInformation<>(coder.getKeyCoder()); + } else if (pos == 1) { + return (TypeInformation) new CoderTypeInformation<>(coder.getValueCoder()); + } else { + throw new RuntimeException("Invalid field position " + pos); + } + } + + @Override + @SuppressWarnings("unchecked") + public TypeInformation getTypeAt(String fieldExpression) { + switch (fieldExpression) { + case "key": + return (TypeInformation) new CoderTypeInformation<>(coder.getKeyCoder()); + case "value": + return (TypeInformation) new CoderTypeInformation<>(coder.getValueCoder()); + default: + throw new UnsupportedOperationException("Only KvCoder has fields."); + } + } + + @Override + public String[] getFieldNames() { + return new String[]{"key", "value"}; + } + + @Override + public int getFieldIndex(String fieldName) { + switch (fieldName) { + case "key": + return 0; + case "value": + return 1; + default: + return -1; + } + } + + @Override + public void getFlatFields(String fieldExpression, int offset, List result) { + CoderTypeInformation keyTypeInfo = new CoderTypeInformation<>(coder.getKeyCoder()); + result.add(new FlatFieldDescriptor(0, keyTypeInfo)); + } + + @Override + protected TypeComparatorBuilder> createTypeComparatorBuilder() { + return new KvCoderTypeComparatorBuilder(); + } + + private class KvCoderTypeComparatorBuilder implements TypeComparatorBuilder> { + + @Override + public void initializeTypeComparatorBuilder(int size) {} + + @Override + public void addComparatorField(int fieldId, TypeComparator comparator) {} + + @Override + public TypeComparator> createTypeComparator(ExecutionConfig config) { + return new KvCoderComperator<>(coder); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java new file mode 100644 index 000000000000..8bc362002707 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; + +/** + * Special Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * {@link com.google.cloud.dataflow.sdk.coders.VoidCoder}. We need this because Flink does not + * allow returning {@code null} from an input reader. We return a {@link VoidValue} instead + * that behaves like a {@code null}, hopefully. + */ +public class VoidCoderTypeSerializer extends TypeSerializer { + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public VoidCoderTypeSerializer duplicate() { + return this; + } + + @Override + public VoidValue createInstance() { + return VoidValue.INSTANCE; + } + + @Override + public VoidValue copy(VoidValue from) { + return from; + } + + @Override + public VoidValue copy(VoidValue from, VoidValue reuse) { + return from; + } + + @Override + public int getLength() { + return 0; + } + + @Override + public void serialize(VoidValue record, DataOutputView target) throws IOException { + target.writeByte(1); + } + + @Override + public VoidValue deserialize(DataInputView source) throws IOException { + source.readByte(); + return VoidValue.INSTANCE; + } + + @Override + public VoidValue deserialize(VoidValue reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + source.readByte(); + target.writeByte(1); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof VoidCoderTypeSerializer) { + VoidCoderTypeSerializer other = (VoidCoderTypeSerializer) obj; + return other.canEqual(this); + } else { + return false; + } + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof VoidCoderTypeSerializer; + } + + @Override + public int hashCode() { + return 0; + } + + public static class VoidValue { + private VoidValue() {} + + public static VoidValue INSTANCE = new VoidValue(); + } + +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java new file mode 100644 index 000000000000..445d41129917 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.accumulators.Accumulator; + +import java.io.Serializable; + +/** + * Wrapper that wraps a {@link com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn} + * in a Flink {@link org.apache.flink.api.common.accumulators.Accumulator} for using + * the combine function as an aggregator in a {@link com.google.cloud.dataflow.sdk.transforms.ParDo} + * operation. + */ +public class CombineFnAggregatorWrapper implements Aggregator, Accumulator { + + private AA aa; + private Combine.CombineFn combiner; + + public CombineFnAggregatorWrapper() { + } + + public CombineFnAggregatorWrapper(Combine.CombineFn combiner) { + this.combiner = combiner; + this.aa = combiner.createAccumulator(); + } + + @Override + public void add(AI value) { + combiner.addInput(aa, value); + } + + @Override + public Serializable getLocalValue() { + return (Serializable) combiner.extractOutput(aa); + } + + @Override + public void resetLocal() { + aa = combiner.createAccumulator(); + } + + @Override + @SuppressWarnings("unchecked") + public void merge(Accumulator other) { + aa = combiner.mergeAccumulators(Lists.newArrayList(aa, ((CombineFnAggregatorWrapper)other).aa)); + } + + @Override + public Accumulator clone() { + // copy it by merging + AA aaCopy = combiner.mergeAccumulators(Lists.newArrayList(aa)); + CombineFnAggregatorWrapper result = new + CombineFnAggregatorWrapper<>(combiner); + result.aa = aaCopy; + return result; + } + + @Override + public void addValue(AI value) { + add(value); + } + + @Override + public String getName() { + return "CombineFn: " + combiner.toString(); + } + + @Override + public Combine.CombineFn getCombineFn() { + return combiner; + } + +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java new file mode 100644 index 000000000000..6a3cf507be30 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import org.apache.flink.core.memory.DataInputView; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; + +/** + * Wrapper for {@link DataInputView}. We need this because Flink reads data using a + * {@link org.apache.flink.core.memory.DataInputView} while + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s expect an + * {@link java.io.InputStream}. + */ +public class DataInputViewWrapper extends InputStream { + + private DataInputView inputView; + + public DataInputViewWrapper(DataInputView inputView) { + this.inputView = inputView; + } + + public void setInputView(DataInputView inputView) { + this.inputView = inputView; + } + + @Override + public int read() throws IOException { + try { + return inputView.readUnsignedByte(); + } catch (EOFException e) { + // translate between DataInput and InputStream, + // DataInput signals EOF by exception, InputStream does it by returning -1 + return -1; + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return inputView.read(b, off, len); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java new file mode 100644 index 000000000000..6bd2240d567c --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Wrapper for {@link org.apache.flink.core.memory.DataOutputView}. We need this because + * Flink writes data using a {@link org.apache.flink.core.memory.DataInputView} while + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s expect an + * {@link java.io.OutputStream}. + */ +public class DataOutputViewWrapper extends OutputStream { + + private DataOutputView outputView; + + public DataOutputViewWrapper(DataOutputView outputView) { + this.outputView = outputView; + } + + public void setOutputView(DataOutputView outputView) { + this.outputView = outputView; + } + + @Override + public void write(int b) throws IOException { + outputView.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + outputView.write(b, off, len); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java new file mode 100644 index 000000000000..440958648309 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.accumulators.Accumulator; + +import java.io.Serializable; + +/** + * Wrapper that wraps a {@link com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn} + * in a Flink {@link org.apache.flink.api.common.accumulators.Accumulator} for using + * the function as an aggregator in a {@link com.google.cloud.dataflow.sdk.transforms.ParDo} + * operation. + */ +public class SerializableFnAggregatorWrapper implements Aggregator, Accumulator { + + private AO aa; + private Combine.CombineFn combiner; + + public SerializableFnAggregatorWrapper(Combine.CombineFn combiner) { + this.combiner = combiner; + resetLocal(); + } + + @Override + @SuppressWarnings("unchecked") + public void add(AI value) { + this.aa = combiner.apply(ImmutableList.of((AI) aa, value)); + } + + @Override + public Serializable getLocalValue() { + return (Serializable) aa; + } + + @Override + public void resetLocal() { + this.aa = combiner.apply(ImmutableList.of()); + } + + @Override + @SuppressWarnings("unchecked") + public void merge(Accumulator other) { + this.aa = combiner.apply(ImmutableList.of((AI) aa, (AI) other.getLocalValue())); + } + + @Override + public void addValue(AI value) { + add(value); + } + + @Override + public String getName() { + return "Aggregator :" + combiner.toString(); + } + + @Override + public Combine.CombineFn getCombineFn() { + return combiner; + } + + @Override + public Accumulator clone() { + // copy it by merging + AO resultCopy = combiner.apply(Lists.newArrayList((AI) aa)); + SerializableFnAggregatorWrapper result = new + SerializableFnAggregatorWrapper<>(combiner); + + result.aa = resultCopy; + return result; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java new file mode 100644 index 000000000000..4c2475dbf4da --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink.translation.wrappers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.io.Sink; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; +import com.google.cloud.dataflow.sdk.transforms.Write; +import org.apache.flink.api.common.io.OutputFormat; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.AbstractID; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.lang.reflect.Field; + +/** + * Wrapper class to use generic Write.Bound transforms as sinks. + * @param The type of the incoming records. + */ +public class SinkOutputFormat implements OutputFormat { + + private final Sink sink; + + private transient PipelineOptions pipelineOptions; + + private Sink.WriteOperation writeOperation; + private Sink.Writer writer; + + private AbstractID uid = new AbstractID(); + + public SinkOutputFormat(Write.Bound transform, PipelineOptions pipelineOptions) { + this.sink = extractSink(transform); + this.pipelineOptions = Preconditions.checkNotNull(pipelineOptions); + } + + private Sink extractSink(Write.Bound transform) { + // TODO possibly add a getter in the upstream + try { + Field sinkField = transform.getClass().getDeclaredField("sink"); + sinkField.setAccessible(true); + @SuppressWarnings("unchecked") + Sink extractedSink = (Sink) sinkField.get(transform); + return extractedSink; + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Could not acquire custom sink field.", e); + } + } + + @Override + public void configure(Configuration configuration) { + writeOperation = sink.createWriteOperation(pipelineOptions); + try { + writeOperation.initialize(pipelineOptions); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize the write operation.", e); + } + } + + @Override + public void open(int taskNumber, int numTasks) throws IOException { + try { + writer = writeOperation.createWriter(pipelineOptions); + } catch (Exception e) { + throw new IOException("Couldn't create writer.", e); + } + try { + writer.open(uid + "-" + String.valueOf(taskNumber)); + } catch (Exception e) { + throw new IOException("Couldn't open writer.", e); + } + } + + @Override + public void writeRecord(T record) throws IOException { + try { + writer.write(record); + } catch (Exception e) { + throw new IOException("Couldn't write record.", e); + } + } + + @Override + public void close() throws IOException { + try { + writer.close(); + } catch (Exception e) { + throw new IOException("Couldn't close writer.", e); + } + } + + private void writeObject(ObjectOutputStream out) throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, pipelineOptions); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + pipelineOptions = mapper.readValue(in, PipelineOptions.class); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java new file mode 100644 index 000000000000..cd5cd40768af --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.io.statistics.BaseStatistics; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.core.io.InputSplitAssigner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * A Flink {@link org.apache.flink.api.common.io.InputFormat} that wraps a + * Dataflow {@link com.google.cloud.dataflow.sdk.io.Source}. + */ +public class SourceInputFormat implements InputFormat> { + private static final Logger LOG = LoggerFactory.getLogger(SourceInputFormat.class); + + private final BoundedSource initialSource; + private transient PipelineOptions options; + + private BoundedSource.BoundedReader reader = null; + private boolean reachedEnd = true; + + public SourceInputFormat(BoundedSource initialSource, PipelineOptions options) { + this.initialSource = initialSource; + this.options = options; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + } + + @Override + public void configure(Configuration configuration) {} + + @Override + public void open(SourceInputSplit sourceInputSplit) throws IOException { + reader = ((BoundedSource) sourceInputSplit.getSource()).createReader(options); + reachedEnd = false; + } + + @Override + public BaseStatistics getStatistics(BaseStatistics baseStatistics) throws IOException { + try { + final long estimatedSize = initialSource.getEstimatedSizeBytes(options); + + return new BaseStatistics() { + @Override + public long getTotalInputSize() { + return estimatedSize; + + } + + @Override + public long getNumberOfRecords() { + return BaseStatistics.NUM_RECORDS_UNKNOWN; + } + + @Override + public float getAverageRecordWidth() { + return BaseStatistics.AVG_RECORD_BYTES_UNKNOWN; + } + }; + } catch (Exception e) { + LOG.warn("Could not read Source statistics: {}", e); + } + + return null; + } + + @Override + @SuppressWarnings("unchecked") + public SourceInputSplit[] createInputSplits(int numSplits) throws IOException { + long desiredSizeBytes; + try { + desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / numSplits; + List> shards = initialSource.splitIntoBundles(desiredSizeBytes, + options); + List> splits = new ArrayList<>(); + int splitCount = 0; + for (Source shard: shards) { + splits.add(new SourceInputSplit<>(shard, splitCount++)); + } + return splits.toArray(new SourceInputSplit[splits.size()]); + } catch (Exception e) { + throw new IOException("Could not create input splits from Source.", e); + } + } + + @Override + public InputSplitAssigner getInputSplitAssigner(final SourceInputSplit[] sourceInputSplits) { + return new InputSplitAssigner() { + private int index = 0; + private final SourceInputSplit[] splits = sourceInputSplits; + @Override + public InputSplit getNextInputSplit(String host, int taskId) { + if (index < splits.length) { + return splits[index++]; + } else { + return null; + } + } + }; + } + + + @Override + public boolean reachedEnd() throws IOException { + return reachedEnd; + } + + @Override + public T nextRecord(T t) throws IOException { + + reachedEnd = !reader.advance(); + if (!reachedEnd) { + return reader.getCurrent(); + } + return null; + } + + @Override + public void close() throws IOException { + reader.close(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java new file mode 100644 index 000000000000..cde2b3571c71 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.io.Source; +import org.apache.flink.core.io.InputSplit; + +/** + * {@link org.apache.flink.core.io.InputSplit} for + * {@link org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat}. We pass + * the sharded Source around in the input split because Sources simply split up into several + * Sources for sharding. This is different to how Flink creates a separate InputSplit from + * an InputFormat. + */ +public class SourceInputSplit implements InputSplit { + + private Source source; + private int splitNumber; + + public SourceInputSplit() { + } + + public SourceInputSplit(Source source, int splitNumber) { + this.source = source; + this.splitNumber = splitNumber; + } + + @Override + public int getSplitNumber() { + return splitNumber; + } + + public Source getSource() { + return source; + } + +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java new file mode 100644 index 000000000000..10c8bbf07a04 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.*; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Throwables; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.accumulators.AccumulatorHelper; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import org.joda.time.format.PeriodFormat; + +import java.util.Collection; + +/** + * An abstract class that encapsulates the common code of the the {@link com.google.cloud.dataflow.sdk.transforms.ParDo.Bound} + * and {@link com.google.cloud.dataflow.sdk.transforms.ParDo.BoundMulti} wrappers. See the {@link FlinkParDoBoundWrapper} and + * {@link FlinkParDoBoundMultiWrapper} for the actual wrappers of the aforementioned transformations. + * */ +public abstract class FlinkAbstractParDoWrapper extends RichFlatMapFunction, WindowedValue> { + + private final DoFn doFn; + private final WindowingStrategy windowingStrategy; + private transient PipelineOptions options; + + private DoFnProcessContext context; + + public FlinkAbstractParDoWrapper(PipelineOptions options, WindowingStrategy windowingStrategy, DoFn doFn) { + Preconditions.checkNotNull(options); + Preconditions.checkNotNull(windowingStrategy); + Preconditions.checkNotNull(doFn); + + this.doFn = doFn; + this.options = options; + this.windowingStrategy = windowingStrategy; + } + + private void initContext(DoFn function, Collector> outCollector) { + if (this.context == null) { + this.context = new DoFnProcessContext(function, outCollector); + } + } + + @Override + public void flatMap(WindowedValue value, Collector> out) throws Exception { + this.initContext(doFn, out); + + // for each window the element belongs to, create a new copy here. + Collection windows = value.getWindows(); + if (windows.size() <= 1) { + processElement(value); + } else { + for (BoundedWindow window : windows) { + processElement(WindowedValue.of( + value.getValue(), value.getTimestamp(), window, value.getPane())); + } + } + } + + private void processElement(WindowedValue value) throws Exception { + this.context.setElement(value); + this.doFn.startBundle(context); + doFn.processElement(context); + this.doFn.finishBundle(context); + } + + private class DoFnProcessContext extends DoFn.ProcessContext { + + private final DoFn fn; + + protected final Collector> collector; + + private WindowedValue element; + + private DoFnProcessContext(DoFn function, Collector> outCollector) { + function.super(); + super.setupDelegateAggregators(); + + this.fn = function; + this.collector = outCollector; + } + + public void setElement(WindowedValue value) { + this.element = value; + } + + @Override + public IN element() { + return this.element.getValue(); + } + + @Override + public Instant timestamp() { + return this.element.getTimestamp(); + } + + @Override + public BoundedWindow window() { + if (!(fn instanceof DoFn.RequiresWindowAccess)) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + + Collection windows = this.element.getWindows(); + if (windows.size() != 1) { + throw new IllegalArgumentException("Each element is expected to belong to 1 window. " + + "This belongs to " + windows.size() + "."); + } + return windows.iterator().next(); + } + + @Override + public PaneInfo pane() { + return this.element.getPane(); + } + + @Override + public WindowingInternals windowingInternals() { + return windowingInternalsHelper(element, collector); + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public T sideInput(PCollectionView view) { + throw new RuntimeException("sideInput() is not supported in Streaming mode."); + } + + @Override + public void output(OUTDF output) { + outputWithTimestamp(output, this.element.getTimestamp()); + } + + @Override + public void outputWithTimestamp(OUTDF output, Instant timestamp) { + outputWithTimestampHelper(element, output, timestamp, collector); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + sideOutputWithTimestamp(tag, output, this.element.getTimestamp()); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + sideOutputWithTimestampHelper(element, output, timestamp, collector, tag); + } + + @Override + protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { + Accumulator acc = getRuntimeContext().getAccumulator(name); + if (acc != null) { + AccumulatorHelper.compareAccumulatorTypes(name, + SerializableFnAggregatorWrapper.class, acc.getClass()); + return (Aggregator) acc; + } + + SerializableFnAggregatorWrapper accumulator = + new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, accumulator); + return accumulator; + } + } + + protected void checkTimestamp(WindowedValue ref, Instant timestamp) { + if (timestamp.isBefore(ref.getTimestamp().minus(doFn.getAllowedTimestampSkew()))) { + throw new IllegalArgumentException(String.format( + "Cannot output with timestamp %s. Output timestamps must be no earlier than the " + + "timestamp of the current input (%s) minus the allowed skew (%s). See the " + + "DoFn#getAllowedTimestmapSkew() Javadoc for details on changing the allowed skew.", + timestamp, ref.getTimestamp(), + PeriodFormat.getDefault().print(doFn.getAllowedTimestampSkew().toPeriod()))); + } + } + + protected WindowedValue makeWindowedValue( + T output, Instant timestamp, Collection windows, PaneInfo pane) { + final Instant inputTimestamp = timestamp; + final WindowFn windowFn = windowingStrategy.getWindowFn(); + + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + if (windows == null) { + try { + windows = windowFn.assignWindows(windowFn.new AssignContext() { + @Override + public Object element() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input element when none was available"); + } + + @Override + public Instant timestamp() { + if (inputTimestamp == null) { + throw new UnsupportedOperationException( + "WindowFn attempted to access input timestamp when none was available"); + } + return inputTimestamp; + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input windows when none were available"); + } + }); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } + } + + return WindowedValue.of(output, timestamp, windows, pane); + } + + /////////// ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES ///////////////// + + public abstract void outputWithTimestampHelper( + WindowedValue inElement, + OUTDF output, + Instant timestamp, + Collector> outCollector); + + public abstract void sideOutputWithTimestampHelper( + WindowedValue inElement, + T output, + Instant timestamp, + Collector> outCollector, + TupleTag tag); + + public abstract WindowingInternals windowingInternalsHelper( + WindowedValue inElement, + Collector> outCollector); + +} \ No newline at end of file diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java new file mode 100644 index 000000000000..e115a15cc90a --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java @@ -0,0 +1,631 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.*; +import com.google.cloud.dataflow.sdk.coders.*; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.*; +import com.google.cloud.dataflow.sdk.values.*; +import com.google.common.base.Preconditions; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.accumulators.AccumulatorHelper; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.operators.*; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTaskState; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.*; + +/** + * This class is the key class implementing all the windowing/triggering logic of Apache Beam. + * To provide full compatibility and support for all the windowing/triggering combinations offered by + * Beam, we opted for a strategy that uses the SDK's code for doing these operations. See the code in + * ({@link com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsDoFn}. + *

    + * In a nutshell, when the execution arrives to this operator, we expect to have a stream already + * grouped by key. Each of the elements that enter here, registers a timer + * (see {@link TimerInternals#setTimer(TimerInternals.TimerData)} in the + * {@link FlinkGroupAlsoByWindowWrapper#activeTimers}. + * This is essentially a timestamp indicating when to trigger the computation over the window this + * element belongs to. + *

    + * When a watermark arrives, all the registered timers are checked to see which ones are ready to + * fire (see {@link FlinkGroupAlsoByWindowWrapper#processWatermark(Watermark)}). These are deregistered from + * the {@link FlinkGroupAlsoByWindowWrapper#activeTimers} + * list, and are fed into the {@link com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsDoFn} + * for furhter processing. + */ +public class FlinkGroupAlsoByWindowWrapper + extends AbstractStreamOperator>> + implements OneInputStreamOperator>, WindowedValue>> { + + private static final long serialVersionUID = 1L; + + private transient PipelineOptions options; + + private transient CoderRegistry coderRegistry; + + private DoFn, KV> operator; + + private ProcessContext context; + + private final WindowingStrategy, BoundedWindow> windowingStrategy; + + private final Combine.KeyedCombineFn combineFn; + + private final KvCoder inputKvCoder; + + /** + * State is kept per-key. This data structure keeps this mapping between an active key, i.e. a + * key whose elements are currently waiting to be processed, and its associated state. + */ + private Map> perKeyStateInternals = new HashMap<>(); + + /** + * Timers waiting to be processed. + */ + private Map> activeTimers = new HashMap<>(); + + private FlinkTimerInternals timerInternals = new FlinkTimerInternals(); + + /** + * Creates an DataStream where elements are grouped in windows based on the specified windowing strategy. + * This method assumes that elements are already grouped by key. + *

    + * The difference with {@link #createForIterable(PipelineOptions, PCollection, KeyedStream)} + * is that this method assumes that a combiner function is provided + * (see {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn}). + * A combiner helps at increasing the speed and, in most of the cases, reduce the per-window state. + * + * @param options the general job configuration options. + * @param input the input Dataflow {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * @param groupedStreamByKey the input stream, it is assumed to already be grouped by key. + * @param combiner the combiner to be used. + * @param outputKvCoder the type of the output values. + */ + public static DataStream>> create( + PipelineOptions options, + PCollection input, + KeyedStream>, K> groupedStreamByKey, + Combine.KeyedCombineFn combiner, + KvCoder outputKvCoder) { + Preconditions.checkNotNull(options); + + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + FlinkGroupAlsoByWindowWrapper windower = new FlinkGroupAlsoByWindowWrapper<>(options, + input.getPipeline().getCoderRegistry(), input.getWindowingStrategy(), inputKvCoder, combiner); + + Coder>> windowedOutputElemCoder = WindowedValue.FullWindowedValueCoder.of( + outputKvCoder, + input.getWindowingStrategy().getWindowFn().windowCoder()); + + CoderTypeInformation>> outputTypeInfo = + new CoderTypeInformation<>(windowedOutputElemCoder); + + DataStream>> groupedByKeyAndWindow = groupedStreamByKey + .transform("GroupByWindowWithCombiner", + new CoderTypeInformation<>(outputKvCoder), + windower) + .returns(outputTypeInfo); + + return groupedByKeyAndWindow; + } + + /** + * Creates an DataStream where elements are grouped in windows based on the specified windowing strategy. + * This method assumes that elements are already grouped by key. + *

    + * The difference with {@link #create(PipelineOptions, PCollection, KeyedStream, Combine.KeyedCombineFn, KvCoder)} + * is that this method assumes no combiner function + * (see {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn}). + * + * @param options the general job configuration options. + * @param input the input Dataflow {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * @param groupedStreamByKey the input stream, it is assumed to already be grouped by key. + */ + public static DataStream>>> createForIterable( + PipelineOptions options, + PCollection input, + KeyedStream>, K> groupedStreamByKey) { + Preconditions.checkNotNull(options); + + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder inputValueCoder = inputKvCoder.getValueCoder(); + + FlinkGroupAlsoByWindowWrapper windower = new FlinkGroupAlsoByWindowWrapper(options, + input.getPipeline().getCoderRegistry(), input.getWindowingStrategy(), inputKvCoder, null); + + Coder> valueIterCoder = IterableCoder.of(inputValueCoder); + KvCoder> outputElemCoder = KvCoder.of(keyCoder, valueIterCoder); + + Coder>>> windowedOutputElemCoder = WindowedValue.FullWindowedValueCoder.of( + outputElemCoder, + input.getWindowingStrategy().getWindowFn().windowCoder()); + + CoderTypeInformation>>> outputTypeInfo = + new CoderTypeInformation<>(windowedOutputElemCoder); + + DataStream>>> groupedByKeyAndWindow = groupedStreamByKey + .transform("GroupByWindow", + new CoderTypeInformation<>(windowedOutputElemCoder), + windower) + .returns(outputTypeInfo); + + return groupedByKeyAndWindow; + } + + public static FlinkGroupAlsoByWindowWrapper + createForTesting(PipelineOptions options, + CoderRegistry registry, + WindowingStrategy, BoundedWindow> windowingStrategy, + KvCoder inputCoder, + Combine.KeyedCombineFn combiner) { + Preconditions.checkNotNull(options); + + return new FlinkGroupAlsoByWindowWrapper(options, registry, windowingStrategy, inputCoder, combiner); + } + + private FlinkGroupAlsoByWindowWrapper(PipelineOptions options, + CoderRegistry registry, + WindowingStrategy, BoundedWindow> windowingStrategy, + KvCoder inputCoder, + Combine.KeyedCombineFn combiner) { + Preconditions.checkNotNull(options); + + this.options = Preconditions.checkNotNull(options); + this.coderRegistry = Preconditions.checkNotNull(registry); + this.inputKvCoder = Preconditions.checkNotNull(inputCoder);//(KvCoder) input.getCoder(); + this.windowingStrategy = Preconditions.checkNotNull(windowingStrategy);//input.getWindowingStrategy(); + this.combineFn = combiner; + this.operator = createGroupAlsoByWindowOperator(); + this.chainingStrategy = ChainingStrategy.ALWAYS; + } + + @Override + public void open() throws Exception { + super.open(); + this.context = new ProcessContext(operator, new TimestampedCollector<>(output), this.timerInternals); + } + + /** + * Create the adequate {@link com.google.cloud.dataflow.sdk.util.GroupAlsoByWindowsDoFn}, + * if not already created. + * If a {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn} was provided, then + * a function with that combiner is created, so that elements are combined as they arrive. This is + * done for speed and (in most of the cases) for reduction of the per-window state. + */ + private DoFn, KV> createGroupAlsoByWindowOperator() { + if (this.operator == null) { + if (this.combineFn == null) { + // Thus VOUT == Iterable + Coder inputValueCoder = inputKvCoder.getValueCoder(); + + this.operator = (DoFn) GroupAlsoByWindowViaWindowSetDoFn.create( + (WindowingStrategy) this.windowingStrategy, SystemReduceFn.buffering(inputValueCoder)); + } else { + Coder inputKeyCoder = inputKvCoder.getKeyCoder(); + + AppliedCombineFn appliedCombineFn = AppliedCombineFn + .withInputCoder(combineFn, coderRegistry, inputKvCoder); + + this.operator = GroupAlsoByWindowViaWindowSetDoFn.create( + (WindowingStrategy) this.windowingStrategy, SystemReduceFn.combining(inputKeyCoder, appliedCombineFn)); + } + } + return this.operator; + } + + private void processKeyedWorkItem(KeyedWorkItem workItem) throws Exception { + context.setElement(workItem, getStateInternalsForKey(workItem.key())); + + // TODO: Ideally startBundle/finishBundle would be called when the operator is first used / about to be discarded. + operator.startBundle(context); + operator.processElement(context); + operator.finishBundle(context); + } + + @Override + public void processElement(StreamRecord>> element) throws Exception { + ArrayList> elements = new ArrayList<>(); + elements.add(WindowedValue.of(element.getValue().getValue().getValue(), element.getValue().getTimestamp(), + element.getValue().getWindows(), element.getValue().getPane())); + processKeyedWorkItem(KeyedWorkItems.elementsWorkItem(element.getValue().getValue().getKey(), elements)); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + context.setCurrentInputWatermark(new Instant(mark.getTimestamp())); + + Multimap timers = getTimersReadyToProcess(mark.getTimestamp()); + if (!timers.isEmpty()) { + for (K key : timers.keySet()) { + processKeyedWorkItem(KeyedWorkItems.timersWorkItem(key, timers.get(key))); + } + } + + /** + * This is to take into account the different semantics of the Watermark in Flink and + * in Dataflow. To understand the reasoning behind the Dataflow semantics and its + * watermark holding logic, see the documentation of + * {@link WatermarkHold#addHold(ReduceFn.ProcessValueContext, boolean)} + * */ + long millis = Long.MAX_VALUE; + for (FlinkStateInternals state : perKeyStateInternals.values()) { + Instant watermarkHold = state.getWatermarkHold(); + if (watermarkHold != null && watermarkHold.getMillis() < millis) { + millis = watermarkHold.getMillis(); + } + } + + if (mark.getTimestamp() < millis) { + millis = mark.getTimestamp(); + } + + context.setCurrentOutputWatermark(new Instant(millis)); + + // Don't forget to re-emit the watermark for further operators down the line. + // This is critical for jobs with multiple aggregation steps. + // Imagine a job with a groupByKey() on key K1, followed by a map() that changes + // the key K1 to K2, and another groupByKey() on K2. In this case, if the watermark + // is not re-emitted, the second aggregation would never be triggered, and no result + // will be produced. + output.emitWatermark(new Watermark(millis)); + } + + @Override + public void close() throws Exception { + super.close(); + } + + private void registerActiveTimer(K key, TimerInternals.TimerData timer) { + Set timersForKey = activeTimers.get(key); + if (timersForKey == null) { + timersForKey = new HashSet<>(); + } + timersForKey.add(timer); + activeTimers.put(key, timersForKey); + } + + private void unregisterActiveTimer(K key, TimerInternals.TimerData timer) { + Set timersForKey = activeTimers.get(key); + if (timersForKey != null) { + timersForKey.remove(timer); + if (timersForKey.isEmpty()) { + activeTimers.remove(key); + } else { + activeTimers.put(key, timersForKey); + } + } + } + + /** + * Returns the list of timers that are ready to fire. These are the timers + * that are registered to be triggered at a time before the current watermark. + * We keep these timers in a Set, so that they are deduplicated, as the same + * timer can be registered multiple times. + */ + private Multimap getTimersReadyToProcess(long currentWatermark) { + + // we keep the timers to return in a different list and launch them later + // because we cannot prevent a trigger from registering another trigger, + // which would lead to concurrent modification exception. + Multimap toFire = HashMultimap.create(); + + Iterator>> it = activeTimers.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry> keyWithTimers = it.next(); + + Iterator timerIt = keyWithTimers.getValue().iterator(); + while (timerIt.hasNext()) { + TimerInternals.TimerData timerData = timerIt.next(); + if (timerData.getTimestamp().isBefore(currentWatermark)) { + toFire.put(keyWithTimers.getKey(), timerData); + timerIt.remove(); + } + } + + if (keyWithTimers.getValue().isEmpty()) { + it.remove(); + } + } + return toFire; + } + + /** + * Gets the state associated with the specified key. + * + * @param key the key whose state we want. + * @return The {@link FlinkStateInternals} + * associated with that key. + */ + private FlinkStateInternals getStateInternalsForKey(K key) { + FlinkStateInternals stateInternals = perKeyStateInternals.get(key); + if (stateInternals == null) { + Coder windowCoder = this.windowingStrategy.getWindowFn().windowCoder(); + OutputTimeFn outputTimeFn = this.windowingStrategy.getWindowFn().getOutputTimeFn(); + stateInternals = new FlinkStateInternals<>(key, inputKvCoder.getKeyCoder(), windowCoder, outputTimeFn); + perKeyStateInternals.put(key, stateInternals); + } + return stateInternals; + } + + private class FlinkTimerInternals extends AbstractFlinkTimerInternals { + @Override + public void setTimer(TimerData timerKey) { + registerActiveTimer(context.element().key(), timerKey); + } + + @Override + public void deleteTimer(TimerData timerKey) { + unregisterActiveTimer(context.element().key(), timerKey); + } + } + + private class ProcessContext extends GroupAlsoByWindowViaWindowSetDoFn>.ProcessContext { + + private final FlinkTimerInternals timerInternals; + + private final TimestampedCollector>> collector; + + private FlinkStateInternals stateInternals; + + private KeyedWorkItem element; + + public ProcessContext(DoFn, KV> function, + TimestampedCollector>> outCollector, + FlinkTimerInternals timerInternals) { + function.super(); + super.setupDelegateAggregators(); + + this.collector = Preconditions.checkNotNull(outCollector); + this.timerInternals = Preconditions.checkNotNull(timerInternals); + } + + public void setElement(KeyedWorkItem element, + FlinkStateInternals stateForKey) { + this.element = element; + this.stateInternals = stateForKey; + } + + public void setCurrentInputWatermark(Instant watermark) { + this.timerInternals.setCurrentInputWatermark(watermark); + } + + public void setCurrentOutputWatermark(Instant watermark) { + this.timerInternals.setCurrentOutputWatermark(watermark); + } + + @Override + public KeyedWorkItem element() { + return this.element; + } + + @Override + public Instant timestamp() { + throw new UnsupportedOperationException("timestamp() is not available when processing KeyedWorkItems."); + } + + @Override + public PipelineOptions getPipelineOptions() { + // TODO: PipelineOptions need to be available on the workers. + // Ideally they are captured as part of the pipeline. + // For now, construct empty options so that StateContexts.createFromComponents + // will yield a valid StateContext, which is needed to support the StateContext.window(). + if (options == null) { + options = new PipelineOptions() { + @Override + public T as(Class kls) { + return null; + } + + @Override + public T cloneAs(Class kls) { + return null; + } + + @Override + public Class> getRunner() { + return null; + } + + @Override + public void setRunner(Class> kls) { + + } + + @Override + public CheckEnabled getStableUniqueNames() { + return null; + } + + @Override + public void setStableUniqueNames(CheckEnabled enabled) { + } + }; + } + return options; + } + + @Override + public void output(KV output) { + throw new UnsupportedOperationException( + "output() is not available when processing KeyedWorkItems."); + } + + @Override + public void outputWithTimestamp(KV output, Instant timestamp) { + throw new UnsupportedOperationException( + "outputWithTimestamp() is not available when processing KeyedWorkItems."); + } + + @Override + public PaneInfo pane() { + throw new UnsupportedOperationException("pane() is not available when processing KeyedWorkItems."); + } + + @Override + public BoundedWindow window() { + throw new UnsupportedOperationException( + "window() is not available when processing KeyedWorkItems."); + } + + @Override + public WindowingInternals, KV> windowingInternals() { + return new WindowingInternals, KV>() { + + @Override + public com.google.cloud.dataflow.sdk.util.state.StateInternals stateInternals() { + return stateInternals; + } + + @Override + public void outputWindowedValue(KV output, Instant timestamp, Collection windows, PaneInfo pane) { + // TODO: No need to represent timestamp twice. + collector.setAbsoluteTimestamp(timestamp.getMillis()); + collector.collect(WindowedValue.of(output, timestamp, windows, pane)); + + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException("windows() is not available in Streaming mode."); + } + + @Override + public PaneInfo pane() { + throw new UnsupportedOperationException("pane() is not available in Streaming mode."); + } + + @Override + public void writePCollectionViewData(TupleTag tag, Iterable> data, Coder elemCoder) throws IOException { + throw new RuntimeException("writePCollectionViewData() not available in Streaming mode."); + } + + @Override + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + throw new RuntimeException("sideInput() is not available in Streaming mode."); + } + }; + } + + @Override + public T sideInput(PCollectionView view) { + throw new RuntimeException("sideInput() is not supported in Streaming mode."); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + // ignore the side output, this can happen when a user does not register + // side outputs but then outputs using a freshly created TupleTag. + throw new RuntimeException("sideOutput() is not available when grouping by window."); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { + Accumulator acc = getRuntimeContext().getAccumulator(name); + if (acc != null) { + AccumulatorHelper.compareAccumulatorTypes(name, + SerializableFnAggregatorWrapper.class, acc.getClass()); + return (Aggregator) acc; + } + + SerializableFnAggregatorWrapper accumulator = + new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, accumulator); + return accumulator; + } + } + + ////////////// Checkpointing implementation //////////////// + + @Override + public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) throws Exception { + StreamTaskState taskState = super.snapshotOperatorState(checkpointId, timestamp); + AbstractStateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); + StateCheckpointWriter writer = StateCheckpointWriter.create(out); + Coder keyCoder = inputKvCoder.getKeyCoder(); + + // checkpoint the timers + StateCheckpointUtils.encodeTimers(activeTimers, writer, keyCoder); + + // checkpoint the state + StateCheckpointUtils.encodeState(perKeyStateInternals, writer, keyCoder); + + // checkpoint the timerInternals + context.timerInternals.encodeTimerInternals(context, writer, + inputKvCoder, windowingStrategy.getWindowFn().windowCoder()); + + taskState.setOperatorState(out.closeAndGetHandle()); + return taskState; + } + + @Override + public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception { + super.restoreState(taskState, recoveryTimestamp); + + final ClassLoader userClassloader = getUserCodeClassloader(); + + Coder windowCoder = this.windowingStrategy.getWindowFn().windowCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); + + @SuppressWarnings("unchecked") + StateHandle inputState = (StateHandle) taskState.getOperatorState(); + DataInputView in = inputState.getState(userClassloader); + StateCheckpointReader reader = new StateCheckpointReader(in); + + // restore the timers + this.activeTimers = StateCheckpointUtils.decodeTimers(reader, windowCoder, keyCoder); + + // restore the state + this.perKeyStateInternals = StateCheckpointUtils.decodeState( + reader, windowingStrategy.getOutputTimeFn(), keyCoder, windowCoder, userClassloader); + + // restore the timerInternals. + this.timerInternals.restoreTimerInternals(reader, inputKvCoder, windowCoder); + } +} \ No newline at end of file diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java new file mode 100644 index 000000000000..1a6a665858c2 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; + +/** + * This class groups the elements by key. It assumes that already the incoming stream + * is composed of [Key,Value] pairs. + * */ +public class FlinkGroupByKeyWrapper { + + /** + * Just an auxiliary interface to bypass the fact that java anonymous classes cannot implement + * multiple interfaces. + */ + private interface KeySelectorWithQueryableResultType extends KeySelector>, K>, ResultTypeQueryable { + } + + public static KeyedStream>, K> groupStreamByKey(DataStream>> inputDataStream, KvCoder inputKvCoder) { + final Coder keyCoder = inputKvCoder.getKeyCoder(); + final TypeInformation keyTypeInfo = new CoderTypeInformation<>(keyCoder); + final boolean isKeyVoid = keyCoder instanceof VoidCoder; + + return inputDataStream.keyBy( + new KeySelectorWithQueryableResultType() { + + @Override + public K getKey(WindowedValue> value) throws Exception { + return isKeyVoid ? (K) VoidCoderTypeSerializer.VoidValue.INSTANCE : + value.getValue().getKey(); + } + + @Override + public TypeInformation getProducedType() { + return keyTypeInfo; + } + }); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundMultiWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundMultiWrapper.java new file mode 100644 index 000000000000..df7f95355c19 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundMultiWrapper.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Preconditions; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.util.Map; + +/** + * A wrapper for the {@link com.google.cloud.dataflow.sdk.transforms.ParDo.BoundMulti} Beam transformation. + * */ +public class FlinkParDoBoundMultiWrapper extends FlinkAbstractParDoWrapper { + + private final TupleTag mainTag; + private final Map, Integer> outputLabels; + + public FlinkParDoBoundMultiWrapper(PipelineOptions options, WindowingStrategy windowingStrategy, DoFn doFn, TupleTag mainTag, Map, Integer> tagsToLabels) { + super(options, windowingStrategy, doFn); + this.mainTag = Preconditions.checkNotNull(mainTag); + this.outputLabels = Preconditions.checkNotNull(tagsToLabels); + } + + @Override + public void outputWithTimestampHelper(WindowedValue inElement, OUT output, Instant timestamp, Collector> collector) { + checkTimestamp(inElement, timestamp); + Integer index = outputLabels.get(mainTag); + collector.collect(makeWindowedValue( + new RawUnionValue(index, output), + timestamp, + inElement.getWindows(), + inElement.getPane())); + } + + @Override + public void sideOutputWithTimestampHelper(WindowedValue inElement, T output, Instant timestamp, Collector> collector, TupleTag tag) { + checkTimestamp(inElement, timestamp); + Integer index = outputLabels.get(tag); + if (index != null) { + collector.collect(makeWindowedValue( + new RawUnionValue(index, output), + timestamp, + inElement.getWindows(), + inElement.getPane())); + } + } + + @Override + public WindowingInternals windowingInternalsHelper(WindowedValue inElement, Collector> outCollector) { + throw new RuntimeException("FlinkParDoBoundMultiWrapper is just an internal operator serving as " + + "an intermediate transformation for the ParDo.BoundMulti translation. windowingInternals() " + + "is not available in this class."); + } +} \ No newline at end of file diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundWrapper.java new file mode 100644 index 000000000000..2ed56203b33a --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkParDoBoundWrapper.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.*; +import com.google.cloud.dataflow.sdk.util.state.StateInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.*; + +/** + * A wrapper for the {@link com.google.cloud.dataflow.sdk.transforms.ParDo.Bound} Beam transformation. + * */ +public class FlinkParDoBoundWrapper extends FlinkAbstractParDoWrapper { + + public FlinkParDoBoundWrapper(PipelineOptions options, WindowingStrategy windowingStrategy, DoFn doFn) { + super(options, windowingStrategy, doFn); + } + + @Override + public void outputWithTimestampHelper(WindowedValue inElement, OUT output, Instant timestamp, Collector> collector) { + checkTimestamp(inElement, timestamp); + collector.collect(makeWindowedValue( + output, + timestamp, + inElement.getWindows(), + inElement.getPane())); + } + + @Override + public void sideOutputWithTimestampHelper(WindowedValue inElement, T output, Instant timestamp, Collector> outCollector, TupleTag tag) { + // ignore the side output, this can happen when a user does not register + // side outputs but then outputs using a freshly created TupleTag. + throw new RuntimeException("sideOutput() not not available in ParDo.Bound()."); + } + + @Override + public WindowingInternals windowingInternalsHelper(final WindowedValue inElement, final Collector> collector) { + return new WindowingInternals() { + @Override + public StateInternals stateInternals() { + throw new NullPointerException("StateInternals are not available for ParDo.Bound()."); + } + + @Override + public void outputWindowedValue(OUT output, Instant timestamp, Collection windows, PaneInfo pane) { + collector.collect(makeWindowedValue(output, timestamp, windows, pane)); + } + + @Override + public TimerInternals timerInternals() { + throw new NullPointerException("TimeInternals are not available for ParDo.Bound()."); + } + + @Override + public Collection windows() { + return inElement.getWindows(); + } + + @Override + public PaneInfo pane() { + return inElement.getPane(); + } + + @Override + public void writePCollectionViewData(TupleTag tag, Iterable> data, Coder elemCoder) throws IOException { + throw new RuntimeException("writePCollectionViewData() not supported in Streaming mode."); + } + + @Override + public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { + throw new RuntimeException("sideInput() not implemented."); + } + }; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java new file mode 100644 index 000000000000..f6c243fe1230 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.ByteArrayInputStream; +import java.util.List; + +/** + * This flat map function bootstraps from collection elements and turns them into WindowedValues + * (as required by the Flink runner). + */ +public class FlinkStreamingCreateFunction implements FlatMapFunction> { + + private final List elements; + private final Coder coder; + + public FlinkStreamingCreateFunction(List elements, Coder coder) { + this.elements = elements; + this.coder = coder; + } + + @Override + public void flatMap(IN value, Collector> out) throws Exception { + + @SuppressWarnings("unchecked") + OUT voidValue = (OUT) VoidCoderTypeSerializer.VoidValue.INSTANCE; + for (byte[] element : elements) { + ByteArrayInputStream bai = new ByteArrayInputStream(element); + OUT outValue = coder.decode(bai, Coder.Context.OUTER); + + if (outValue == null) { + out.collect(WindowedValue.of(voidValue, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + } else { + out.collect(WindowedValue.of(outValue, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + } + } + + out.close(); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedFlinkSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedFlinkSource.java new file mode 100644 index 000000000000..2857efd40a23 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedFlinkSource.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import org.apache.beam.runners.flink.FlinkPipelineRunner; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * A wrapper translating Flink Sources implementing the {@link RichParallelSourceFunction} interface, into + * unbounded Beam sources (see {@link UnboundedSource}). + * */ +public class UnboundedFlinkSource extends UnboundedSource { + + private final PipelineOptions options; + private final RichParallelSourceFunction flinkSource; + + public UnboundedFlinkSource(PipelineOptions pipelineOptions, RichParallelSourceFunction source) { + if(!pipelineOptions.getRunner().equals(FlinkPipelineRunner.class)) { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } + options = Preconditions.checkNotNull(pipelineOptions); + flinkSource = Preconditions.checkNotNull(source); + validate(); + } + + public RichParallelSourceFunction getFlinkSource() { + return this.flinkSource; + } + + @Override + public List> generateInitialSplits(int desiredNumSplits, PipelineOptions options) throws Exception { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } + + @Override + public UnboundedReader createReader(PipelineOptions options, @Nullable C checkpointMark) { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } + + @Nullable + @Override + public Coder getCheckpointMarkCoder() { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } + + + @Override + public void validate() { + Preconditions.checkNotNull(options); + Preconditions.checkNotNull(flinkSource); + if(!options.getRunner().equals(FlinkPipelineRunner.class)) { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } + } + + @Override + public Coder getDefaultOutputCoder() { + throw new RuntimeException("Flink Sources are supported only when running with the FlinkPipelineRunner."); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java new file mode 100644 index 000000000000..1389e9d98cdb --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * An example unbounded Beam source that reads input from a socket. This is used mainly for testing and debugging. + * */ +public class UnboundedSocketSource extends UnboundedSource { + + private static final Coder DEFAULT_SOCKET_CODER = StringUtf8Coder.of(); + + private static final long serialVersionUID = 1L; + + private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500; + + private static final int CONNECTION_TIMEOUT_TIME = 0; + + private final String hostname; + private final int port; + private final char delimiter; + private final long maxNumRetries; + private final long delayBetweenRetries; + + public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries) { + this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP); + } + + public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries, long delayBetweenRetries) { + this.hostname = hostname; + this.port = port; + this.delimiter = delimiter; + this.maxNumRetries = maxNumRetries; + this.delayBetweenRetries = delayBetweenRetries; + } + + public String getHostname() { + return this.hostname; + } + + public int getPort() { + return this.port; + } + + public char getDelimiter() { + return this.delimiter; + } + + public long getMaxNumRetries() { + return this.maxNumRetries; + } + + public long getDelayBetweenRetries() { + return this.delayBetweenRetries; + } + + @Override + public List> generateInitialSplits(int desiredNumSplits, PipelineOptions options) throws Exception { + return Collections.>singletonList(this); + } + + @Override + public UnboundedReader createReader(PipelineOptions options, @Nullable C checkpointMark) { + return new UnboundedSocketReader(this); + } + + @Nullable + @Override + public Coder getCheckpointMarkCoder() { + // Flink and Dataflow have different checkpointing mechanisms. + // In our case we do not need a coder. + return null; + } + + @Override + public void validate() { + checkArgument(port > 0 && port < 65536, "port is out of range"); + checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), or -1 (infinite retries)"); + checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive"); + } + + @Override + public Coder getDefaultOutputCoder() { + return DEFAULT_SOCKET_CODER; + } + + public static class UnboundedSocketReader extends UnboundedSource.UnboundedReader implements Serializable { + + private static final long serialVersionUID = 7526472295622776147L; + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSocketReader.class); + + private final UnboundedSocketSource source; + + private Socket socket; + private BufferedReader reader; + + private boolean isRunning; + + private String currentRecord; + + public UnboundedSocketReader(UnboundedSocketSource source) { + this.source = source; + } + + private void openConnection() throws IOException { + this.socket = new Socket(); + this.socket.connect(new InetSocketAddress(this.source.getHostname(), this.source.getPort()), CONNECTION_TIMEOUT_TIME); + this.reader = new BufferedReader(new InputStreamReader(this.socket.getInputStream())); + this.isRunning = true; + } + + @Override + public boolean start() throws IOException { + int attempt = 0; + while (!isRunning) { + try { + openConnection(); + LOG.info("Connected to server socket " + this.source.getHostname() + ':' + this.source.getPort()); + + return advance(); + } catch (IOException e) { + LOG.info("Lost connection to server socket " + this.source.getHostname() + ':' + this.source.getPort() + ". Retrying in " + this.source.getDelayBetweenRetries() + " msecs..."); + + if (this.source.getMaxNumRetries() == -1 || attempt++ < this.source.getMaxNumRetries()) { + try { + Thread.sleep(this.source.getDelayBetweenRetries()); + } catch (InterruptedException e1) { + e1.printStackTrace(); + } + } else { + this.isRunning = false; + break; + } + } + } + LOG.error("Unable to connect to host " + this.source.getHostname() + " : " + this.source.getPort()); + return false; + } + + @Override + public boolean advance() throws IOException { + final StringBuilder buffer = new StringBuilder(); + int data; + while (isRunning && (data = reader.read()) != -1) { + // check if the string is complete + if (data != this.source.getDelimiter()) { + buffer.append((char) data); + } else { + if (buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') { + buffer.setLength(buffer.length() - 1); + } + this.currentRecord = buffer.toString(); + buffer.setLength(0); + return true; + } + } + return false; + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + return new byte[0]; + } + + @Override + public String getCurrent() throws NoSuchElementException { + return this.currentRecord; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return Instant.now(); + } + + @Override + public void close() throws IOException { + this.reader.close(); + this.socket.close(); + this.isRunning = false; + LOG.info("Closed connection to server socket at " + this.source.getHostname() + ":" + this.source.getPort() + "."); + } + + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public CheckpointMark getCheckpointMark() { + return null; + } + + @Override + public UnboundedSource getCurrentSource() { + return this.source; + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java new file mode 100644 index 000000000000..97084cfc7ab8 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.UnboundedSource; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.operators.Triggerable; +import org.joda.time.Instant; + +/** + * A wrapper for Beam's unbounded sources. This class wraps around a source implementing the {@link com.google.cloud.dataflow.sdk.io.Read.Unbounded} + * interface. + * + *

    + * For now we support non-parallel, not checkpointed sources. + * */ +public class UnboundedSourceWrapper extends RichSourceFunction> implements Triggerable { + + private final String name; + private final UnboundedSource.UnboundedReader reader; + + private StreamingRuntimeContext runtime = null; + private StreamSource.ManualWatermarkContext> context = null; + + private volatile boolean isRunning = false; + + public UnboundedSourceWrapper(PipelineOptions options, Read.Unbounded transform) { + this.name = transform.getName(); + this.reader = transform.getSource().createReader(options, null); + } + + public String getName() { + return this.name; + } + + WindowedValue makeWindowedValue(T output, Instant timestamp) { + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + return WindowedValue.of(output, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + } + + @Override + public void run(SourceContext> ctx) throws Exception { + if (!(ctx instanceof StreamSource.ManualWatermarkContext)) { + throw new RuntimeException("We assume that all sources in Dataflow are EventTimeSourceFunction. " + + "Apparently " + this.name + " is not. Probably you should consider writing your own Wrapper for this source."); + } + + context = (StreamSource.ManualWatermarkContext>) ctx; + runtime = (StreamingRuntimeContext) getRuntimeContext(); + + this.isRunning = true; + boolean inputAvailable = reader.start(); + + setNextWatermarkTimer(this.runtime); + + while (isRunning) { + + while (!inputAvailable && isRunning) { + // wait a bit until we retry to pull more records + Thread.sleep(50); + inputAvailable = reader.advance(); + } + + if (inputAvailable) { + + // get it and its timestamp from the source + T item = reader.getCurrent(); + Instant timestamp = reader.getCurrentTimestamp(); + + // write it to the output collector + synchronized (ctx.getCheckpointLock()) { + context.collectWithTimestamp(makeWindowedValue(item, timestamp), timestamp.getMillis()); + } + + inputAvailable = reader.advance(); + } + + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void trigger(long timestamp) throws Exception { + if (this.isRunning) { + synchronized (context.getCheckpointLock()) { + long watermarkMillis = this.reader.getWatermark().getMillis(); + context.emitWatermark(new Watermark(watermarkMillis)); + } + setNextWatermarkTimer(this.runtime); + } + } + + private void setNextWatermarkTimer(StreamingRuntimeContext runtime) { + if (this.isRunning) { + long watermarkInterval = runtime.getExecutionConfig().getAutoWatermarkInterval(); + long timeToNextWatermark = getTimeToNextWaternark(watermarkInterval); + runtime.registerTimer(timeToNextWatermark, this); + } + } + + private long getTimeToNextWaternark(long watermarkInterval) { + return System.currentTimeMillis() + watermarkInterval; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java new file mode 100644 index 000000000000..fc759486dc58 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.values.KV; +import org.joda.time.Instant; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.io.Serializable; + +/** + * An implementation of Beam's {@link TimerInternals}, that also provides serialization functionality. + * The latter is used when snapshots of the current state are taken, for fault-tolerance. + * */ +public abstract class AbstractFlinkTimerInternals implements TimerInternals, Serializable { + private Instant currentInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + private Instant currentOutputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + + public void setCurrentInputWatermark(Instant watermark) { + checkIfValidInputWatermark(watermark); + this.currentInputWatermark = watermark; + } + + public void setCurrentOutputWatermark(Instant watermark) { + checkIfValidOutputWatermark(watermark); + this.currentOutputWatermark = watermark; + } + + private void setCurrentInputWatermarkAfterRecovery(Instant watermark) { + if (!currentInputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { + throw new RuntimeException("Explicitly setting the input watermark is only allowed on " + + "initialization after recovery from a node failure. Apparently this is not " + + "the case here as the watermark is already set."); + } + this.currentInputWatermark = watermark; + } + + private void setCurrentOutputWatermarkAfterRecovery(Instant watermark) { + if (!currentOutputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { + throw new RuntimeException("Explicitly setting the output watermark is only allowed on " + + "initialization after recovery from a node failure. Apparently this is not " + + "the case here as the watermark is already set."); + } + this.currentOutputWatermark = watermark; + } + + @Override + public Instant currentProcessingTime() { + return Instant.now(); + } + + @Override + public Instant currentInputWatermarkTime() { + return currentInputWatermark; + } + + @Nullable + @Override + public Instant currentSynchronizedProcessingTime() { + // TODO + return null; + } + + @Override + public Instant currentOutputWatermarkTime() { + return currentOutputWatermark; + } + + private void checkIfValidInputWatermark(Instant newWatermark) { + if (currentInputWatermark.isAfter(newWatermark)) { + throw new IllegalArgumentException(String.format( + "Cannot set current input watermark to %s. Newer watermarks " + + "must be no earlier than the current one (%s).", + newWatermark, currentInputWatermark)); + } + } + + private void checkIfValidOutputWatermark(Instant newWatermark) { + if (currentOutputWatermark.isAfter(newWatermark)) { + throw new IllegalArgumentException(String.format( + "Cannot set current output watermark to %s. Newer watermarks " + + "must be no earlier than the current one (%s).", + newWatermark, currentOutputWatermark)); + } + } + + public void encodeTimerInternals(DoFn.ProcessContext context, + StateCheckpointWriter writer, + KvCoder kvCoder, + Coder windowCoder) throws IOException { + if (context == null) { + throw new RuntimeException("The Context has not been initialized."); + } + + writer.setTimestamp(currentInputWatermark); + writer.setTimestamp(currentOutputWatermark); + } + + public void restoreTimerInternals(StateCheckpointReader reader, + KvCoder kvCoder, + Coder windowCoder) throws IOException { + setCurrentInputWatermarkAfterRecovery(reader.getTimestamp()); + setCurrentOutputWatermarkAfterRecovery(reader.getTimestamp()); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java new file mode 100644 index 000000000000..6cf46e5344b3 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -0,0 +1,715 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.state.*; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; +import org.apache.flink.util.InstantiationUtil; +import org.joda.time.Instant; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.*; + +/** + * An implementation of the Beam {@link StateInternals}. This implementation simply keeps elements in memory. + * This state is periodically checkpointed by Flink, for fault-tolerance. + * + * TODO: State should be rewritten to redirect to Flink per-key state so that coders and combiners don't need + * to be serialized along with encoded values when snapshotting. + */ +public class FlinkStateInternals implements StateInternals { + + private final K key; + + private final Coder keyCoder; + + private final Coder windowCoder; + + private final OutputTimeFn outputTimeFn; + + private Instant watermarkHoldAccessor; + + public FlinkStateInternals(K key, + Coder keyCoder, + Coder windowCoder, + OutputTimeFn outputTimeFn) { + this.key = key; + this.keyCoder = keyCoder; + this.windowCoder = windowCoder; + this.outputTimeFn = outputTimeFn; + } + + public Instant getWatermarkHold() { + return watermarkHoldAccessor; + } + + /** + * This is the interface state has to implement in order for it to be fault tolerant when + * executed by the FlinkPipelineRunner. + */ + private interface CheckpointableIF { + + boolean shouldPersist(); + + void persistState(StateCheckpointWriter checkpointBuilder) throws IOException; + } + + protected final StateTable inMemoryState = new StateTable() { + @Override + protected StateTag.StateBinder binderForNamespace(final StateNamespace namespace, final StateContext c) { + return new StateTag.StateBinder() { + + @Override + public ValueState bindValue(StateTag> address, Coder coder) { + return new FlinkInMemoryValue<>(encodeKey(namespace, address), coder); + } + + @Override + public BagState bindBag(StateTag> address, Coder elemCoder) { + return new FlinkInMemoryBag<>(encodeKey(namespace, address), elemCoder); + } + + @Override + public AccumulatorCombiningState bindCombiningValue( + StateTag> address, + Coder accumCoder, Combine.CombineFn combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public AccumulatorCombiningState bindKeyedCombiningValue( + StateTag> address, + Coder accumCoder, + Combine.KeyedCombineFn combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public AccumulatorCombiningState bindKeyedCombiningValueWithContext( + StateTag> address, + Coder accumCoder, + CombineWithContext.KeyedCombineFnWithContext combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public WatermarkHoldState bindWatermark(StateTag> address, OutputTimeFn outputTimeFn) { + return new FlinkWatermarkHoldStateImpl<>(encodeKey(namespace, address), outputTimeFn); + } + }; + } + }; + + @Override + public K getKey() { + return key; + } + + @Override + public StateT state(StateNamespace namespace, StateTag address) { + return inMemoryState.get(namespace, address, null); + } + + @Override + public T state(StateNamespace namespace, StateTag address, StateContext c) { + return inMemoryState.get(namespace, address, c); + } + + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + checkpointBuilder.writeInt(getNoOfElements()); + + for (State location : inMemoryState.values()) { + if (!(location instanceof CheckpointableIF)) { + throw new IllegalStateException(String.format( + "%s wasn't created by %s -- unable to persist it", + location.getClass().getSimpleName(), + getClass().getSimpleName())); + } + ((CheckpointableIF) location).persistState(checkpointBuilder); + } + } + + public void restoreState(StateCheckpointReader checkpointReader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + // the number of elements to read. + int noOfElements = checkpointReader.getInt(); + for (int i = 0; i < noOfElements; i++) { + decodeState(checkpointReader, loader); + } + } + + /** + * We remove the first character which encodes the type of the stateTag ('s' for system + * and 'u' for user). For more details check out the source of + * {@link StateTags.StateTagBase#getId()}. + */ + private void decodeState(StateCheckpointReader reader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + StateType stateItemType = StateType.deserialize(reader); + ByteString stateKey = reader.getTag(); + + // first decode the namespace and the tagId... + String[] namespaceAndTag = stateKey.toStringUtf8().split("\\+"); + if (namespaceAndTag.length != 2) { + throw new IllegalArgumentException("Invalid stateKey " + stateKey.toString() + "."); + } + StateNamespace namespace = StateNamespaces.fromString(namespaceAndTag[0], windowCoder); + + // ... decide if it is a system or user stateTag... + char ownerTag = namespaceAndTag[1].charAt(0); + if (ownerTag != 's' && ownerTag != 'u') { + throw new RuntimeException("Invalid StateTag name."); + } + boolean isSystemTag = ownerTag == 's'; + String tagId = namespaceAndTag[1].substring(1); + + // ...then decode the coder (if there is one)... + Coder coder = null; + switch (stateItemType) { + case VALUE: + case LIST: + case ACCUMULATOR: + ByteString coderBytes = reader.getData(); + coder = InstantiationUtil.deserializeObject(coderBytes.toByteArray(), loader); + break; + case WATERMARK: + break; + } + + // ...then decode the combiner function (if there is one)... + CombineWithContext.KeyedCombineFnWithContext combineFn = null; + switch (stateItemType) { + case ACCUMULATOR: + ByteString combinerBytes = reader.getData(); + combineFn = InstantiationUtil.deserializeObject(combinerBytes.toByteArray(), loader); + break; + case VALUE: + case LIST: + case WATERMARK: + break; + } + + //... and finally, depending on the type of the state being decoded, + // 1) create the adequate stateTag, + // 2) create the state container, + // 3) restore the actual content. + switch (stateItemType) { + case VALUE: { + StateTag stateTag = StateTags.value(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryValue value = (FlinkInMemoryValue) inMemoryState.get(namespace, stateTag, null); + value.restoreState(reader); + break; + } + case WATERMARK: { + @SuppressWarnings("unchecked") + StateTag> stateTag = StateTags.watermarkStateInternal(tagId, outputTimeFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkWatermarkHoldStateImpl watermark = (FlinkWatermarkHoldStateImpl) inMemoryState.get(namespace, stateTag, null); + watermark.restoreState(reader); + break; + } + case LIST: { + StateTag stateTag = StateTags.bag(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + FlinkInMemoryBag bag = (FlinkInMemoryBag) inMemoryState.get(namespace, stateTag, null); + bag.restoreState(reader); + break; + } + case ACCUMULATOR: { + @SuppressWarnings("unchecked") + StateTag> stateTag = StateTags.keyedCombiningValueWithContext(tagId, (Coder) coder, combineFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryKeyedCombiningValue combiningValue = + (FlinkInMemoryKeyedCombiningValue) inMemoryState.get(namespace, stateTag, null); + combiningValue.restoreState(reader); + break; + } + default: + throw new RuntimeException("Unknown State Type " + stateItemType + "."); + } + } + + private ByteString encodeKey(StateNamespace namespace, StateTag address) { + StringBuilder sb = new StringBuilder(); + try { + namespace.appendTo(sb); + sb.append('+'); + address.appendTo(sb); + } catch (IOException e) { + throw new RuntimeException(e); + } + return ByteString.copyFromUtf8(sb.toString()); + } + + private int getNoOfElements() { + int noOfElements = 0; + for (State state : inMemoryState.values()) { + if (!(state instanceof CheckpointableIF)) { + throw new RuntimeException("State Implementations used by the " + + "Flink Dataflow Runner should implement the CheckpointableIF interface."); + } + + if (((CheckpointableIF) state).shouldPersist()) { + noOfElements++; + } + } + return noOfElements; + } + + private final class FlinkInMemoryValue implements ValueState, CheckpointableIF { + + private final ByteString stateKey; + private final Coder elemCoder; + + private T value = null; + + public FlinkInMemoryValue(ByteString stateKey, Coder elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + value = null; + } + + @Override + public void write(T input) { + this.value = input; + } + + @Override + public T read() { + return value; + } + + @Override + public ValueState readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return value != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (value != null) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + // encode the value into a ByteString + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(value, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + checkpointBuilder.addValueBuilder() + .setTag(stateKey) + .setData(coder) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + write(outValue); + } + } + + private final class FlinkWatermarkHoldStateImpl + implements WatermarkHoldState, CheckpointableIF { + + private final ByteString stateKey; + + private Instant minimumHold = null; + + private OutputTimeFn outputTimeFn; + + public FlinkWatermarkHoldStateImpl(ByteString stateKey, OutputTimeFn outputTimeFn) { + this.stateKey = stateKey; + this.outputTimeFn = outputTimeFn; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this WatermarkBagInternal. + minimumHold = null; + watermarkHoldAccessor = null; + } + + @Override + public void add(Instant watermarkHold) { + if (minimumHold == null || minimumHold.isAfter(watermarkHold)) { + watermarkHoldAccessor = watermarkHold; + minimumHold = watermarkHold; + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + return minimumHold == null; + } + + @Override + public ReadableState readLater() { + // Ignore + return this; + } + }; + } + + @Override + public OutputTimeFn getOutputTimeFn() { + return outputTimeFn; + } + + @Override + public Instant read() { + return minimumHold; + } + + @Override + public WatermarkHoldState readLater() { + // Ignore + return this; + } + + @Override + public String toString() { + return Objects.toString(minimumHold); + } + + @Override + public boolean shouldPersist() { + return minimumHold != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (minimumHold != null) { + checkpointBuilder.addWatermarkHoldsBuilder() + .setTag(stateKey) + .setTimestamp(minimumHold); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + Instant watermark = checkpointReader.getTimestamp(); + add(watermark); + } + } + + + private static CombineWithContext.KeyedCombineFnWithContext withContext( + final Combine.KeyedCombineFn combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(key); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(key, accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(key, accumulator); + } + }; + } + + private static CombineWithContext.KeyedCombineFnWithContext withKeyAndContext( + final Combine.CombineFn combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(accumulator); + } + }; + } + + private final class FlinkInMemoryKeyedCombiningValue + implements AccumulatorCombiningState, CheckpointableIF { + + private final ByteString stateKey; + private final CombineWithContext.KeyedCombineFnWithContext combineFn; + private final Coder accumCoder; + private final CombineWithContext.Context context; + + private AccumT accum = null; + private boolean isClear = true; + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.CombineFn combineFn, + Coder accumCoder, + final StateContext stateContext) { + this(stateKey, withKeyAndContext(combineFn), accumCoder, stateContext); + } + + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.KeyedCombineFn combineFn, + Coder accumCoder, + final StateContext stateContext) { + this(stateKey, withContext(combineFn), accumCoder, stateContext); + } + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + CombineWithContext.KeyedCombineFnWithContext combineFn, + Coder accumCoder, + final StateContext stateContext) { + Preconditions.checkNotNull(combineFn); + Preconditions.checkNotNull(accumCoder); + + this.stateKey = stateKey; + this.combineFn = combineFn; + this.accumCoder = accumCoder; + this.context = new CombineWithContext.Context() { + @Override + public PipelineOptions getPipelineOptions() { + return stateContext.getPipelineOptions(); + } + + @Override + public T sideInput(PCollectionView view) { + return stateContext.sideInput(view); + } + }; + accum = combineFn.createAccumulator(key, context); + } + + @Override + public void clear() { + accum = combineFn.createAccumulator(key, context); + isClear = true; + } + + @Override + public void add(InputT input) { + isClear = false; + accum = combineFn.addInput(key, accum, input, context); + } + + @Override + public AccumT getAccum() { + return accum; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public ReadableState readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return isClear; + } + }; + } + + @Override + public void addAccum(AccumT accum) { + isClear = false; + this.accum = combineFn.mergeAccumulators(key, Arrays.asList(this.accum, accum), context); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(key, accumulators, context); + } + + @Override + public OutputT read() { + return combineFn.extractOutput(key, accum, context); + } + + @Override + public AccumulatorCombiningState readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return !isClear; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!isClear) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(accumCoder); + + // serialize the combiner. + byte[] combiner = InstantiationUtil.serializeObject(combineFn); + + // encode the accumulator into a ByteString + ByteString.Output stream = ByteString.newOutput(); + accumCoder.encode(accum, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // put the flag that the next serialized element is an accumulator + checkpointBuilder.addAccumulatorBuilder() + .setTag(stateKey) + .setData(coder) + .setData(combiner) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + AccumT accum = this.accumCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + addAccum(accum); + } + } + + private static final class FlinkInMemoryBag implements BagState, CheckpointableIF { + private final List contents = new ArrayList<>(); + + private final ByteString stateKey; + private final Coder elemCoder; + + public FlinkInMemoryBag(ByteString stateKey, Coder elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + contents.clear(); + } + + @Override + public Iterable read() { + return contents; + } + + @Override + public BagState readLater() { + // Ignore + return this; + } + + @Override + public void add(T input) { + contents.add(input); + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public ReadableState readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return contents.isEmpty(); + } + }; + } + + @Override + public boolean shouldPersist() { + return !contents.isEmpty(); + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!contents.isEmpty()) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + checkpointBuilder.addListUpdatesBuilder() + .setTag(stateKey) + .setData(coder) + .writeInt(contents.size()); + + for (T item : contents) { + // encode the element + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(item, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // add the data to the checkpoint. + checkpointBuilder.setData(data); + } + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + int noOfValues = checkpointReader.getInt(); + for (int j = 0; j < noOfValues; j++) { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + add(outValue); + } + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java new file mode 100644 index 000000000000..5aadccdefe3d --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.protobuf.ByteString; +import org.apache.flink.core.memory.DataInputView; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +public class StateCheckpointReader { + + private final DataInputView input; + + public StateCheckpointReader(DataInputView in) { + this.input = in; + } + + public ByteString getTag() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public String getTagToString() throws IOException { + return input.readUTF(); + } + + public ByteString getData() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public int getInt() throws IOException { + validate(); + return input.readInt(); + } + + public byte getByte() throws IOException { + validate(); + return input.readByte(); + } + + public Instant getTimestamp() throws IOException { + validate(); + Long watermarkMillis = input.readLong(); + return new Instant(TimeUnit.MICROSECONDS.toMillis(watermarkMillis)); + } + + public K deserializeKey(CoderTypeSerializer keySerializer) throws IOException { + return deserializeObject(keySerializer); + } + + public T deserializeObject(CoderTypeSerializer objectSerializer) throws IOException { + return objectSerializer.deserialize(input); + } + + ///////// Helper Methods /////// + + private byte[] readRawData() throws IOException { + validate(); + int size = input.readInt(); + + byte[] serData = new byte[size]; + int bytesRead = input.read(serData); + if (bytesRead != size) { + throw new RuntimeException("Error while deserializing checkpoint. Not enough bytes in the input stream."); + } + return serData; + } + + private void validate() { + if (this.input == null) { + throw new RuntimeException("StateBackend not initialized yet."); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java new file mode 100644 index 000000000000..b2dc33cab6b2 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class StateCheckpointUtils { + + public static void encodeState(Map> perKeyStateInternals, + StateCheckpointWriter writer, Coder keyCoder) throws IOException { + CoderTypeSerializer keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = perKeyStateInternals.size(); + writer.writeInt(noOfKeys); + for (Map.Entry> keyStatePair : perKeyStateInternals.entrySet()) { + K key = keyStatePair.getKey(); + FlinkStateInternals state = keyStatePair.getValue(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated state + state.persistState(writer); + } + } + + public static Map> decodeState( + StateCheckpointReader reader, + OutputTimeFn outputTimeFn, + Coder keyCoder, + Coder windowCoder, + ClassLoader classLoader) throws IOException, ClassNotFoundException { + + int noOfKeys = reader.getInt(); + Map> perKeyStateInternals = new HashMap<>(noOfKeys); + perKeyStateInternals.clear(); + + CoderTypeSerializer keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + //decode the state associated to the key. + FlinkStateInternals stateForKey = + new FlinkStateInternals<>(key, keyCoder, windowCoder, outputTimeFn); + stateForKey.restoreState(reader, classLoader); + perKeyStateInternals.put(key, stateForKey); + } + return perKeyStateInternals; + } + + ////////////// Encoding/Decoding the Timers //////////////// + + + public static void encodeTimers(Map> allTimers, + StateCheckpointWriter writer, + Coder keyCoder) throws IOException { + CoderTypeSerializer keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = allTimers.size(); + writer.writeInt(noOfKeys); + for (Map.Entry> timersPerKey : allTimers.entrySet()) { + K key = timersPerKey.getKey(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated timers + Set timers = timersPerKey.getValue(); + encodeTimerDataForKey(writer, timers); + } + } + + public static Map> decodeTimers( + StateCheckpointReader reader, + Coder windowCoder, + Coder keyCoder) throws IOException { + + int noOfKeys = reader.getInt(); + Map> activeTimers = new HashMap<>(noOfKeys); + activeTimers.clear(); + + CoderTypeSerializer keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + // decode the associated timers. + Set timers = decodeTimerDataForKey(reader, windowCoder); + activeTimers.put(key, timers); + } + return activeTimers; + } + + private static void encodeTimerDataForKey(StateCheckpointWriter writer, Set timers) throws IOException { + // encode timers + writer.writeInt(timers.size()); + for (TimerInternals.TimerData timer : timers) { + String stringKey = timer.getNamespace().stringKey(); + + writer.setTag(stringKey); + writer.setTimestamp(timer.getTimestamp()); + writer.writeInt(timer.getDomain().ordinal()); + } + } + + private static Set decodeTimerDataForKey( + StateCheckpointReader reader, Coder windowCoder) throws IOException { + + // decode the timers: first their number and then the content itself. + int noOfTimers = reader.getInt(); + Set timers = new HashSet<>(noOfTimers); + for (int i = 0; i < noOfTimers; i++) { + String stringKey = reader.getTagToString(); + Instant instant = reader.getTimestamp(); + TimeDomain domain = TimeDomain.values()[reader.getInt()]; + + StateNamespace namespace = StateNamespaces.fromString(stringKey, windowCoder); + timers.add(TimerInternals.TimerData.of(namespace, instant, domain)); + } + return timers; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java new file mode 100644 index 000000000000..18e118a8900e --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.protobuf.ByteString; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +public class StateCheckpointWriter { + + private final AbstractStateBackend.CheckpointStateOutputView output; + + public static StateCheckpointWriter create(AbstractStateBackend.CheckpointStateOutputView output) { + return new StateCheckpointWriter(output); + } + + private StateCheckpointWriter(AbstractStateBackend.CheckpointStateOutputView output) { + this.output = output; + } + + ///////// Creating the serialized versions of the different types of state held by dataflow /////// + + public StateCheckpointWriter addValueBuilder() throws IOException { + validate(); + StateType.serialize(StateType.VALUE, this); + return this; + } + + public StateCheckpointWriter addWatermarkHoldsBuilder() throws IOException { + validate(); + StateType.serialize(StateType.WATERMARK, this); + return this; + } + + public StateCheckpointWriter addListUpdatesBuilder() throws IOException { + validate(); + StateType.serialize(StateType.LIST, this); + return this; + } + + public StateCheckpointWriter addAccumulatorBuilder() throws IOException { + validate(); + StateType.serialize(StateType.ACCUMULATOR, this); + return this; + } + + ///////// Setting the tag for a given state element /////// + + public StateCheckpointWriter setTag(ByteString stateKey) throws IOException { + return writeData(stateKey.toByteArray()); + } + + public StateCheckpointWriter setTag(String stateKey) throws IOException { + output.writeUTF(stateKey); + return this; + } + + + public StateCheckpointWriter serializeKey(K key, CoderTypeSerializer keySerializer) throws IOException { + return serializeObject(key, keySerializer); + } + + public StateCheckpointWriter serializeObject(T object, CoderTypeSerializer objectSerializer) throws IOException { + objectSerializer.serialize(object, output); + return this; + } + + ///////// Write the actual serialized data ////////// + + public StateCheckpointWriter setData(ByteString data) throws IOException { + return writeData(data.toByteArray()); + } + + public StateCheckpointWriter setData(byte[] data) throws IOException { + return writeData(data); + } + + public StateCheckpointWriter setTimestamp(Instant timestamp) throws IOException { + validate(); + output.writeLong(TimeUnit.MILLISECONDS.toMicros(timestamp.getMillis())); + return this; + } + + public StateCheckpointWriter writeInt(int number) throws IOException { + validate(); + output.writeInt(number); + return this; + } + + public StateCheckpointWriter writeByte(byte b) throws IOException { + validate(); + output.writeByte(b); + return this; + } + + ///////// Helper Methods /////// + + private StateCheckpointWriter writeData(byte[] data) throws IOException { + validate(); + output.writeInt(data.length); + output.write(data); + return this; + } + + private void validate() { + if (this.output == null) { + throw new RuntimeException("StateBackend not initialized yet."); + } + } +} \ No newline at end of file diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java new file mode 100644 index 000000000000..58497730dd27 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.io.IOException; + +/** + * The available types of state, as provided by the Beam SDK. This class is used for serialization/deserialization + * purposes. + * */ +public enum StateType { + + VALUE(0), + + WATERMARK(1), + + LIST(2), + + ACCUMULATOR(3); + + private final int numVal; + + StateType(int value) { + this.numVal = value; + } + + public static void serialize(StateType type, StateCheckpointWriter output) throws IOException { + if (output == null) { + throw new IllegalArgumentException("Cannot write to a null output."); + } + + if(type.numVal < 0 || type.numVal > 3) { + throw new RuntimeException("Unknown State Type " + type + "."); + } + + output.writeByte((byte) type.numVal); + } + + public static StateType deserialize(StateCheckpointReader input) throws IOException { + if (input == null) { + throw new IllegalArgumentException("Cannot read from a null input."); + } + + int typeInt = (int) input.getByte(); + if(typeInt < 0 || typeInt > 3) { + throw new RuntimeException("Unknown State Type " + typeInt + "."); + } + + StateType resultType = null; + for(StateType st: values()) { + if(st.numVal == typeInt) { + resultType = st; + break; + } + } + return resultType; + } +} diff --git a/runners/flink/src/main/resources/log4j.properties b/runners/flink/src/main/resources/log4j.properties new file mode 100644 index 000000000000..4daaad1e22a5 --- /dev/null +++ b/runners/flink/src/main/resources/log4j.properties @@ -0,0 +1,23 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +log4j.rootLogger=INFO,console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{2}: %m%n diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/AvroITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/AvroITCase.java new file mode 100644 index 000000000000..5b32d54513be --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/AvroITCase.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.base.Joiner; +import org.apache.flink.api.io.avro.example.User; +import org.apache.flink.test.util.JavaProgramTestBase; + + +public class AvroITCase extends JavaProgramTestBase { + + protected String resultPath; + protected String tmpPath; + + public AvroITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "Joe red 3", + "Mary blue 4", + "Mark green 1", + "Julia purple 5" + }; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + tmpPath = getTempDirPath("tmp"); + + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + runProgram(tmpPath, resultPath); + } + + private static void runProgram(String tmpPath, String resultPath) { + Pipeline p = FlinkTestPipeline.createForBatch(); + + p + .apply(Create.of( + new User("Joe", 3, "red"), + new User("Mary", 4, "blue"), + new User("Mark", 1, "green"), + new User("Julia", 5, "purple")) + .withCoder(AvroCoder.of(User.class))) + + .apply(AvroIO.Write.to(tmpPath) + .withSchema(User.class)); + + p.run(); + + p = FlinkTestPipeline.createForBatch(); + + p + .apply(AvroIO.Read.from(tmpPath).withSchema(User.class).withoutValidation()) + + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + User u = c.element(); + String result = u.getName() + " " + u.getFavoriteColor() + " " + u.getFavoriteNumber(); + c.output(result); + } + })) + + .apply(TextIO.Write.to(resultPath)); + + p.run(); + } + +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java new file mode 100644 index 000000000000..5ae0e832fd8b --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +public class FlattenizeITCase extends JavaProgramTestBase { + + private String resultPath; + private String resultPath2; + + private static final String[] words = {"hello", "this", "is", "a", "DataSet!"}; + private static final String[] words2 = {"hello", "this", "is", "another", "DataSet!"}; + private static final String[] words3 = {"hello", "this", "is", "yet", "another", "DataSet!"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + resultPath2 = getTempDirPath("result2"); + } + + @Override + protected void postSubmit() throws Exception { + String join = Joiner.on('\n').join(words); + String join2 = Joiner.on('\n').join(words2); + String join3 = Joiner.on('\n').join(words3); + compareResultsByLinesInMemory(join + "\n" + join2, resultPath); + compareResultsByLinesInMemory(join + "\n" + join2 + "\n" + join3, resultPath2); + } + + + @Override + protected void testProgram() throws Exception { + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection p1 = p.apply(Create.of(words)); + PCollection p2 = p.apply(Create.of(words2)); + + PCollectionList list = PCollectionList.of(p1).and(p2); + + list.apply(Flatten.pCollections()).apply(TextIO.Write.to(resultPath)); + + PCollection p3 = p.apply(Create.of(words3)); + + PCollectionList list2 = list.and(p3); + + list2.apply(Flatten.pCollections()).apply(TextIO.Write.to(resultPath2)); + + p.run(); + } + +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java new file mode 100644 index 000000000000..aadda24b4c84 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; + +/** + * {@link com.google.cloud.dataflow.sdk.Pipeline} for testing Dataflow programs on the + * {@link org.apache.beam.runners.flink.FlinkPipelineRunner}. + */ +public class FlinkTestPipeline extends Pipeline { + + /** + * Creates and returns a new test pipeline for batch execution. + * + *

    Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + */ + public static FlinkTestPipeline createForBatch() { + return create(false); + } + + /** + * Creates and returns a new test pipeline for streaming execution. + * + *

    Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + * + * @return The Test Pipeline + */ + public static FlinkTestPipeline createForStreaming() { + return create(true); + } + + /** + * Creates and returns a new test pipeline for streaming or batch execution. + * + *

    Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + * + * @param streaming True for streaming mode, False for batch. + * @return The Test Pipeline. + */ + private static FlinkTestPipeline create(boolean streaming) { + FlinkPipelineRunner flinkRunner = FlinkPipelineRunner.createForTest(streaming); + return new FlinkTestPipeline(flinkRunner, flinkRunner.getPipelineOptions()); + } + + private FlinkTestPipeline(PipelineRunner runner, + PipelineOptions options) { + super(runner, options); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java new file mode 100644 index 000000000000..f60056ded067 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.util.JoinExamples; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +/** + * Unfortunately we need to copy the code from the Dataflow SDK because it is not public there. + */ +public class JoinExamplesITCase extends JavaProgramTestBase { + + protected String resultPath; + + public JoinExamplesITCase(){ + } + + private static final TableRow row1 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "BANGKOK").set("SOURCEURL", "http://cnn.com"); + private static final TableRow row2 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "LAOS").set("SOURCEURL", "http://www.chicagotribune.com"); + private static final TableRow row3 = new TableRow() + .set("ActionGeo_CountryCode", "BE").set("SQLDATE", "20141213") + .set("Actor1Name", "AFGHANISTAN").set("SOURCEURL", "http://cnn.com"); + static final TableRow[] EVENTS = new TableRow[] { + row1, row2, row3 + }; + static final List EVENT_ARRAY = Arrays.asList(EVENTS); + + private static final TableRow cc1 = new TableRow() + .set("FIPSCC", "VM").set("HumanName", "Vietnam"); + private static final TableRow cc2 = new TableRow() + .set("FIPSCC", "BE").set("HumanName", "Belgium"); + static final TableRow[] CCS = new TableRow[] { + cc1, cc2 + }; + static final List CC_ARRAY = Arrays.asList(CCS); + + static final String[] JOINED_EVENTS = new String[] { + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: LAOS, " + + "url: http://www.chicagotribune.com", + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: BANGKOK, " + + "url: http://cnn.com", + "Country code: BE, Country name: Belgium, Event info: Date: 20141213, Actor1: AFGHANISTAN, " + + "url: http://cnn.com" + }; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(JOINED_EVENTS), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection input1 = p.apply(Create.of(EVENT_ARRAY)); + PCollection input2 = p.apply(Create.of(CC_ARRAY)); + + PCollection output = JoinExamples.joinEvents(input1, input2); + + output.apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java new file mode 100644 index 000000000000..199602c6c3e0 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class MaybeEmptyTestITCase extends JavaProgramTestBase implements Serializable { + + protected String resultPath; + + protected final String expected = "test"; + + public MaybeEmptyTestITCase() { + } + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(expected, resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + p.apply(Create.of((Void) null)).setCoder(VoidCoder.of()) + .apply(ParDo.of( + new DoFn() { + @Override + public void processElement(DoFn.ProcessContext c) { + c.output(expected); + } + })).apply(TextIO.Write.to(resultPath)); + p.run(); + } + +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java new file mode 100644 index 000000000000..403de29600b7 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class ParDoMultiOutputITCase extends JavaProgramTestBase implements Serializable { + + private String resultPath; + + private static String[] expectedWords = {"MAAA", "MAAFOOO"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on("\n").join(expectedWords), resultPath); + } + + @Override + protected void testProgram() throws Exception { + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection words = p.apply(Create.of("Hello", "Whatupmyman", "hey", "SPECIALthere", "MAAA", "MAAFOOO")); + + // Select words whose length is below a cut off, + // plus the lengths of words that are above the cut off. + // Also select words starting with "MARKER". + final int wordLengthCutOff = 3; + // Create tags to use for the main and side outputs. + final TupleTag wordsBelowCutOffTag = new TupleTag(){}; + final TupleTag wordLengthsAboveCutOffTag = new TupleTag(){}; + final TupleTag markedWordsTag = new TupleTag(){}; + + PCollectionTuple results = + words.apply(ParDo + .withOutputTags(wordsBelowCutOffTag, TupleTagList.of(wordLengthsAboveCutOffTag) + .and(markedWordsTag)) + .of(new DoFn() { + final TupleTag specialWordsTag = new TupleTag() { + }; + + public void processElement(ProcessContext c) { + String word = c.element(); + if (word.length() <= wordLengthCutOff) { + c.output(word); + } else { + c.sideOutput(wordLengthsAboveCutOffTag, word.length()); + } + if (word.startsWith("MAA")) { + c.sideOutput(markedWordsTag, word); + } + + if (word.startsWith("SPECIAL")) { + c.sideOutput(specialWordsTag, word); + } + } + })); + + // Extract the PCollection results, by tag. + PCollection wordsBelowCutOff = results.get(wordsBelowCutOffTag); + PCollection wordLengthsAboveCutOff = results.get + (wordLengthsAboveCutOffTag); + PCollection markedWords = results.get(markedWordsTag); + + markedWords.apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java new file mode 100644 index 000000000000..323c41ba0da6 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +public class ReadSourceITCase extends JavaProgramTestBase { + + protected String resultPath; + + public ReadSourceITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "1", "2", "3", "4", "5", "6", "7", "8", "9"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + runProgram(resultPath); + } + + private static void runProgram(String resultPath) { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection result = p + .apply(Read.from(new ReadSource(1, 10))) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element().toString()); + } + })); + + result.apply(TextIO.Write.to(resultPath)); + p.run(); + } + + + private static class ReadSource extends BoundedSource { + final int from; + final int to; + + ReadSource(int from, int to) { + this.from = from; + this.to = to; + } + + @Override + public List splitIntoBundles(long desiredShardSizeBytes, PipelineOptions options) + throws Exception { + List res = new ArrayList<>(); + FlinkPipelineOptions flinkOptions = options.as(FlinkPipelineOptions.class); + int numWorkers = flinkOptions.getParallelism(); + Preconditions.checkArgument(numWorkers > 0, "Number of workers should be larger than 0."); + + float step = 1.0f * (to - from) / numWorkers; + for (int i = 0; i < numWorkers; ++i) { + res.add(new ReadSource(Math.round(from + i * step), Math.round(from + (i + 1) * step))); + } + return res; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 8 * (to - from); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return true; + } + + @Override + public BoundedReader createReader(PipelineOptions options) throws IOException { + return new RangeReader(this); + } + + @Override + public void validate() {} + + @Override + public Coder getDefaultOutputCoder() { + return BigEndianIntegerCoder.of(); + } + + private class RangeReader extends BoundedReader { + private int current; + + public RangeReader(ReadSource source) { + this.current = source.from - 1; + } + + @Override + public boolean start() throws IOException { + return true; + } + + @Override + public boolean advance() throws IOException { + current++; + return (current < to); + } + + @Override + public Integer getCurrent() { + return current; + } + + @Override + public void close() throws IOException { + // Nothing + } + + @Override + public BoundedSource getCurrentSource() { + return ReadSource.this; + } + } + } +} + + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java new file mode 100644 index 000000000000..524554aa8d5c --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Collections; +import java.util.List; + + +public class RemoveDuplicatesEmptyITCase extends JavaProgramTestBase { + + protected String resultPath; + + public RemoveDuplicatesEmptyITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] {}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + List strings = Collections.emptyList(); + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + output.apply(TextIO.Write.to(resultPath)); + p.run(); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java new file mode 100644 index 000000000000..54e92aa9ec39 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +public class RemoveDuplicatesITCase extends JavaProgramTestBase { + + protected String resultPath; + + public RemoveDuplicatesITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "k1", "k5", "k2", "k3"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + List strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection output = + input.apply(RemoveDuplicates.create()); + + output.apply(TextIO.Write.to(resultPath)); + p.run(); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java new file mode 100644 index 000000000000..7f73b8309605 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class SideInputITCase extends JavaProgramTestBase implements Serializable { + + private static final String expected = "Hello!"; + + protected String resultPath; + + @Override + protected void testProgram() throws Exception { + + + Pipeline p = FlinkTestPipeline.createForBatch(); + + + final PCollectionView sidesInput = p + .apply(Create.of(expected)) + .apply(View.asSingleton()); + + p.apply(Create.of("bli")) + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + String s = c.sideInput(sidesInput); + c.output(s); + } + }).withSideInputs(sidesInput)).apply(TextIO.Write.to(resultPath)); + + p.run(); + } + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(expected, resultPath); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java new file mode 100644 index 000000000000..8722feefb960 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.examples.TFIDF; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.net.URI; + + +public class TfIdfITCase extends JavaProgramTestBase { + + protected String resultPath; + + public TfIdfITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "a", "m", "n", "b", "c", "d"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline pipeline = FlinkTestPipeline.createForBatch(); + + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + PCollection>> wordToUriAndTfIdf = pipeline + .apply(Create.of( + KV.of(new URI("x"), "a b c d"), + KV.of(new URI("y"), "a b c"), + KV.of(new URI("z"), "a m n"))) + .apply(new TFIDF.ComputeTfIdf()); + + PCollection words = wordToUriAndTfIdf + .apply(Keys.create()) + .apply(RemoveDuplicates.create()); + + words.apply(TextIO.Write.to(resultPath)); + + pipeline.run(); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java new file mode 100644 index 000000000000..8ca978e79792 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.examples.WordCount; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +public class WordCountITCase extends JavaProgramTestBase { + + protected String resultPath; + + public WordCountITCase(){ + } + + static final String[] WORDS_ARRAY = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final List WORDS = Arrays.asList(WORDS_ARRAY); + + static final String[] COUNTS_ARRAY = new String[] { + "hi: 5", "there: 1", "sue: 2", "bob: 2"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(COUNTS_ARRAY), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection input = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); + + input + .apply(new WordCount.CountWords()) + .apply(MapElements.via(new WordCount.FormatAsTextFn())) + .apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java new file mode 100644 index 000000000000..e73c4568df0f --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + + +public class WordCountJoin2ITCase extends JavaProgramTestBase { + + static final String[] WORDS_1 = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final String[] WORDS_2 = new String[] { + "hi tim", "beauty", "hooray sue bob", + "hi there", "", "please say hi"}; + + static final String[] RESULTS = new String[] { + "beauty -> Tag1: Tag2: 1", + "bob -> Tag1: 2 Tag2: 1", + "hi -> Tag1: 5 Tag2: 3", + "hooray -> Tag1: Tag2: 1", + "please -> Tag1: Tag2: 1", + "say -> Tag1: Tag2: 1", + "sue -> Tag1: 2 Tag2: 1", + "there -> Tag1: 1 Tag2: 1", + "tim -> Tag1: Tag2: 1" + }; + + static final TupleTag tag1 = new TupleTag<>("Tag1"); + static final TupleTag tag2 = new TupleTag<>("Tag2"); + + protected String resultPath; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(RESULTS), resultPath); + } + + @Override + protected void testProgram() throws Exception { + Pipeline p = FlinkTestPipeline.createForBatch(); + + /* Create two PCollections and join them */ + PCollection> occurences1 = p.apply(Create.of(WORDS_1)) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Count.perElement()); + + PCollection> occurences2 = p.apply(Create.of(WORDS_2)) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Count.perElement()); + + /* CoGroup the two collections */ + PCollection> mergedOccurences = KeyedPCollectionTuple + .of(tag1, occurences1) + .and(tag2, occurences2) + .apply(CoGroupByKey.create()); + + /* Format output */ + mergedOccurences.apply(ParDo.of(new FormatCountsFn())) + .apply(TextIO.Write.named("test").to(resultPath)); + + p.run(); + } + + + static class ExtractWordsFn extends DoFn { + + @Override + public void startBundle(Context c) { + } + + @Override + public void processElement(ProcessContext c) { + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + static class FormatCountsFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + CoGbkResult value = c.element().getValue(); + String key = c.element().getKey(); + String countTag1 = tag1.getId() + ": "; + String countTag2 = tag2.getId() + ": "; + for (Long count : value.getAll(tag1)) { + countTag1 += count + " "; + } + for (Long count : value.getAll(tag2)) { + countTag2 += count; + } + c.output(key + " -> " + countTag1 + countTag2); + } + } + + +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java new file mode 100644 index 000000000000..6b57d771070b --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + + +public class WordCountJoin3ITCase extends JavaProgramTestBase { + + static final String[] WORDS_1 = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final String[] WORDS_2 = new String[] { + "hi tim", "beauty", "hooray sue bob", + "hi there", "", "please say hi"}; + + static final String[] WORDS_3 = new String[] { + "hi stephan", "beauty", "hooray big fabian", + "hi yo", "", "please say hi"}; + + static final String[] RESULTS = new String[] { + "beauty -> Tag1: Tag2: 1 Tag3: 1", + "bob -> Tag1: 2 Tag2: 1 Tag3: ", + "hi -> Tag1: 5 Tag2: 3 Tag3: 3", + "hooray -> Tag1: Tag2: 1 Tag3: 1", + "please -> Tag1: Tag2: 1 Tag3: 1", + "say -> Tag1: Tag2: 1 Tag3: 1", + "sue -> Tag1: 2 Tag2: 1 Tag3: ", + "there -> Tag1: 1 Tag2: 1 Tag3: ", + "tim -> Tag1: Tag2: 1 Tag3: ", + "stephan -> Tag1: Tag2: Tag3: 1", + "yo -> Tag1: Tag2: Tag3: 1", + "fabian -> Tag1: Tag2: Tag3: 1", + "big -> Tag1: Tag2: Tag3: 1" + }; + + static final TupleTag tag1 = new TupleTag<>("Tag1"); + static final TupleTag tag2 = new TupleTag<>("Tag2"); + static final TupleTag tag3 = new TupleTag<>("Tag3"); + + protected String resultPath; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(RESULTS), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + /* Create two PCollections and join them */ + PCollection> occurences1 = p.apply(Create.of(WORDS_1)) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Count.perElement()); + + PCollection> occurences2 = p.apply(Create.of(WORDS_2)) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Count.perElement()); + + PCollection> occurences3 = p.apply(Create.of(WORDS_3)) + .apply(ParDo.of(new ExtractWordsFn())) + .apply(Count.perElement()); + + /* CoGroup the two collections */ + PCollection> mergedOccurences = KeyedPCollectionTuple + .of(tag1, occurences1) + .and(tag2, occurences2) + .and(tag3, occurences3) + .apply(CoGroupByKey.create()); + + /* Format output */ + mergedOccurences.apply(ParDo.of(new FormatCountsFn())) + .apply(TextIO.Write.named("test").to(resultPath)); + + p.run(); + } + + + static class ExtractWordsFn extends DoFn { + + @Override + public void startBundle(Context c) { + } + + @Override + public void processElement(ProcessContext c) { + // Split the line into words. + String[] words = c.element().split("[^a-zA-Z']+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + static class FormatCountsFn extends DoFn, String> { + @Override + public void processElement(ProcessContext c) { + CoGbkResult value = c.element().getValue(); + String key = c.element().getKey(); + String countTag1 = tag1.getId() + ": "; + String countTag2 = tag2.getId() + ": "; + String countTag3 = tag3.getId() + ": "; + for (Long count : value.getAll(tag1)) { + countTag1 += count + " "; + } + for (Long count : value.getAll(tag2)) { + countTag2 += count + " "; + } + for (Long count : value.getAll(tag3)) { + countTag3 += count; + } + c.output(key + " -> " + countTag1 + countTag2 + countTag3); + } + } + +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/WriteSinkITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/WriteSinkITCase.java new file mode 100644 index 000000000000..dfa15ce62b19 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/WriteSinkITCase.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.Sink; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Write; +import com.google.common.base.Joiner; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.File; +import java.io.PrintWriter; +import java.net.URI; + +import static org.junit.Assert.*; + +/** + * Tests the translation of custom Write.Bound sinks. + */ +public class WriteSinkITCase extends JavaProgramTestBase { + + protected String resultPath; + + public WriteSinkITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "Joe red 3", "Mary blue 4", "Max yellow 23"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + runProgram(resultPath); + } + + private static void runProgram(String resultPath) { + Pipeline p = FlinkTestPipeline.createForBatch(); + + p.apply(Create.of(EXPECTED_RESULT)).setCoder(StringUtf8Coder.of()) + .apply("CustomSink", Write.to(new MyCustomSink(resultPath))); + + p.run(); + } + + /** + * Simple custom sink which writes to a file. + */ + private static class MyCustomSink extends Sink { + + private final String resultPath; + + public MyCustomSink(String resultPath) { + this.resultPath = resultPath; + } + + @Override + public void validate(PipelineOptions options) { + assertNotNull(options); + } + + @Override + public WriteOperation createWriteOperation(PipelineOptions options) { + return new MyWriteOperation(); + } + + private class MyWriteOperation extends WriteOperation { + + @Override + public Coder getWriterResultCoder() { + return StringUtf8Coder.of(); + } + + @Override + public void initialize(PipelineOptions options) throws Exception { + + } + + @Override + public void finalize(Iterable writerResults, PipelineOptions options) throws Exception { + + } + + @Override + public Writer createWriter(PipelineOptions options) throws Exception { + return new MyWriter(); + } + + @Override + public Sink getSink() { + return MyCustomSink.this; + } + + /** + * Simple Writer which writes to a file. + */ + private class MyWriter extends Writer { + + private PrintWriter internalWriter; + + @Override + public void open(String uId) throws Exception { + Path path = new Path(resultPath + "/" + uId); + FileSystem.get(new URI("file:///")).create(path, false); + internalWriter = new PrintWriter(new File(path.toUri())); + } + + @Override + public void write(String value) throws Exception { + internalWriter.println(value); + } + + @Override + public String close() throws Exception { + internalWriter.close(); + return resultPath; + } + + @Override + public WriteOperation getWriteOperation() { + return MyWriteOperation.this; + } + } + } + } + +} + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java new file mode 100644 index 000000000000..880da59792af --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.beam.runners.flink.FlinkTestPipeline; +import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkGroupAlsoByWindowWrapper; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.util.UserCodeException; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.base.Throwables; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; + +import java.util.Collection; +import java.util.Comparator; +import java.util.concurrent.ConcurrentLinkedQueue; + +public class GroupAlsoByWindowTest { + + private final Combine.CombineFn combiner = new Sum.SumIntegerFn(); + + private final WindowingStrategy slidingWindowWithAfterWatermarkTriggerStrategy = + WindowingStrategy.of(SlidingWindows.of(Duration.standardSeconds(10)).every(Duration.standardSeconds(5))) + .withTrigger(AfterWatermark.pastEndOfWindow()).withMode(WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES); + + private final WindowingStrategy sessionWindowingStrategy = + WindowingStrategy.of(Sessions.withGapDuration(Duration.standardSeconds(2))) + .withTrigger(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) + .withMode(WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.standardSeconds(100)); + + private final WindowingStrategy fixedWindowingStrategy = + WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(10))); + + private final WindowingStrategy fixedWindowWithCountTriggerStrategy = + fixedWindowingStrategy.withTrigger(AfterPane.elementCountAtLeast(5)); + + private final WindowingStrategy fixedWindowWithAfterWatermarkTriggerStrategy = + fixedWindowingStrategy.withTrigger(AfterWatermark.pastEndOfWindow()); + + private final WindowingStrategy fixedWindowWithCompoundTriggerStrategy = + fixedWindowingStrategy.withTrigger( + AfterWatermark.pastEndOfWindow().withEarlyFirings(AfterPane.elementCountAtLeast(5)) + .withLateFirings(AfterPane.elementCountAtLeast(5)).buildTrigger()); + + /** + * The default accumulation mode is + * {@link com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode#DISCARDING_FIRED_PANES}. + * This strategy changes it to + * {@link com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode#ACCUMULATING_FIRED_PANES} + */ + private final WindowingStrategy fixedWindowWithCompoundTriggerStrategyAcc = + fixedWindowWithCompoundTriggerStrategy + .withMode(WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES); + + @Test + public void testWithLateness() throws Exception { + WindowingStrategy strategy = WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(2))) + .withMode(WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(1000)); + long initialTime = 0L; + Pipeline pipeline = FlinkTestPipeline.createForStreaming(); + + KvCoder inputCoder = KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()); + + FlinkGroupAlsoByWindowWrapper gbwOperaror = + FlinkGroupAlsoByWindowWrapper.createForTesting( + pipeline.getOptions(), + pipeline.getCoderRegistry(), + strategy, + inputCoder, + combiner.asKeyedFn()); + + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + new OneInputStreamOperatorTestHarness<>(gbwOperaror); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1000), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processWatermark(new Watermark(initialTime + 2000)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processWatermark(new Watermark(initialTime + 4000)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 4), + new Instant(initialTime + 1), + new IntervalWindow(new Instant(0), new Instant(2000)), + PaneInfo.createPane(true, false, PaneInfo.Timing.ON_TIME, 0, 0)) + , initialTime + 1)); + expectedOutput.add(new Watermark(initialTime + 2000)); + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 5), + new Instant(initialTime + 1999), + new IntervalWindow(new Instant(0), new Instant(2000)), + PaneInfo.createPane(false, false, PaneInfo.Timing.LATE, 1, 1)) + , initialTime + 1999)); + + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 6), + new Instant(initialTime + 1999), + new IntervalWindow(new Instant(0), new Instant(2000)), + PaneInfo.createPane(false, false, PaneInfo.Timing.LATE, 2, 2)) + , initialTime + 1999)); + expectedOutput.add(new Watermark(initialTime + 4000)); + + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + testHarness.close(); + } + + @Test + public void testSessionWindows() throws Exception { + WindowingStrategy strategy = sessionWindowingStrategy; + + long initialTime = 0L; + Pipeline pipeline = FlinkTestPipeline.createForStreaming(); + + KvCoder inputCoder = KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()); + + FlinkGroupAlsoByWindowWrapper gbwOperaror = + FlinkGroupAlsoByWindowWrapper.createForTesting( + pipeline.getOptions(), + pipeline.getCoderRegistry(), + strategy, + inputCoder, + combiner.asKeyedFn()); + + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + new OneInputStreamOperatorTestHarness<>(gbwOperaror); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1000), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 3500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 3700), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 2700), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processWatermark(new Watermark(initialTime + 6000)); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 6700), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 6800), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 8900), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 7600), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 5600), null, PaneInfo.NO_FIRING), initialTime + 20)); + + testHarness.processWatermark(new Watermark(initialTime + 12000)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 6), + new Instant(initialTime + 1), + new IntervalWindow(new Instant(1), new Instant(5700)), + PaneInfo.createPane(true, false, PaneInfo.Timing.ON_TIME, 0, 0)) + , initialTime + 1)); + expectedOutput.add(new Watermark(initialTime + 6000)); + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 11), + new Instant(initialTime + 6700), + new IntervalWindow(new Instant(1), new Instant(10900)), + PaneInfo.createPane(true, false, PaneInfo.Timing.ON_TIME, 0, 0)) + , initialTime + 6700)); + expectedOutput.add(new Watermark(initialTime + 12000)); + + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + testHarness.close(); + } + + @Test + public void testSlidingWindows() throws Exception { + WindowingStrategy strategy = slidingWindowWithAfterWatermarkTriggerStrategy; + long initialTime = 0L; + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + createTestingOperatorAndState(strategy, initialTime); + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + testHarness.processWatermark(new Watermark(initialTime + 25000)); + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 6), + new Instant(initialTime + 5000), + new IntervalWindow(new Instant(0), new Instant(10000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 5000)); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 6), + new Instant(initialTime + 1), + new IntervalWindow(new Instant(-5000), new Instant(5000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 1)); + expectedOutput.add(new Watermark(initialTime + 10000)); + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 11), + new Instant(initialTime + 15000), + new IntervalWindow(new Instant(10000), new Instant(20000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 15000)); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 3), + new Instant(initialTime + 10000), + new IntervalWindow(new Instant(5000), new Instant(15000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 10000)); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key2", 1), + new Instant(initialTime + 19500), + new IntervalWindow(new Instant(10000), new Instant(20000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 19500)); + expectedOutput.add(new Watermark(initialTime + 20000)); + + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key2", 1), + new Instant(initialTime + 20000), + /** + * this is 20000 and not 19500 because of a convention in dataflow where + * timestamps of windowed values in a window cannot be smaller than the + * end of a previous window. Checkout the documentation of the + * {@link WindowFn#getOutputTime(Instant, BoundedWindow)} + */ + new IntervalWindow(new Instant(15000), new Instant(25000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 20000)); + expectedOutput.add(new StreamRecord<>( + WindowedValue.of(KV.of("key1", 8), + new Instant(initialTime + 20000), + new IntervalWindow(new Instant(15000), new Instant(25000)), + PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)) + , initialTime + 20000)); + expectedOutput.add(new Watermark(initialTime + 25000)); + + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + testHarness.close(); + } + + @Test + public void testAfterWatermarkProgram() throws Exception { + WindowingStrategy strategy = fixedWindowWithAfterWatermarkTriggerStrategy; + long initialTime = 0L; + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + createTestingOperatorAndState(strategy, initialTime); + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 6), + new Instant(initialTime + 1), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)), initialTime + 1)); + expectedOutput.add(new Watermark(initialTime + 10000)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 11), + new Instant(initialTime + 10000), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)), initialTime + 10000)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key2", 1), + new Instant(initialTime + 19500), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)), initialTime + 19500)); + expectedOutput.add(new Watermark(initialTime + 20000)); + + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + testHarness.close(); + } + + @Test + public void testAfterCountProgram() throws Exception { + WindowingStrategy strategy = fixedWindowWithCountTriggerStrategy; + + long initialTime = 0L; + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + createTestingOperatorAndState(strategy, initialTime); + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 1), null, PaneInfo.createPane(true, true, PaneInfo.Timing.EARLY)), initialTime + 1)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 10000), null, PaneInfo.createPane(true, true, PaneInfo.Timing.EARLY)), initialTime + 10000)); + expectedOutput.add(new Watermark(initialTime + 10000)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key2", 1), + new Instant(initialTime + 19500), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME, 0, 0)), initialTime + 19500)); + expectedOutput.add(new Watermark(initialTime + 20000)); + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + + testHarness.close(); + } + + @Test + public void testCompoundProgram() throws Exception { + WindowingStrategy strategy = fixedWindowWithCompoundTriggerStrategy; + + long initialTime = 0L; + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + createTestingOperatorAndState(strategy, initialTime); + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + /** + * PaneInfo are: + * isFirst (pane in window), + * isLast, Timing (of triggering), + * index (of pane in the window), + * onTimeIndex (if it the 1st,2nd, ... pane that was fired on time) + * */ + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 1), null, PaneInfo.createPane(true, false, PaneInfo.Timing.EARLY)), initialTime + 1)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 10000), null, PaneInfo.createPane(true, false, PaneInfo.Timing.EARLY)), initialTime + 10000)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 19500), null, PaneInfo.createPane(false, false, PaneInfo.Timing.EARLY, 1, -1)), initialTime + 19500)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), + new Instant(initialTime + 1200), null, PaneInfo.createPane(false, true, PaneInfo.Timing.ON_TIME, 1, 0)), initialTime + 1200)); + + expectedOutput.add(new Watermark(initialTime + 10000)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), + new Instant(initialTime + 19500), null, PaneInfo.createPane(false, true, PaneInfo.Timing.ON_TIME, 2, 0)), initialTime + 19500)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key2", 1), + new Instant(initialTime + 19500), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)), initialTime + 19500)); + + expectedOutput.add(new Watermark(initialTime + 20000)); + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + + testHarness.close(); + } + + @Test + public void testCompoundAccumulatingPanesProgram() throws Exception { + WindowingStrategy strategy = fixedWindowWithCompoundTriggerStrategyAcc; + long initialTime = 0L; + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + createTestingOperatorAndState(strategy, initialTime); + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 1), null, PaneInfo.createPane(true, false, PaneInfo.Timing.EARLY)), initialTime + 1)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 5), + new Instant(initialTime + 10000), null, PaneInfo.createPane(true, false, PaneInfo.Timing.EARLY)), initialTime + 10000)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 10), + new Instant(initialTime + 19500), null, PaneInfo.createPane(false, false, PaneInfo.Timing.EARLY, 1, -1)), initialTime + 19500)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 6), + new Instant(initialTime + 1200), null, PaneInfo.createPane(false, true, PaneInfo.Timing.ON_TIME, 1, 0)), initialTime + 1200)); + + expectedOutput.add(new Watermark(initialTime + 10000)); + + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 11), + new Instant(initialTime + 19500), null, PaneInfo.createPane(false, true, PaneInfo.Timing.ON_TIME, 2, 0)), initialTime + 19500)); + expectedOutput.add(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key2", 1), + new Instant(initialTime + 19500), null, PaneInfo.createPane(true, true, PaneInfo.Timing.ON_TIME)), initialTime + 19500)); + + expectedOutput.add(new Watermark(initialTime + 20000)); + TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new ResultSortComparator()); + + testHarness.close(); + } + + private OneInputStreamOperatorTestHarness createTestingOperatorAndState(WindowingStrategy strategy, long initialTime) throws Exception { + Pipeline pipeline = FlinkTestPipeline.createForStreaming(); + + KvCoder inputCoder = KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()); + + FlinkGroupAlsoByWindowWrapper gbwOperaror = + FlinkGroupAlsoByWindowWrapper.createForTesting( + pipeline.getOptions(), + pipeline.getCoderRegistry(), + strategy, + inputCoder, + combiner.asKeyedFn()); + + OneInputStreamOperatorTestHarness>, WindowedValue>> testHarness = + new OneInputStreamOperatorTestHarness<>(gbwOperaror); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1000), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 1200), null, PaneInfo.NO_FIRING), initialTime + 20)); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 10000), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 12100), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 14200), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 15300), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 16500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key1", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + + testHarness.processElement(new StreamRecord<>(makeWindowedValue(strategy, KV.of("key2", 1), new Instant(initialTime + 19500), null, PaneInfo.NO_FIRING), initialTime + 20)); + + testHarness.processWatermark(new Watermark(initialTime + 10000)); + testHarness.processWatermark(new Watermark(initialTime + 20000)); + + return testHarness; + } + + private static class ResultSortComparator implements Comparator { + @Override + public int compare(Object o1, Object o2) { + if (o1 instanceof Watermark && o2 instanceof Watermark) { + Watermark w1 = (Watermark) o1; + Watermark w2 = (Watermark) o2; + return (int) (w1.getTimestamp() - w2.getTimestamp()); + } else { + StreamRecord>> sr0 = (StreamRecord>>) o1; + StreamRecord>> sr1 = (StreamRecord>>) o2; + + int comparison = (int) (sr0.getValue().getTimestamp().getMillis() - sr1.getValue().getTimestamp().getMillis()); + if (comparison != 0) { + return comparison; + } + + comparison = sr0.getValue().getValue().getKey().compareTo(sr1.getValue().getValue().getKey()); + if(comparison == 0) { + comparison = Integer.compare( + sr0.getValue().getValue().getValue(), + sr1.getValue().getValue().getValue()); + } + if(comparison == 0) { + Collection windowsA = sr0.getValue().getWindows(); + Collection windowsB = sr1.getValue().getWindows(); + + if(windowsA.size() != 1 || windowsB.size() != 1) { + throw new IllegalStateException("A value cannot belong to more than one windows after grouping."); + } + + BoundedWindow windowA = (BoundedWindow) windowsA.iterator().next(); + BoundedWindow windowB = (BoundedWindow) windowsB.iterator().next(); + comparison = Long.compare(windowA.maxTimestamp().getMillis(), windowB.maxTimestamp().getMillis()); + } + return comparison; + } + } + } + + private WindowedValue makeWindowedValue(WindowingStrategy strategy, + T output, Instant timestamp, Collection windows, PaneInfo pane) { + final Instant inputTimestamp = timestamp; + final WindowFn windowFn = strategy.getWindowFn(); + + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + if (windows == null) { + try { + windows = windowFn.assignWindows(windowFn.new AssignContext() { + @Override + public Object element() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input element when none was available"); + } + + @Override + public Instant timestamp() { + if (inputTimestamp == null) { + throw new UnsupportedOperationException( + "WindowFn attempted to access input timestamp when none was available"); + } + return inputTimestamp; + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input windows when none were available"); + } + }); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } + } + + return WindowedValue.of(output, timestamp, windows, pane); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java new file mode 100644 index 000000000000..63e0bcf718f2 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.beam.runners.flink.FlinkTestPipeline; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.StreamingProgramTestBase; +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Arrays; + +public class GroupByNullKeyTest extends StreamingProgramTestBase implements Serializable { + + + protected String resultPath; + + static final String[] EXPECTED_RESULT = new String[] { + "k: null v: user1 user1 user1 user2 user2 user2 user2 user3" + }; + + public GroupByNullKeyTest(){ + } + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + public static class ExtractUserAndTimestamp extends DoFn, String> { + private static final long serialVersionUID = 0; + + @Override + public void processElement(ProcessContext c) { + KV record = c.element(); + long now = System.currentTimeMillis(); + int timestamp = record.getKey(); + String userName = record.getValue(); + if (userName != null) { + // Sets the implicit timestamp field to be used in windowing. + c.outputWithTimestamp(userName, new Instant(timestamp + now)); + } + } + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForStreaming(); + + PCollection output = + p.apply(Create.of(Arrays.asList( + KV.of(0, "user1"), + KV.of(1, "user1"), + KV.of(2, "user1"), + KV.of(10, "user2"), + KV.of(1, "user2"), + KV.of(15000, "user2"), + KV.of(12000, "user2"), + KV.of(25000, "user3")))) + .apply(ParDo.of(new ExtractUserAndTimestamp())) + .apply(Window.into(FixedWindows.of(Duration.standardHours(1))) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()) + + .apply(ParDo.of(new DoFn>() { + @Override + public void processElement(ProcessContext c) throws Exception { + String elem = c.element(); + c.output(KV.of((Void) null, elem)); + } + })) + .apply(GroupByKey.create()) + .apply(ParDo.of(new DoFn>, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + KV> elem = c.element(); + StringBuilder str = new StringBuilder(); + str.append("k: " + elem.getKey() + " v:"); + for (String v : elem.getValue()) { + str.append(" " + v); + } + c.output(str.toString()); + } + })); + output.apply(TextIO.Write.to(resultPath)); + p.run(); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StateSerializationTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StateSerializationTest.java new file mode 100644 index 000000000000..77a8de65082e --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StateSerializationTest.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.StateCheckpointReader; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.StateCheckpointUtils; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.StateCheckpointWriter; +import com.google.cloud.dataflow.sdk.coders.*; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.state.*; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.util.DataInputDeserializer; +import org.joda.time.Instant; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.*; + +import static org.junit.Assert.assertEquals; + +public class StateSerializationTest { + + private static final StateNamespace NAMESPACE_1 = StateNamespaces.global(); + private static final String KEY_PREFIX = "TEST_"; + + // TODO: This can be replaced with the standard Sum.SumIntererFn once the state no longer needs + // to create a StateTag at the point of restoring state. Currently StateTags are compared strictly + // by type and combiners always use KeyedCombineFnWithContext rather than KeyedCombineFn or CombineFn. + private static CombineWithContext.KeyedCombineFnWithContext SUM_COMBINER = + new CombineWithContext.KeyedCombineFnWithContext() { + @Override + public int[] createAccumulator(Object key, CombineWithContext.Context c) { + return new int[1]; + } + + @Override + public int[] addInput(Object key, int[] accumulator, Integer value, CombineWithContext.Context c) { + accumulator[0] += value; + return accumulator; + } + + @Override + public int[] mergeAccumulators(Object key, Iterable accumulators, CombineWithContext.Context c) { + int[] r = new int[1]; + for (int[] a : accumulators) { + r[0] += a[0]; + } + return r; + } + + @Override + public Integer extractOutput(Object key, int[] accumulator, CombineWithContext.Context c) { + return accumulator[0]; + } + }; + + private static Coder INT_ACCUM_CODER = DelegateCoder.of( + VarIntCoder.of(), + new DelegateCoder.CodingFunction() { + @Override + public Integer apply(int[] accumulator) { + return accumulator[0]; + } + }, + new DelegateCoder.CodingFunction() { + @Override + public int[] apply(Integer value) { + int[] a = new int[1]; + a[0] = value; + return a; + } + }); + + private static final StateTag> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag> INT_VALUE_ADDR = + StateTags.value("stringValue", VarIntCoder.of()); + private static final StateTag> SUM_INTEGER_ADDR = + StateTags.keyedCombiningValueWithContext("sumInteger", INT_ACCUM_CODER, SUM_COMBINER); + private static final StateTag> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag> WATERMARK_BAG_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp()); + + private Map> statePerKey = new HashMap<>(); + + private Map> activeTimers = new HashMap<>(); + + private void initializeStateAndTimers() throws CannotProvideCoderException { + for (int i = 0; i < 10; i++) { + String key = KEY_PREFIX + i; + + FlinkStateInternals state = initializeStateForKey(key); + Set timers = new HashSet<>(); + for (int j = 0; j < 5; j++) { + TimerInternals.TimerData timer = TimerInternals + .TimerData.of(NAMESPACE_1, + new Instant(1000 + i + j), TimeDomain.values()[j % 3]); + timers.add(timer); + } + + statePerKey.put(key, state); + activeTimers.put(key, timers); + } + } + + private FlinkStateInternals initializeStateForKey(String key) throws CannotProvideCoderException { + FlinkStateInternals state = createState(key); + + ValueState value = state.state(NAMESPACE_1, STRING_VALUE_ADDR); + value.write("test"); + + ValueState value2 = state.state(NAMESPACE_1, INT_VALUE_ADDR); + value2.write(4); + value2.write(5); + + AccumulatorCombiningState combiningValue = state.state(NAMESPACE_1, SUM_INTEGER_ADDR); + combiningValue.add(1); + combiningValue.add(2); + + WatermarkHoldState watermark = state.state(NAMESPACE_1, WATERMARK_BAG_ADDR); + watermark.add(new Instant(1000)); + + BagState bag = state.state(NAMESPACE_1, STRING_BAG_ADDR); + bag.add("v1"); + bag.add("v2"); + bag.add("v3"); + bag.add("v4"); + return state; + } + + private boolean restoreAndTestState(DataInputView in) throws Exception { + StateCheckpointReader reader = new StateCheckpointReader(in); + final ClassLoader userClassloader = this.getClass().getClassLoader(); + Coder windowCoder = IntervalWindow.getCoder(); + Coder keyCoder = StringUtf8Coder.of(); + + boolean comparisonRes = true; + + for (String key : statePerKey.keySet()) { + comparisonRes &= checkStateForKey(key); + } + + // restore the timers + Map> restoredTimersPerKey = StateCheckpointUtils.decodeTimers(reader, windowCoder, keyCoder); + if (activeTimers.size() != restoredTimersPerKey.size()) { + return false; + } + + for (String key : statePerKey.keySet()) { + Set originalTimers = activeTimers.get(key); + Set restoredTimers = restoredTimersPerKey.get(key); + comparisonRes &= checkTimersForKey(originalTimers, restoredTimers); + } + + // restore the state + Map> restoredPerKeyState = + StateCheckpointUtils.decodeState(reader, OutputTimeFns.outputAtEarliestInputTimestamp(), keyCoder, windowCoder, userClassloader); + if (restoredPerKeyState.size() != statePerKey.size()) { + return false; + } + + for (String key : statePerKey.keySet()) { + FlinkStateInternals originalState = statePerKey.get(key); + FlinkStateInternals restoredState = restoredPerKeyState.get(key); + comparisonRes &= checkStateForKey(originalState, restoredState); + } + return comparisonRes; + } + + private boolean checkStateForKey(String key) throws CannotProvideCoderException { + FlinkStateInternals state = statePerKey.get(key); + + ValueState value = state.state(NAMESPACE_1, STRING_VALUE_ADDR); + boolean comp = value.read().equals("test"); + + ValueState value2 = state.state(NAMESPACE_1, INT_VALUE_ADDR); + comp &= value2.read().equals(5); + + AccumulatorCombiningState combiningValue = state.state(NAMESPACE_1, SUM_INTEGER_ADDR); + comp &= combiningValue.read().equals(3); + + WatermarkHoldState watermark = state.state(NAMESPACE_1, WATERMARK_BAG_ADDR); + comp &= watermark.read().equals(new Instant(1000)); + + BagState bag = state.state(NAMESPACE_1, STRING_BAG_ADDR); + Iterator it = bag.read().iterator(); + int i = 0; + while (it.hasNext()) { + comp &= it.next().equals("v" + (++i)); + } + return comp; + } + + private void storeState(AbstractStateBackend.CheckpointStateOutputView out) throws Exception { + StateCheckpointWriter checkpointBuilder = StateCheckpointWriter.create(out); + Coder keyCoder = StringUtf8Coder.of(); + + // checkpoint the timers + StateCheckpointUtils.encodeTimers(activeTimers, checkpointBuilder, keyCoder); + + // checkpoint the state + StateCheckpointUtils.encodeState(statePerKey, checkpointBuilder, keyCoder); + } + + private boolean checkTimersForKey(Set originalTimers, Set restoredTimers) { + boolean comp = true; + if (restoredTimers == null) { + return false; + } + + if (originalTimers.size() != restoredTimers.size()) { + return false; + } + + for (TimerInternals.TimerData timer : originalTimers) { + comp &= restoredTimers.contains(timer); + } + return comp; + } + + private boolean checkStateForKey(FlinkStateInternals originalState, FlinkStateInternals restoredState) throws CannotProvideCoderException { + if (restoredState == null) { + return false; + } + + ValueState orValue = originalState.state(NAMESPACE_1, STRING_VALUE_ADDR); + ValueState resValue = restoredState.state(NAMESPACE_1, STRING_VALUE_ADDR); + boolean comp = orValue.read().equals(resValue.read()); + + ValueState orIntValue = originalState.state(NAMESPACE_1, INT_VALUE_ADDR); + ValueState resIntValue = restoredState.state(NAMESPACE_1, INT_VALUE_ADDR); + comp &= orIntValue.read().equals(resIntValue.read()); + + AccumulatorCombiningState combOrValue = originalState.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState combResValue = restoredState.state(NAMESPACE_1, SUM_INTEGER_ADDR); + comp &= combOrValue.read().equals(combResValue.read()); + + WatermarkHoldState orWatermark = originalState.state(NAMESPACE_1, WATERMARK_BAG_ADDR); + WatermarkHoldState resWatermark = restoredState.state(NAMESPACE_1, WATERMARK_BAG_ADDR); + comp &= orWatermark.read().equals(resWatermark.read()); + + BagState orBag = originalState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState resBag = restoredState.state(NAMESPACE_1, STRING_BAG_ADDR); + + Iterator orIt = orBag.read().iterator(); + Iterator resIt = resBag.read().iterator(); + + while (orIt.hasNext() && resIt.hasNext()) { + comp &= orIt.next().equals(resIt.next()); + } + + return !((orIt.hasNext() && !resIt.hasNext()) || (!orIt.hasNext() && resIt.hasNext())) && comp; + } + + private FlinkStateInternals createState(String key) throws CannotProvideCoderException { + return new FlinkStateInternals<>( + key, + StringUtf8Coder.of(), + IntervalWindow.getCoder(), + OutputTimeFns.outputAtEarliestInputTimestamp()); + } + + @Test + public void test() throws Exception { + StateSerializationTest test = new StateSerializationTest(); + test.initializeStateAndTimers(); + + MemoryStateBackend.MemoryCheckpointOutputStream memBackend = new MemoryStateBackend.MemoryCheckpointOutputStream(32048); + AbstractStateBackend.CheckpointStateOutputView out = new AbstractStateBackend.CheckpointStateOutputView(memBackend); + + test.storeState(out); + + byte[] contents = memBackend.closeAndGetBytes(); + DataInputView in = new DataInputDeserializer(contents, 0, contents.length); + + assertEquals(test.restoreAndTestState(in), true); + } + +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java new file mode 100644 index 000000000000..83c1661fcca6 --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.beam.runners.flink.FlinkTestPipeline; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.windowing.Sessions; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.streaming.util.StreamingProgramTestBase; +import org.joda.time.Duration; +import org.joda.time.Instant; + +import java.io.Serializable; +import java.util.Arrays; + + +/** + * Session window test + */ +public class TopWikipediaSessionsITCase extends StreamingProgramTestBase implements Serializable { + protected String resultPath; + + public TopWikipediaSessionsITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "user: user1 value:3", + "user: user1 value:1", + "user: user2 value:4", + "user: user2 value:6", + "user: user3 value:7", + "user: user3 value:2" + }; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForStreaming(); + + Long now = (System.currentTimeMillis() + 10000) / 1000; + + PCollection> output = + p.apply(Create.of(Arrays.asList(new TableRow().set("timestamp", now).set + ("contributor_username", "user1"), new TableRow().set("timestamp", now + 10).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now).set + ("contributor_username", "user1"), new TableRow().set("timestamp", now + 2).set + ("contributor_username", "user1"), new TableRow().set("timestamp", now).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 1).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 5).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 7).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 8).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 200).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 230).set + ("contributor_username", "user1"), new TableRow().set("timestamp", now + 230).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 240).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now + 245).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 235).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 236).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 237).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 238).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 239).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 240).set + ("contributor_username", "user3"), new TableRow().set("timestamp", now + 241).set + ("contributor_username", "user2"), new TableRow().set("timestamp", now) + .set("contributor_username", "user3")))) + + + + .apply(ParDo.of(new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + TableRow row = c.element(); + long timestamp = (Integer) row.get("timestamp"); + String userName = (String) row.get("contributor_username"); + if (userName != null) { + // Sets the timestamp field to be used in windowing. + c.outputWithTimestamp(userName, new Instant(timestamp * 1000L)); + } + } + })) + + .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))) + + .apply(Count.perElement()); + + PCollection format = output.apply(ParDo.of(new DoFn, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + KV el = c.element(); + String out = "user: " + el.getKey() + " value:" + el.getValue(); + c.output(out); + } + })); + + format.apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java new file mode 100644 index 000000000000..e850dd6a451e --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.util; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.BigQueryIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.options.Validation; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; + +/** + * Copied from {@link com.google.cloud.dataflow.examples.JoinExamples} because the code + * is private there. + */ +public class JoinExamples { + + // A 1000-row sample of the GDELT data here: gdelt-bq:full.events. + private static final String GDELT_EVENTS_TABLE = + "clouddataflow-readonly:samples.gdelt_sample"; + // A table that maps country codes to country names. + private static final String COUNTRY_CODES = + "gdelt-bq:full.crosswalk_geocountrycodetohuman"; + + /** + * Join two collections, using country code as the key. + */ + public static PCollection joinEvents(PCollection eventsTable, + PCollection countryCodes) throws Exception { + + final TupleTag eventInfoTag = new TupleTag<>(); + final TupleTag countryInfoTag = new TupleTag<>(); + + // transform both input collections to tuple collections, where the keys are country + // codes in both cases. + PCollection> eventInfo = eventsTable.apply( + ParDo.of(new ExtractEventDataFn())); + PCollection> countryInfo = countryCodes.apply( + ParDo.of(new ExtractCountryInfoFn())); + + // country code 'key' -> CGBKR (, ) + PCollection> kvpCollection = KeyedPCollectionTuple + .of(eventInfoTag, eventInfo) + .and(countryInfoTag, countryInfo) + .apply(CoGroupByKey.create()); + + // Process the CoGbkResult elements generated by the CoGroupByKey transform. + // country code 'key' -> string of , + PCollection> finalResultCollection = + kvpCollection.apply(ParDo.of(new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) { + KV e = c.element(); + CoGbkResult val = e.getValue(); + String countryCode = e.getKey(); + String countryName; + countryName = e.getValue().getOnly(countryInfoTag, "Kostas"); + for (String eventInfo : c.element().getValue().getAll(eventInfoTag)) { + // Generate a string that combines information from both collection values + c.output(KV.of(countryCode, "Country name: " + countryName + + ", Event info: " + eventInfo)); + } + } + })); + + // write to GCS + return finalResultCollection + .apply(ParDo.of(new DoFn, String>() { + @Override + public void processElement(ProcessContext c) { + String outputstring = "Country code: " + c.element().getKey() + + ", " + c.element().getValue(); + c.output(outputstring); + } + })); + } + + /** + * Examines each row (event) in the input table. Output a KV with the key the country + * code of the event, and the value a string encoding event information. + */ + static class ExtractEventDataFn extends DoFn> { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + String countryCode = (String) row.get("ActionGeo_CountryCode"); + String sqlDate = (String) row.get("SQLDATE"); + String actor1Name = (String) row.get("Actor1Name"); + String sourceUrl = (String) row.get("SOURCEURL"); + String eventInfo = "Date: " + sqlDate + ", Actor1: " + actor1Name + ", url: " + sourceUrl; + c.output(KV.of(countryCode, eventInfo)); + } + } + + + /** + * Examines each row (country info) in the input table. Output a KV with the key the country + * code, and the value the country name. + */ + static class ExtractCountryInfoFn extends DoFn> { + @Override + public void processElement(ProcessContext c) { + TableRow row = c.element(); + String countryCode = (String) row.get("FIPSCC"); + String countryName = (String) row.get("HumanName"); + c.output(KV.of(countryCode, countryName)); + } + } + + + /** + * Options supported by {@link JoinExamples}. + *

    + * Inherits standard configuration options. + */ + private interface Options extends PipelineOptions { + @Description("Path of the file to write to") + @Validation.Required + String getOutput(); + void setOutput(String value); + } + + public static void main(String[] args) throws Exception { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + // the following two 'applys' create multiple inputs to our pipeline, one for each + // of our two input sources. + PCollection eventsTable = p.apply(BigQueryIO.Read.from(GDELT_EVENTS_TABLE)); + PCollection countryCodes = p.apply(BigQueryIO.Read.from(COUNTRY_CODES)); + PCollection formattedResults = joinEvents(eventsTable, countryCodes); + formattedResults.apply(TextIO.Write.to(options.getOutput())); + p.run(); + } + +} diff --git a/runners/flink/src/test/resources/log4j-test.properties b/runners/flink/src/test/resources/log4j-test.properties new file mode 100644 index 000000000000..4c74d85d7c62 --- /dev/null +++ b/runners/flink/src/test/resources/log4j-test.properties @@ -0,0 +1,27 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +log4j.rootLogger=OFF, testlogger + +# A1 is set to be a ConsoleAppender. +log4j.appender.testlogger=org.apache.log4j.ConsoleAppender +log4j.appender.testlogger.target = System.err +log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout +log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/runners/pom.xml b/runners/pom.xml new file mode 100644 index 000000000000..757e2081d1cc --- /dev/null +++ b/runners/pom.xml @@ -0,0 +1,43 @@ + + + + + 4.0.0 + + + com.google.cloud.dataflow + google-cloud-dataflow-java-sdk-parent + 1.5.0-SNAPSHOT + + + org.apache.beam + runners + 1.5.0-SNAPSHOT + + pom + + Beam Runners + + + flink + + +